Files
railway_cloud/app/services/database.py
2025-09-27 09:38:27 +08:00

335 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from sqlalchemy.orm import Session
from sqlalchemy import text, MetaData, Table, Column, create_engine, inspect
from sqlalchemy.exc import SQLAlchemyError
from typing import List, Dict, Any, Optional
from ..core.database import engine
from ..utils.file_import import FileImportUtils
import pandas as pd
class DatabaseService:
@staticmethod
def execute_sql(db: Session, sql: str) -> Dict[str, Any]:
"""执行SQL语句"""
try:
result = db.execute(text(sql))
# 判断是否为SELECT查询
if sql.strip().upper().startswith('SELECT'):
data = []
columns = result.keys()
for row in result:
data.append(dict(zip(columns, row)))
return {
"success": True,
"message": "查询成功",
"data": data
}
else:
# 非SELECT语句提交事务
db.commit()
return {
"success": True,
"message": "执行成功",
"rows_affected": result.rowcount
}
except SQLAlchemyError as e:
db.rollback()
return {
"success": False,
"message": f"SQL执行失败: {str(e)}"
}
except Exception as e:
db.rollback()
return {
"success": False,
"message": f"未知错误: {str(e)}"
}
@staticmethod
def get_table_data(db: Session, table_name: str, limit: int = 100, offset: int = 0) -> Dict[str, Any]:
"""获取表数据"""
try:
# 先检查表是否存在
inspector = inspect(engine)
if table_name not in inspector.get_table_names():
return {
"success": False,
"message": f"{table_name} 不存在"
}
# 获取总数
count_sql = f"SELECT COUNT(*) as total FROM {table_name}"
count_result = db.execute(text(count_sql)).fetchone()
total_count = count_result.total if count_result else 0
# 获取数据
sql = f"SELECT * FROM {table_name} LIMIT {limit} OFFSET {offset}"
result = db.execute(text(sql))
data = []
columns = result.keys()
for row in result:
data.append(dict(zip(columns, row)))
return {
"success": True,
"message": "获取数据成功",
"data": data,
"total_count": total_count
}
except SQLAlchemyError as e:
return {
"success": False,
"message": f"获取表数据失败: {str(e)}"
}
except Exception as e:
return {
"success": False,
"message": f"未知错误: {str(e)}"
}
@staticmethod
def create_table(db: Session, table_name: str, columns: Dict[str, str], primary_key: Optional[str] = None) -> Dict[str, Any]:
"""创建表"""
try:
# 构建CREATE TABLE语句
column_definitions = []
for col_name, col_type in columns.items():
column_definitions.append(f"{col_name} {col_type}")
if primary_key:
column_definitions.append(f"PRIMARY KEY ({primary_key})")
sql = f"CREATE TABLE {table_name} ({', '.join(column_definitions)})"
db.execute(text(sql))
db.commit()
return {
"success": True,
"message": f"{table_name} 创建成功"
}
except SQLAlchemyError as e:
db.rollback()
return {
"success": False,
"message": f"创建表失败: {str(e)}"
}
except Exception as e:
db.rollback()
return {
"success": False,
"message": f"未知错误: {str(e)}"
}
@staticmethod
def drop_table(db: Session, table_name: str) -> Dict[str, Any]:
"""删除表"""
try:
sql = f"DROP TABLE IF EXISTS {table_name}"
db.execute(text(sql))
db.commit()
return {
"success": True,
"message": f"{table_name} 删除成功"
}
except SQLAlchemyError as e:
db.rollback()
return {
"success": False,
"message": f"删除表失败: {str(e)}"
}
except Exception as e:
db.rollback()
return {
"success": False,
"message": f"未知错误: {str(e)}"
}
@staticmethod
def import_data(db: Session, table_name: str, data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""导入数据到表"""
try:
if not data:
return {
"success": False,
"message": "导入数据不能为空"
}
# 使用pandas DataFrame来处理数据导入
df = pd.DataFrame(data)
# 使用pandas的to_sql方法
df.to_sql(table_name, engine, if_exists='append', index=False)
return {
"success": True,
"message": f"成功导入 {len(data)} 条数据到表 {table_name}"
}
except SQLAlchemyError as e:
db.rollback()
return {
"success": False,
"message": f"导入数据失败: {str(e)}"
}
except Exception as e:
db.rollback()
return {
"success": False,
"message": f"未知错误: {str(e)}"
}
@staticmethod
def get_table_list() -> Dict[str, Any]:
"""获取所有表名及字段信息"""
try:
inspector = inspect(engine)
# 排除mysql的系统表和accounts表
table_names = [table for table in inspector.get_table_names()
if not table.startswith('mysql') and table != 'accounts' and table != 'apscheduler_jobs']
tables_info = []
for table_name in table_names:
# 获取表的列信息
columns = inspector.get_columns(table_name)
column_info = []
for column in columns:
column_info.append({
"name": column['name'],
"type": str(column['type']),
"nullable": column['nullable'],
"default": column['default']
})
tables_info.append({
"table_name": table_name,
"columns": column_info
})
return {
"success": True,
"message": "获取表名及字段信息成功",
"data": tables_info
}
except Exception as e:
return {
"success": False,
"message": f"获取表名失败: {str(e)}"
}
@staticmethod
def import_file_to_database(db: Session, filename: str, file_content: str,
table_name: Optional[str] = None,
force_overwrite: bool = False) -> Dict[str, Any]:
"""从文件导入数据到数据库"""
try:
# 解析文件内容
file_data = FileImportUtils.parse_file(filename, file_content)
if not file_data:
return {
"success": False,
"message": "文件中没有找到有效数据"
}
results = []
inspector = inspect(engine)
existing_tables = inspector.get_table_names()
for item in file_data:
final_table_name = table_name if table_name else item["table_name"]
data = item["data"]
columns = item["columns"]
if not data:
continue
# 检查表是否存在
if final_table_name in existing_tables and not force_overwrite:
return {
"success": False,
"message": f"{final_table_name} 已存在,请使用 force_overwrite=true 覆盖或选择其他表名"
}
# 如果需要覆盖,先删除表
if final_table_name in existing_tables and force_overwrite:
db.execute(text(f"DROP TABLE IF EXISTS `{final_table_name}`"))
db.commit()
# 准备列类型定义
column_types = FileImportUtils.prepare_table_columns(data, columns)
# 创建表
column_definitions = []
for col in columns:
col_type = column_types.get(col, "TEXT")
column_definitions.append(f"`{col}` {col_type}")
create_table_sql = f"""
CREATE TABLE `{final_table_name}` (
id INT AUTO_INCREMENT PRIMARY KEY,
{', '.join(column_definitions)}
)
"""
db.execute(text(create_table_sql))
db.commit()
# 导入数据
df_clean = pd.DataFrame(data)
# 使用pandas的to_sql方法批量导入
df_clean.to_sql(final_table_name, engine, if_exists='append',
index=False, method='multi', chunksize=1000)
results.append({
"table_name": final_table_name,
"rows_imported": len(data),
"columns": columns
})
# 返回结果
if len(results) == 1:
result = results[0]
return {
"success": True,
"message": f"成功导入 {result['rows_imported']} 行数据到表 {result['table_name']}",
"table_name": result["table_name"],
"rows_imported": result["rows_imported"],
"columns": result["columns"]
}
else:
total_rows = sum(r["rows_imported"] for r in results)
table_names = [r["table_name"] for r in results]
return {
"success": True,
"message": f"成功导入 {total_rows} 行数据到 {len(results)} 个表",
"tables": results,
"total_rows": total_rows,
"table_names": table_names
}
except ValueError as e:
return {
"success": False,
"message": str(e)
}
except SQLAlchemyError as e:
db.rollback()
return {
"success": False,
"message": f"数据库操作失败: {str(e)}"
}
except Exception as e:
db.rollback()
return {
"success": False,
"message": f"导入失败: {str(e)}"
}