1310 lines
56 KiB
Python
1310 lines
56 KiB
Python
import os
|
||
import logging
|
||
import re
|
||
import asyncio
|
||
import functools
|
||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||
from sqlalchemy.orm import sessionmaker
|
||
from sqlalchemy.sql import text
|
||
|
||
import Config
|
||
from DbKit.SqlTemplateLoader import SqlTemplateLoader
|
||
from DbKit.TransactionContext import TransactionContext
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def db_retry(max_retries=3, delay=1):
|
||
"""
|
||
数据库操作重试装饰器
|
||
|
||
Args:
|
||
max_retries: 最大重试次数,默认为3次
|
||
delay: 重试间隔时间(秒),默认为1秒
|
||
"""
|
||
def decorator(func):
|
||
@functools.wraps(func)
|
||
async def wrapper(*args, **kwargs):
|
||
# 检查是否提供了 session,如果提供了 session,通常意味着是在一个更大的事务中,不建议在这里重试
|
||
session = kwargs.get('session')
|
||
if session is not None:
|
||
return await func(*args, **kwargs)
|
||
|
||
last_exception = None
|
||
for attempt in range(max_retries):
|
||
try:
|
||
return await func(*args, **kwargs)
|
||
except Exception as e:
|
||
last_exception = e
|
||
# 只有在还有重试机会时才打印警告
|
||
if attempt < max_retries - 1:
|
||
logger.warning(f"数据库操作失败 (尝试 {attempt + 1}/{max_retries}),正在进行重试: {str(e)}")
|
||
await asyncio.sleep(delay)
|
||
else:
|
||
logger.error(f"数据库操作在 {max_retries} 次尝试后仍然失败: {str(e)}")
|
||
|
||
# 如果循环结束仍未返回,说明最后一次尝试也失败了,抛出异常
|
||
if last_exception:
|
||
raise last_exception
|
||
return wrapper
|
||
return decorator
|
||
|
||
|
||
class Db:
|
||
"""通用数据库操作封装类,提供数据库连接和操作功能"""
|
||
# 单例实例
|
||
_instance = None
|
||
|
||
def __new__(cls, db_url=None, sql_dir=None):
|
||
"""确保Db类只有一个实例(异步版本)"""
|
||
if cls._instance is None:
|
||
cls._instance = super(Db, cls).__new__(cls)
|
||
# 初始化数据库连接
|
||
cls._instance.engine = None
|
||
cls._instance.AsyncSessionLocal = None
|
||
cls._instance.db_url = db_url
|
||
cls._instance.sql_dir = sql_dir or os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Sql')
|
||
|
||
# 初始化SQL模板加载器
|
||
cls._instance.sql_loader = SqlTemplateLoader(cls._instance.sql_dir, auto_load=False)
|
||
# 添加标志表示SQL模板尚未加载
|
||
cls._instance._templates_loaded = False
|
||
return cls._instance
|
||
|
||
def __init__(self, db_url=None, sql_dir=None):
|
||
"""初始化数据库操作类
|
||
|
||
Args:
|
||
db_url: 数据库连接URL,不提供时将使用默认配置
|
||
sql_dir: SQL模板文件目录,不提供时将使用默认目录
|
||
"""
|
||
# 单例模式下,初始化只会在第一次创建实例时执行
|
||
# 这里可以为空,因为实际初始化在__new__中完成
|
||
self.sql_dir = sql_dir
|
||
self.db_url = db_url
|
||
|
||
async def init_db(self, db_url=None):
|
||
"""
|
||
初始化数据库连接(异步版本)
|
||
|
||
Args:
|
||
db_url: 数据库连接URL,优先级高于初始化时提供的URL
|
||
"""
|
||
if self.engine:
|
||
return
|
||
|
||
# 确定使用的数据库连接URL
|
||
final_url = db_url or self.db_url
|
||
if not final_url:
|
||
# 如果没有提供URL,使用配置文件中的参数构建连接URL
|
||
try:
|
||
from Config.Config import DB_URL
|
||
final_url = DB_URL
|
||
except ImportError:
|
||
# 兼容旧配置或Postgres配置
|
||
from Config.Config import POSTGRES_HOST, POSTGRES_PORT, POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_DATABASE
|
||
# 转换为异步连接URL
|
||
final_url = f"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DATABASE}"
|
||
else:
|
||
# 如果是同步URL,转换为异步URL
|
||
if final_url.startswith("postgresql://"):
|
||
final_url = final_url.replace("postgresql://", "postgresql+asyncpg://")
|
||
|
||
# 创建异步数据库引擎
|
||
self.engine = create_async_engine(final_url)
|
||
|
||
# 创建异步会话工厂
|
||
self.AsyncSessionLocal = sessionmaker(
|
||
class_=AsyncSession,
|
||
autocommit=False,
|
||
autoflush=False,
|
||
bind=self.engine
|
||
)
|
||
|
||
# 检查Redis中是否已经加载了SQL模板
|
||
templates_loaded_in_redis = await self.sql_loader._check_templates_loaded_from_redis()
|
||
|
||
# 如果Redis中有加载的模板,直接从Redis加载
|
||
if templates_loaded_in_redis:
|
||
try:
|
||
await self.sql_loader._load_from_redis()
|
||
self._templates_loaded = True
|
||
logger.info(f"从Redis加载SQL模板完成,共{len(self.sql_loader._loaded_files)}个文件,{len(self.sql_loader._template_map)}个模板")
|
||
except Exception as e:
|
||
logger.error(f"从Redis加载SQL模板失败: {str(e)}")
|
||
# 不抛出异常,允许继续执行
|
||
pass
|
||
# 如果Redis中没有,且内存中也没有加载过,执行正常加载流程
|
||
elif not getattr(self, '_templates_loaded', False):
|
||
try:
|
||
await self.load_all_templates()
|
||
self._templates_loaded = True
|
||
|
||
# 获取并记录SQL模板统计信息
|
||
loaded_files = len(self.sql_loader._loaded_files)
|
||
template_count = 0
|
||
for namespace, templates in self.sql_loader.templates.items():
|
||
template_count += len(templates)
|
||
logger.info(f"SQL模板加载完成,共加载{loaded_files}个文件,{template_count}个模板")
|
||
except Exception as e:
|
||
logger.error(f"加载SQL模板失败: {str(e)}")
|
||
# 不抛出异常,允许数据库操作继续进行
|
||
pass
|
||
|
||
async def get_session(self):
|
||
"""获取数据库会话(异步版本)
|
||
|
||
Returns:
|
||
AsyncSession: 异步数据库会话对象
|
||
"""
|
||
if self.AsyncSessionLocal is None:
|
||
await self.init_db()
|
||
return self.AsyncSessionLocal()
|
||
|
||
async def close(self):
|
||
if hasattr(self, "engine") and self.engine is not None:
|
||
try:
|
||
await self.engine.dispose()
|
||
finally:
|
||
self.engine = None
|
||
self.AsyncSessionLocal = None
|
||
|
||
async def load_all_templates(self):
|
||
"""异步加载所有SQL模板
|
||
|
||
Returns:
|
||
self: 支持链式调用
|
||
"""
|
||
await self.sql_loader.load_all_templates()
|
||
return self
|
||
|
||
def _is_sql_template(self, sql_input):
|
||
"""
|
||
判断输入的带点字符串是SQL模板还是SQL语句
|
||
|
||
Args:
|
||
sql_input: 输入的字符串,可能是SQL模板名称或SQL语句
|
||
|
||
Returns:
|
||
bool: 如果是SQL模板返回True,如果是SQL语句返回False
|
||
"""
|
||
# 如果输入不是字符串或者不包含点号,直接判定为SQL语句
|
||
if not isinstance(sql_input, str) or '.' not in sql_input:
|
||
return False
|
||
|
||
# 检查sql_loader是否存在且已初始化
|
||
if not hasattr(self, 'sql_loader') or not self.sql_loader:
|
||
return False
|
||
|
||
# 检查模板映射是否存在
|
||
if not hasattr(self.sql_loader, '_template_map') or not self.sql_loader._template_map:
|
||
return False
|
||
|
||
# 在模板映射中查找,如果找到就是SQL模板,否则是SQL语句
|
||
is_template = sql_input in self.sql_loader._template_map
|
||
logger.debug(f"输入 '{sql_input}' 在模板映射中查找结果: {is_template}")
|
||
|
||
return is_template
|
||
|
||
async def get_sql(self, sql_identifier, params=None):
|
||
"""获取指定名称的SQL模板,并处理参数替换
|
||
|
||
Args:
|
||
sql_identifier: 格式为'namespace.template_name'的SQL标识符
|
||
params: 可选的参数字典,用于替换模板中的参数
|
||
|
||
Returns:
|
||
str: 处理后的SQL语句
|
||
"""
|
||
# logger.info(f"尝试获取SQL模板: {sql_identifier}")
|
||
sql_template = await self.sql_loader.get_sql(sql_identifier)
|
||
|
||
# 处理模板,即使没有提供参数也进行处理(清理条件块等)
|
||
processed_sql = self._process_sql_template(sql_template, params)
|
||
logger.debug(f"处理后的SQL语句: {processed_sql}")
|
||
return processed_sql
|
||
|
||
def _process_sql_template(self, sql_template, params=None):
|
||
"""
|
||
处理SQL模板中的条件判断和参数替换
|
||
|
||
Args:
|
||
sql_template: SQL模板字符串
|
||
params: 参数字典
|
||
|
||
Returns:
|
||
str: 处理后的SQL语句
|
||
"""
|
||
if params is None:
|
||
params = {}
|
||
|
||
# 检查是否包含#if条件块
|
||
if_matches = re.findall(r'#if\([^)]+\)', sql_template)
|
||
logger.debug(f"发现 {len(if_matches)} 个#if条件块")
|
||
for i, match in enumerate(if_matches, 1):
|
||
logger.debug(f" {i}. {match}")
|
||
|
||
# 检查是否包含ORDER BY子句
|
||
if 'ORDER BY' in sql_template:
|
||
logger.debug("SQL模板包含ORDER BY子句")
|
||
|
||
processed_sql = sql_template
|
||
|
||
# 处理#if条件判断
|
||
# 修复正则表达式,确保正确匹配条件块,支持 ! 取反
|
||
if_pattern = r'#if\((!?\w+)\)([\s\S]*?)#end'
|
||
|
||
# 使用finditer来获取匹配的位置信息
|
||
matches = list(re.finditer(if_pattern, processed_sql))
|
||
|
||
# 从后向前替换,避免位置偏移问题
|
||
for match in reversed(matches):
|
||
raw_param_name = match.group(1)
|
||
sql_block = match.group(2)
|
||
|
||
# 处理取反逻辑
|
||
is_negation = raw_param_name.startswith('!')
|
||
param_name = raw_param_name[1:] if is_negation else raw_param_name
|
||
|
||
# 检查参数是否存在且不为None (且不为空,如果需要更严格的检查)
|
||
# 这里保持原逻辑基础上增加对空值的判定,因为空列表通常也意味着False
|
||
has_value = param_name in params and params[param_name] is not None and params[param_name]
|
||
|
||
# 决定是否保留代码块
|
||
keep_block = (not has_value) if is_negation else has_value
|
||
|
||
if keep_block:
|
||
# 如果条件满足,保留SQL块
|
||
# 移除多余的空白行和缩进
|
||
lines = sql_block.strip().split('\n')
|
||
cleaned_lines = [line.strip() for line in lines if line.strip()]
|
||
cleaned_block = '\n'.join(cleaned_lines)
|
||
|
||
# 替换原始匹配的内容
|
||
processed_sql = processed_sql[:match.start()] + cleaned_block + processed_sql[match.end():]
|
||
|
||
logger.debug(f"保留条件块 #if({raw_param_name}),内容: {cleaned_block}")
|
||
else:
|
||
# 如果条件不满足,移除整个条件块
|
||
processed_sql = processed_sql[:match.start()] + "" + processed_sql[match.end():]
|
||
logger.debug(f"移除条件块 #if({raw_param_name})")
|
||
|
||
# 创建参数字典的副本,用于存储处理后的参数
|
||
if params is not None:
|
||
processed_params = params.copy()
|
||
else:
|
||
processed_params = {}
|
||
|
||
# 处理参数替换 - 支持直接的变量名格式
|
||
# 首先找出所有可能的参数占位符(#para格式和直接变量名格式)
|
||
# 处理#para格式的参数
|
||
para_pattern = r'#para\((\w+)\)'
|
||
all_para_placeholders = re.findall(para_pattern, processed_sql)
|
||
|
||
# 替换#para格式的参数
|
||
for param_name in all_para_placeholders:
|
||
placeholder = f"#para({param_name})"
|
||
if placeholder in processed_sql and param_name in params and params[param_name] is not None:
|
||
# 检查参数是否为列表类型且用于IN操作
|
||
if isinstance(params[param_name], (list, tuple)) and "in" in processed_sql.lower():
|
||
logger.debug(f"检测到列表参数 {param_name} 用于IN操作,进行特殊处理")
|
||
# 为列表中的每个元素创建单独的参数
|
||
list_values = params[param_name]
|
||
if list_values: # 确保列表不为空
|
||
# 生成多个占位符,如 :idList_1, :idList_2, :idList_3
|
||
placeholders = ', '.join([f":{param_name}_{i}" for i in range(len(list_values))])
|
||
processed_sql = processed_sql.replace(placeholder, placeholders)
|
||
# 更新参数字典,添加每个列表元素的参数
|
||
for i, value in enumerate(list_values):
|
||
processed_params[f"{param_name}_{i}"] = value
|
||
logger.debug(f"为列表参数 {param_name} 创建子参数 {param_name}_{i} = {value}")
|
||
else:
|
||
# 空列表情况,将整个条件移除(因为IN () 是无效的SQL)
|
||
logger.warning(f"列表参数 {param_name} 为空,将移除包含该参数的条件表达式")
|
||
# 查找包含该占位符的整个条件表达式行
|
||
lines = processed_sql.split('\n')
|
||
new_lines = []
|
||
for line in lines:
|
||
if placeholder in line:
|
||
line_stripped = line.strip().upper()
|
||
if line_stripped.startswith('AND ') or line_stripped.startswith('OR '):
|
||
continue
|
||
new_lines.append(line)
|
||
processed_sql = '\n'.join(new_lines)
|
||
else:
|
||
# 普通参数处理
|
||
processed_sql = processed_sql.replace(placeholder, f":{param_name}")
|
||
logger.debug(f"已将参数 {placeholder} 转换为 :{param_name}")
|
||
elif placeholder in processed_sql:
|
||
# 参数缺失或为None,移除包含该参数的条件表达式
|
||
logger.warning(f"参数 {param_name} 缺失或为None,将移除包含该参数的条件表达式")
|
||
lines = processed_sql.split('\n')
|
||
new_lines = []
|
||
for line in lines:
|
||
if placeholder in line:
|
||
line_stripped = line.strip().upper()
|
||
if line_stripped.startswith('AND ') or line_stripped.startswith('OR '):
|
||
continue
|
||
new_lines.append(line)
|
||
processed_sql = '\n'.join(new_lines)
|
||
|
||
# 处理直接的变量名格式(如 $person_id 或 #{person_id})
|
||
# 尝试匹配各种常见的变量名格式
|
||
variable_patterns = [
|
||
r'\$([a-zA-Z_][a-zA-Z0-9_]*)', # $variable 格式
|
||
r'#\{([a-zA-Z_][a-zA-Z0-9_]*)\}', # #{variable} 格式
|
||
r'\{([a-zA-Z_][a-zA-Z0-9_]*)\}' # {variable} 格式
|
||
]
|
||
|
||
for pattern in variable_patterns:
|
||
matches = re.finditer(pattern, processed_sql)
|
||
# 收集所有匹配并从后向前替换,避免位置偏移
|
||
variable_matches = list(matches)
|
||
for match in reversed(variable_matches):
|
||
placeholder = match.group(0)
|
||
param_name = match.group(1)
|
||
|
||
# 检查参数是否存在且不为None
|
||
if param_name in params and params[param_name] is not None:
|
||
# 替换为SQLAlchemy参数格式
|
||
replacement = f":{param_name}"
|
||
processed_sql = processed_sql[:match.start()] + replacement + processed_sql[match.end():]
|
||
logger.debug(f"已将变量 {placeholder} 转换为 {replacement}")
|
||
else:
|
||
# 变量不存在或为None,记录警告
|
||
logger.warning(f"变量 {param_name} 缺失或为None,保留原样")
|
||
|
||
# 更新params,使其包含处理后的参数(包括扩展的列表参数)
|
||
if params is not None:
|
||
for key, value in processed_params.items():
|
||
params[key] = value
|
||
|
||
# 清理多余的空白行和空格
|
||
lines = processed_sql.split('\n')
|
||
processed_lines = [line.strip() for line in lines if line.strip()]
|
||
processed_sql = '\n'.join(processed_lines)
|
||
|
||
# 额外的清理:处理可能产生的孤立逻辑运算符
|
||
# 例如,如果我们移除了一行,可能会留下单独的AND/OR
|
||
processed_sql = re.sub(r'(^|\n)\s*(AND|OR)\s*($|\n)', '\1\3', processed_sql)
|
||
|
||
# 最终清理
|
||
lines = processed_sql.split('\n')
|
||
processed_lines = [line.strip() for line in lines if line.strip()]
|
||
processed_sql = '\n'.join(processed_lines)
|
||
|
||
return processed_sql
|
||
|
||
@db_retry()
|
||
async def find(self, sql, params=None, session=None):
|
||
"""
|
||
执行SQL查询并返回结果(异步版本)
|
||
|
||
Args:
|
||
sql: SQL查询语句或SQL模板名称(支持namespace.sql_name格式)
|
||
params: 查询参数,字典形式
|
||
session: 可选的数据库会话对象,不提供时将自动创建
|
||
|
||
Returns:
|
||
list: 查询结果的字典列表
|
||
"""
|
||
# 使用当前实例
|
||
instance = self
|
||
|
||
# 如果没有提供params,创建一个空字典
|
||
if params is None:
|
||
params = {}
|
||
|
||
# 确保数据库连接已初始化
|
||
if instance.AsyncSessionLocal is None:
|
||
await instance.init_db()
|
||
|
||
# 检查是否需要加载SQL模板
|
||
# 首先检查Redis中是否有加载的模板
|
||
if not getattr(instance, '_templates_loaded', False):
|
||
templates_loaded_in_redis = await instance.sql_loader._check_templates_loaded_from_redis()
|
||
if templates_loaded_in_redis:
|
||
try:
|
||
await instance.sql_loader._load_from_redis()
|
||
instance._templates_loaded = True
|
||
logger.info("从Redis自动加载SQL模板完成")
|
||
except Exception as e:
|
||
logger.error(f"从Redis自动加载SQL模板失败: {str(e)}")
|
||
else:
|
||
# Redis中没有,尝试正常加载
|
||
try:
|
||
await instance.load_all_templates()
|
||
instance._templates_loaded = True
|
||
logger.info("SQL模板自动加载完成")
|
||
except Exception as e:
|
||
logger.error(f"自动加载SQL模板失败: {str(e)}")
|
||
# 不抛出异常,允许数据库操作继续进行
|
||
pass
|
||
|
||
# 检查是否为SQL模板名称并处理
|
||
# 即使不是模板名称,如果包含点号,我们也尝试在加载后检查它是否真的是模板
|
||
is_template = self._is_sql_template(sql)
|
||
if not is_template and '.' in sql:
|
||
# 如果包含点号但没在映射中找到,可能是因为缓存加载问题,尝试重新加载映射
|
||
if hasattr(self.sql_loader, '_build_template_map'):
|
||
self.sql_loader._build_template_map()
|
||
is_template = self._is_sql_template(sql)
|
||
|
||
if is_template:
|
||
try:
|
||
logger.debug(f"确认为SQL模板: {sql}")
|
||
|
||
# 尝试获取SQL模板
|
||
sql_content = await instance.get_sql(sql, params)
|
||
# 如果成功获取到内容且不等于原始名称,说明是模板
|
||
if sql_content and sql_content != sql:
|
||
sql = sql_content
|
||
logger.debug(f"成功获取SQL模板,处理后的SQL: {sql}")
|
||
# 对于模板SQL,参数已经被get_sql处理过了,不需要再处理
|
||
# params = None
|
||
else:
|
||
logger.debug(f"获取到的内容与输入相同,可能不是模板: {sql_content}")
|
||
except Exception as e:
|
||
logger.error(f"获取SQL模板失败: {str(e)},将作为普通SQL执行")
|
||
# 继续执行,将其作为普通SQL处理
|
||
|
||
# 只有当SQL不是从模板获取时,才处理普通SQL中的#para()格式参数
|
||
# 因为模板已经通过get_sql方法处理了参数替换
|
||
if params and isinstance(params, dict) and isinstance(sql, str):
|
||
# 检查SQL是否还包含#para()格式(如果是模板,应该已经被处理过了)
|
||
if '#para(' in sql:
|
||
logger.debug("检测到SQL中仍包含#para()格式,进行参数替换")
|
||
# 替换SQL中的#para(param_name)为:param_name
|
||
for param_name, param_value in list(params.items()): # 使用list()创建副本,避免在遍历时修改
|
||
para_pattern = f'#para\\({param_name}\\)' # 转义括号
|
||
if re.search(para_pattern, sql):
|
||
# 检查参数是否为列表类型且用于IN操作
|
||
if isinstance(param_value, (list, tuple)) and "in" in sql.lower():
|
||
logger.debug(f"检测到普通SQL中的列表参数 {param_name} 用于IN操作,进行特殊处理")
|
||
# 为列表中的每个元素创建单独的参数
|
||
list_values = param_value
|
||
if list_values: # 确保列表不为空
|
||
# 生成多个占位符,如 :idList_1, :idList_2, :idList_3
|
||
placeholders = ', '.join([f":{param_name}_{i}" for i in range(len(list_values))])
|
||
sql = re.sub(para_pattern, placeholders, sql)
|
||
# 更新参数字典,添加每个列表元素的参数
|
||
for i, value in enumerate(list_values):
|
||
params[f"{param_name}_{i}"] = value
|
||
logger.debug(f"为列表参数 {param_name} 创建子参数 {param_name}_{i} = {value}")
|
||
else:
|
||
# 空列表情况,将整个条件移除(因为IN () 是无效的SQL)
|
||
logger.warning(f"列表参数 {param_name} 为空,将移除包含该参数的条件表达式")
|
||
# 查找包含该占位符的整个条件表达式行
|
||
lines = sql.split('\n')
|
||
new_lines = []
|
||
for line in lines:
|
||
if f"#para({param_name})" in line:
|
||
line_stripped = line.strip().upper()
|
||
if line_stripped.startswith('AND ') or line_stripped.startswith('OR '):
|
||
continue
|
||
new_lines.append(line)
|
||
sql = '\n'.join(new_lines)
|
||
else:
|
||
# 普通参数处理
|
||
sql = re.sub(para_pattern, f':{param_name}', sql)
|
||
logger.debug(f"已替换参数占位符: #para({param_name}) -> :{param_name}")
|
||
else:
|
||
logger.debug("SQL模板已处理过参数,跳过重复处理")
|
||
|
||
is_own_session = False
|
||
try:
|
||
# 如果没有提供会话,创建新会话
|
||
if session is None:
|
||
session = await instance.get_session()
|
||
is_own_session = True
|
||
|
||
# 执行查询
|
||
# logger.info(f"执行查询: {sql},参数: {params}")
|
||
result = await session.execute(text(sql), params or {})
|
||
|
||
# 将结果转换为字典列表
|
||
return [dict(row._mapping) for row in result]
|
||
except Exception as e:
|
||
logger.error(f"查询执行失败: {str(e)}")
|
||
raise
|
||
finally:
|
||
# 如果是自己创建的会话,负责关闭
|
||
if is_own_session and session:
|
||
await session.close()
|
||
|
||
async def findFirst(self, sql, params=None, session=None):
|
||
"""
|
||
执行查询并返回结果的第一条记录(异步版本)
|
||
|
||
Args:
|
||
sql: SQL查询语句或SQL模板名称(支持namespace.sql_name格式)
|
||
params: 查询参数,字典形式
|
||
session: 可选的数据库会话对象,不提供时将自动创建
|
||
|
||
Returns:
|
||
dict: 第一条记录的字典,如果没有结果则返回None
|
||
"""
|
||
# 直接调用异步的find方法,它已经使用_is_sql_template进行了模板检查
|
||
rows = await self.find(sql, params, session)
|
||
return rows[0] if rows else None
|
||
|
||
async def findById(self, table_name, primary_key_value, primary_key='id', session=None):
|
||
"""
|
||
根据主键查询指定表中的记录(异步版本)
|
||
|
||
Args:
|
||
table_name: 表名
|
||
primary_key_value: 主键值
|
||
primary_key: 主键字段名,默认为'id'
|
||
session: 可选的数据库会话对象,不提供时将自动创建
|
||
|
||
Returns:
|
||
dict: 查询到的记录字典,如果没有找到则返回None
|
||
"""
|
||
if not table_name:
|
||
raise ValueError("表名不能为空")
|
||
|
||
# 构建查询SQL
|
||
sql = f"SELECT * FROM {table_name} WHERE {primary_key} = :primary_key_value"
|
||
|
||
# 准备参数
|
||
params = {"primary_key_value": primary_key_value}
|
||
|
||
logger.debug(f"生成的查询SQL: {sql}")
|
||
logger.debug(f"查询参数: {params}")
|
||
|
||
# 异步使用findFirst方法执行查询
|
||
return await self.findFirst(sql, params, session)
|
||
|
||
async def paginate(self, sql, page_number=1, page_size=10, params=None, session=None):
|
||
"""
|
||
执行SQL查询并返回分页结果(异步版本)
|
||
|
||
Args:
|
||
sql: SQL查询语句或SQL模板名称(支持namespace.sql_name格式)
|
||
page_number: 页码,从1开始
|
||
page_size: 每页记录数
|
||
params: 查询参数,字典形式
|
||
session: 可选的数据库会话对象,不提供时将自动创建
|
||
|
||
Returns:
|
||
dict: 包含分页信息的字典,格式如下:
|
||
{
|
||
'list': 当前页的数据列表,
|
||
'pageNumber': 当前页码,
|
||
'pageSize': 每页记录数,
|
||
'totalRow': 总记录数,
|
||
'totalPage': 总页数
|
||
}
|
||
"""
|
||
# 确保数据库已初始化
|
||
if self.AsyncSessionLocal is None:
|
||
await self.init_db()
|
||
|
||
# 初始化诊断日志
|
||
logger.debug("======= 开始分页查询 =======")
|
||
|
||
# 如果没有提供params,创建一个空字典
|
||
if params is None:
|
||
params = {}
|
||
logger.debug(f"传入的参数: {params}")
|
||
|
||
# 确保页码和每页大小为正整数
|
||
try:
|
||
page_number = int(page_number)
|
||
page_size = int(page_size)
|
||
except (ValueError, TypeError):
|
||
logger.error(f"页码或每页大小类型错误,重置为默认值")
|
||
page_number = 1
|
||
page_size = 10
|
||
|
||
if page_number < 1:
|
||
page_number = 1
|
||
if page_size < 1:
|
||
page_size = 10
|
||
|
||
logger.debug(f"处理后的页码: {page_number}, 每页大小: {page_size}")
|
||
|
||
# 使用_is_sql_template方法检查并获取SQL模板
|
||
is_template = self._is_sql_template(sql)
|
||
if not is_template and '.' in sql:
|
||
# 如果包含点号但没在映射中找到,可能是因为缓存加载问题,尝试重新加载映射
|
||
if hasattr(self.sql_loader, '_build_template_map'):
|
||
self.sql_loader._build_template_map()
|
||
is_template = self._is_sql_template(sql)
|
||
|
||
if is_template:
|
||
try:
|
||
logger.debug(f"确认为SQL模板: {sql}")
|
||
|
||
# 尝试获取SQL模板
|
||
sql_content = await self.get_sql(sql, params)
|
||
# 如果成功获取到内容且不等于原始名称,说明是模板
|
||
if sql_content and sql_content != sql:
|
||
sql = sql_content
|
||
logger.debug(f"成功获取SQL模板,处理后的SQL: {sql}")
|
||
else:
|
||
logger.debug(f"获取到的内容与输入相同,可能不是模板: {sql_content}")
|
||
except Exception as e:
|
||
logger.error(f"获取SQL模板失败: {str(e)},将作为普通SQL执行")
|
||
# 继续执行,将其作为普通SQL处理
|
||
|
||
# 只有当SQL不是从模板获取时,才处理普通SQL中的#para()格式参数
|
||
# 因为模板已经通过get_sql方法处理了参数替换
|
||
if params and isinstance(params, dict) and isinstance(sql, str):
|
||
# 检查SQL是否还包含#para()格式(如果是模板,应该已经被处理过了)
|
||
if '#para(' in sql:
|
||
logger.debug("检测到SQL中仍包含#para()格式,进行参数替换")
|
||
# 替换SQL中的#para(param_name)为:param_name
|
||
for param_name, param_value in list(params.items()): # 使用list()创建副本,避免在遍历时修改
|
||
para_pattern = f'#para\\({param_name}\\)' # 转义括号
|
||
if re.search(para_pattern, sql):
|
||
# 检查参数是否为列表类型且用于IN操作
|
||
if isinstance(param_value, (list, tuple)) and "in" in sql.lower():
|
||
logger.debug(f"检测到分页查询中的列表参数 {param_name} 用于IN操作,进行特殊处理")
|
||
# 为列表中的每个元素创建单独的参数
|
||
list_values = param_value
|
||
if list_values: # 确保列表不为空
|
||
# 生成多个占位符,如 :idList_1, :idList_2, :idList_3
|
||
placeholders = ', '.join([f":{param_name}_{i}" for i in range(len(list_values))])
|
||
sql = re.sub(para_pattern, placeholders, sql)
|
||
# 更新参数字典,添加每个列表元素的参数
|
||
for i, value in enumerate(list_values):
|
||
params[f"{param_name}_{i}"] = value
|
||
logger.debug(f"为列表参数 {param_name} 创建子参数 {param_name}_{i} = {value}")
|
||
else:
|
||
# 空列表情况,将整个条件移除(因为IN () 是无效的SQL)
|
||
logger.warning(f"列表参数 {param_name} 为空,将移除包含该参数的条件表达式")
|
||
# 查找包含该占位符的整个条件表达式行
|
||
lines = sql.split('\n')
|
||
new_lines = []
|
||
for line in lines:
|
||
if f"#para({param_name})" in line:
|
||
line_stripped = line.strip().upper()
|
||
if line_stripped.startswith('AND ') or line_stripped.startswith('OR '):
|
||
continue
|
||
new_lines.append(line)
|
||
sql = '\n'.join(new_lines)
|
||
else:
|
||
# 普通参数处理
|
||
sql = re.sub(para_pattern, f':{param_name}', sql)
|
||
logger.debug(f"已替换参数占位符: #para({param_name}) -> :{param_name}")
|
||
else:
|
||
logger.debug("SQL模板已处理过参数,跳过重复处理")
|
||
|
||
is_own_session = False
|
||
try:
|
||
# 如果没有提供会话,创建新会话
|
||
if session is None:
|
||
session = await self.get_session()
|
||
is_own_session = True
|
||
|
||
# 计算总记录数
|
||
count_sql = self._build_count_sql(sql)
|
||
logger.debug(f"构建的COUNT SQL: {count_sql}")
|
||
|
||
try:
|
||
count_result = await session.execute(text(count_sql), params or {})
|
||
count_value = count_result.scalar()
|
||
|
||
# 确保total_row是正确的整数值
|
||
try:
|
||
total_row = int(count_value) if count_value is not None else 0
|
||
except (ValueError, TypeError):
|
||
total_row = 0
|
||
except Exception as e:
|
||
logger.error(f"COUNT查询失败: {str(e)}")
|
||
total_row = 0
|
||
# 如果是自己创建的会话,当COUNT查询失败时,回滚事务并重新创建会话
|
||
if is_own_session:
|
||
logger.debug("COUNT查询失败,回滚并重新创建会话")
|
||
await session.rollback()
|
||
await session.close()
|
||
session = await self.get_session()
|
||
|
||
# 添加详细日志记录总记录数和每页大小
|
||
logger.debug(f"总记录数: {total_row}, 每页大小: {page_size}")
|
||
|
||
# 确保total_row是整数
|
||
total_row = int(total_row) if total_row is not None else 0
|
||
|
||
# 计算总页数(使用更清晰的计算方式)
|
||
if page_size > 0:
|
||
total_page = (total_row + page_size - 1) // page_size
|
||
else:
|
||
total_page = 0
|
||
|
||
logger.debug(f"计算的总页数: {total_page}")
|
||
|
||
# 计算偏移量
|
||
offset = (page_number - 1) * page_size
|
||
|
||
# 构建分页SQL(PostgreSQL语法)
|
||
# 首先检查SQL是否已经包含LIMIT或OFFSET子句
|
||
has_limit = re.search(r'(?i)\bLIMIT\b', sql)
|
||
has_offset = re.search(r'(?i)\bOFFSET\b', sql)
|
||
|
||
# 如果已经包含LIMIT或OFFSET,直接使用原始SQL
|
||
if has_limit or has_offset:
|
||
paginated_sql = sql
|
||
else:
|
||
# 否则,在SQL末尾添加LIMIT和OFFSET
|
||
# 不需要复杂的ORDER BY处理,PostgreSQL会自动处理顺序
|
||
paginated_sql = f"{sql} LIMIT :limit OFFSET :offset"
|
||
|
||
logger.debug(f"构建的分页SQL: {paginated_sql}")
|
||
|
||
# 添加分页参数
|
||
paginated_params = params.copy() if params else {}
|
||
paginated_params['limit'] = page_size
|
||
paginated_params['offset'] = offset
|
||
logger.debug(f"分页查询参数: {paginated_params}")
|
||
|
||
# 执行分页查询
|
||
result = await session.execute(text(paginated_sql), paginated_params)
|
||
|
||
# 将结果转换为字典列表
|
||
list_data = [dict(row._mapping) for row in result]
|
||
|
||
# 确保total_row是整数
|
||
try:
|
||
total_row_int = int(total_row)
|
||
except Exception as e:
|
||
total_row_int = 0
|
||
|
||
# 计算总页数
|
||
if page_size > 0:
|
||
remainder = total_row_int % page_size
|
||
quotient = total_row_int // page_size
|
||
calculated_pages = quotient + 1 if remainder > 0 else quotient
|
||
else:
|
||
calculated_pages = 0
|
||
|
||
# 构建返回结果,使用正确的页数
|
||
pagination_result = {
|
||
'list': list_data if list_data is not None else [],
|
||
'page': page_number,
|
||
'page_size': page_size,
|
||
'total': total_row,
|
||
'total_pages': calculated_pages # 使用正确的页数
|
||
}
|
||
|
||
return pagination_result
|
||
except Exception as e:
|
||
logger.error(f"分页查询失败: {str(e)}")
|
||
raise
|
||
finally:
|
||
# 如果是自己创建的会话,负责关闭
|
||
if is_own_session and session:
|
||
await session.close()
|
||
|
||
def _build_count_sql(self, sql):
|
||
"""
|
||
根据原始SQL构建用于计算总记录数的SQL
|
||
|
||
Args:
|
||
sql: 原始SQL查询语句
|
||
|
||
Returns:
|
||
str: 计算总记录数的SQL
|
||
"""
|
||
import re
|
||
|
||
# 清理SQL,移除前后空白
|
||
sql = sql.strip()
|
||
|
||
# 策略1: 使用正则表达式查找FROM关键字(更灵活)
|
||
match = re.search(r'(?i)\bFROM\b', sql)
|
||
if match:
|
||
from_pos = match.start()
|
||
# 提取FROM后面的内容(包括FROM关键字)
|
||
from_part = sql[from_pos:]
|
||
|
||
# 检查是否包含ORDER BY子句,如果有,移除
|
||
order_by_match = re.search(r'(?i)\bORDER\s+BY\b', from_part)
|
||
if order_by_match:
|
||
order_by_pos = order_by_match.start()
|
||
from_part = from_part[:order_by_pos]
|
||
|
||
# 检查是否包含GROUP BY子句
|
||
if re.search(r'(?i)\bGROUP\s+BY\b', from_part):
|
||
# 对于复杂查询,使用子查询来计算总数
|
||
subquery_sql = sql[:order_by_pos] if order_by_match else sql
|
||
return f"SELECT COUNT(*) FROM ({subquery_sql}) AS subquery"
|
||
else:
|
||
# 简单查询,直接替换SELECT部分
|
||
return f"SELECT COUNT(*) {from_part}"
|
||
|
||
# 策略2: 使用多种空格格式查找FROM关键字
|
||
from_variants = [' FROM ', ' FROM', 'FROM ', 'FROM']
|
||
sql_upper = sql.upper()
|
||
for variant in from_variants:
|
||
if variant.upper() in sql_upper:
|
||
from_pos = sql_upper.find(variant.upper())
|
||
from_part = sql[from_pos:]
|
||
|
||
# 移除ORDER BY
|
||
if ' ORDER BY ' in from_part.upper():
|
||
order_by_pos = from_part.upper().rfind(' ORDER BY ')
|
||
from_part = from_part[:order_by_pos]
|
||
|
||
return f"SELECT COUNT(*) {from_part}"
|
||
|
||
# 如果所有策略都失败,返回默认查询
|
||
return "SELECT COUNT(*)"
|
||
|
||
async def check_column_exists(self, table_name, column_name):
|
||
"""检查表中是否存在指定列
|
||
|
||
Args:
|
||
table_name: 表名
|
||
column_name: 列名
|
||
|
||
Returns:
|
||
bool: 如果列存在返回True,否则返回False
|
||
"""
|
||
try:
|
||
# 兼容 Doris 和 MySQL 的语法
|
||
sql = f"SHOW COLUMNS FROM {table_name} LIKE :column_name"
|
||
params = {"column_name": column_name}
|
||
# 使用 find 方法执行查询
|
||
result = await self.find(sql, params)
|
||
return len(result) > 0
|
||
except Exception as e:
|
||
logger.debug(f"检查列是否存在时出错 (可能表不存在): {str(e)}")
|
||
return False
|
||
|
||
@db_retry()
|
||
async def execute_update(self, sql, params=None, session=None):
|
||
"""执行SQL更新操作(插入、更新、删除)(异步版本)
|
||
|
||
Args:
|
||
sql: SQL更新语句或SQL模板名称(支持namespace.sql_name格式)
|
||
params: 更新参数,字典形式
|
||
session: 可选的数据库会话对象,不提供时将自动创建
|
||
|
||
Returns:
|
||
int: 受影响的行数
|
||
"""
|
||
# 确保数据库已初始化
|
||
if self.AsyncSessionLocal is None:
|
||
await self.init_db()
|
||
|
||
# 如果没有提供params,创建一个空字典
|
||
if params is None:
|
||
params = {}
|
||
|
||
# 检查是否为SQL模板名称并处理
|
||
if self._is_sql_template(sql):
|
||
try:
|
||
logger.debug(f"确认为SQL模板: {sql}")
|
||
|
||
# 尝试获取SQL模板
|
||
sql_content = await self.get_sql(sql, params)
|
||
# 如果成功获取到内容且不等于原始名称,说明是模板
|
||
if sql_content and sql_content != sql:
|
||
sql = sql_content
|
||
logger.debug(f"成功获取SQL模板,处理后的SQL: {sql}")
|
||
# 对于模板SQL,参数已经被get_sql处理过了,不需要再处理
|
||
params = None
|
||
else:
|
||
logger.debug(f"获取到的内容与输入相同,可能不是模板: {sql_content}")
|
||
except Exception as e:
|
||
logger.error(f"获取SQL模板失败: {str(e)},将作为普通SQL执行")
|
||
# 继续执行,将其作为普通SQL处理
|
||
|
||
# 只有当SQL不是从模板获取时,才处理普通SQL中的#para()格式参数
|
||
# 因为模板已经通过get_sql方法处理了参数替换
|
||
if params and isinstance(params, dict):
|
||
# 检查SQL是否还包含#para()格式(如果是模板,应该已经被处理过了)
|
||
if '#para(' in sql:
|
||
logger.debug("检测到SQL中仍包含#para()格式,进行参数替换")
|
||
# 替换SQL中的#para(param_name)为:param_name
|
||
for param_name, param_value in list(params.items()): # 使用list()创建副本,避免在遍历时修改
|
||
para_pattern = f'#para\\({param_name}\\)' # 转义括号
|
||
if re.search(para_pattern, sql):
|
||
# 检查参数是否为列表类型且用于IN操作
|
||
if isinstance(param_value, (list, tuple)) and "in" in sql.lower():
|
||
logger.debug(f"检测到更新SQL中的列表参数 {param_name} 用于IN操作,进行特殊处理")
|
||
# 为列表中的每个元素创建单独的参数
|
||
list_values = param_value
|
||
if list_values: # 确保列表不为空
|
||
# 生成多个占位符,如 :idList_1, :idList_2, :idList_3
|
||
placeholders = ', '.join([f":{param_name}_{i}" for i in range(len(list_values))])
|
||
sql = re.sub(para_pattern, placeholders, sql)
|
||
# 更新参数字典,添加每个列表元素的参数
|
||
for i, value in enumerate(list_values):
|
||
params[f"{param_name}_{i}"] = value
|
||
logger.debug(f"为列表参数 {param_name} 创建子参数 {param_name}_{i} = {value}")
|
||
else:
|
||
# 空列表情况,将整个条件移除(因为IN () 是无效的SQL)
|
||
logger.warning(f"列表参数 {param_name} 为空,将移除包含该参数的条件表达式")
|
||
# 查找包含该占位符的整个条件表达式行
|
||
lines = sql.split('\n')
|
||
new_lines = []
|
||
for line in lines:
|
||
if f"#para({param_name})" in line:
|
||
line_stripped = line.strip().upper()
|
||
if line_stripped.startswith('AND ') or line_stripped.startswith('OR '):
|
||
continue
|
||
new_lines.append(line)
|
||
sql = '\n'.join(new_lines)
|
||
else:
|
||
# 普通参数处理
|
||
sql = re.sub(para_pattern, f':{param_name}', sql)
|
||
logger.debug(f"已替换参数占位符: #para({param_name}) -> :{param_name}")
|
||
else:
|
||
logger.debug("SQL模板已处理过参数,跳过重复处理")
|
||
|
||
is_own_session = False
|
||
try:
|
||
# 如果没有提供会话,创建新会话
|
||
if session is None:
|
||
session = await self.get_session()
|
||
is_own_session = True
|
||
|
||
# 执行更新
|
||
# logger.info(f"执行更新: {sql},参数: {params}")
|
||
result = await session.execute(text(sql), params or {})
|
||
|
||
# 提交事务
|
||
await session.commit()
|
||
|
||
# 返回受影响的行数(如果有)
|
||
return result.rowcount if hasattr(result, 'rowcount') else 0
|
||
except Exception as e:
|
||
logger.error(f"更新操作失败: {str(e)}")
|
||
if session:
|
||
await session.rollback()
|
||
raise
|
||
finally:
|
||
# 如果是自己创建的会话,负责关闭
|
||
if is_own_session and session:
|
||
await session.close()
|
||
|
||
@db_retry()
|
||
async def save(self, table_name, data, primary_key, session=None):
|
||
"""插入数据到指定表,并返回插入后的主键值(异步版本)
|
||
|
||
Args:
|
||
table_name: 表名
|
||
data: 要插入的数据字典
|
||
primary_key: 主键字段名
|
||
session: 可选的数据库会话对象,不提供时将自动创建
|
||
|
||
Returns:
|
||
插入后的主键值
|
||
|
||
Raises:
|
||
ValueError: 当参数验证失败时
|
||
Exception: 当插入操作失败时
|
||
"""
|
||
if not table_name or not data:
|
||
raise ValueError("表名和数据不能为空")
|
||
|
||
# 主键可能由数据库自动生成,不传递也可以
|
||
# 无需额外检查,直接构建SQL语句
|
||
# 如果data中不包含primary_key,SQL会自动处理
|
||
|
||
# logger.info(f"准备插入数据到表 {table_name},主键字段: {primary_key}")
|
||
|
||
# 构建插入SQL (移除 RETURNING,Doris/MySQL 不支持)
|
||
columns = ', '.join(data.keys())
|
||
placeholders = ', '.join([f":{key}" for key in data.keys()])
|
||
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
|
||
|
||
logger.debug(f"生成的插入SQL: {sql}")
|
||
logger.debug(f"插入参数: {data}")
|
||
|
||
is_own_session = False
|
||
try:
|
||
# 如果没有提供会话,创建新会话
|
||
if session is None:
|
||
session = await self.get_session()
|
||
is_own_session = True
|
||
|
||
# 执行插入
|
||
result = await session.execute(text(sql), data)
|
||
await session.commit()
|
||
|
||
# 获取插入的主键值
|
||
# 如果 data 中包含主键,直接返回
|
||
if primary_key in data:
|
||
return data[primary_key]
|
||
|
||
# 否则尝试获取 lastrowid (某些 DB 支持)
|
||
try:
|
||
return result.lastrowid
|
||
except:
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"数据插入失败: {str(e)}")
|
||
if session:
|
||
await session.rollback()
|
||
raise
|
||
finally:
|
||
# 如果是自己创建的会话,负责关闭
|
||
if is_own_session and session:
|
||
await session.close()
|
||
|
||
@db_retry()
|
||
async def update(self, table_name, data, primary_key, session=None):
|
||
"""根据主键更新指定表中的数据(异步版本)
|
||
|
||
Args:
|
||
table_name: 表名
|
||
data: 要更新的数据字典,必须包含主键字段和值
|
||
primary_key: 主键字段名
|
||
session: 可选的数据库会话对象,不提供时将自动创建
|
||
|
||
Returns:
|
||
int: 受影响的行数
|
||
|
||
Raises:
|
||
ValueError: 当参数验证失败时
|
||
Exception: 当更新操作失败时
|
||
"""
|
||
if not table_name or not data:
|
||
raise ValueError("表名和数据不能为空")
|
||
|
||
if primary_key not in data:
|
||
raise ValueError(f"数据字典中必须包含主键字段 '{primary_key}'")
|
||
|
||
# 获取主键值
|
||
primary_key_value = data[primary_key]
|
||
|
||
# logger.info(f"准备更新表 {table_name} 中的数据,主键: {primary_key}={primary_key_value}")
|
||
|
||
# 构建更新字段部分,排除主键
|
||
update_fields = [f"{key} = :{key}" for key in data.keys() if key != primary_key]
|
||
if not update_fields:
|
||
logger.warning("没有需要更新的字段")
|
||
return 0
|
||
|
||
update_sql_part = ', '.join(update_fields)
|
||
|
||
# 构建完整的更新SQL
|
||
sql = f"UPDATE {table_name} SET {update_sql_part} WHERE {primary_key} = :{primary_key}"
|
||
|
||
logger.debug(f"生成的更新SQL: {sql}")
|
||
logger.debug(f"更新参数: {data}")
|
||
|
||
is_own_session = False
|
||
try:
|
||
# 如果没有提供会话,创建新会话
|
||
if session is None:
|
||
session = await self.get_session()
|
||
is_own_session = True
|
||
|
||
# 执行更新
|
||
result = await session.execute(text(sql), data)
|
||
await session.commit()
|
||
|
||
affected_rows = result.rowcount if hasattr(result, 'rowcount') else 0
|
||
# logger.info(f"数据更新成功,受影响行数: {affected_rows}")
|
||
return affected_rows
|
||
except Exception as e:
|
||
logger.error(f"数据更新失败: {str(e)}")
|
||
if session:
|
||
await session.rollback()
|
||
raise
|
||
finally:
|
||
# 如果是自己创建的会话,负责关闭
|
||
if is_own_session and session:
|
||
await session.close()
|
||
|
||
@db_retry()
|
||
async def batch_insert(self, table_name, data_list, primary_key=None, session=None):
|
||
"""
|
||
批量插入数据到指定表(异步版本)
|
||
|
||
Args:
|
||
table_name: 表名
|
||
data_list: 要插入的数据字典列表
|
||
primary_key: 主键字段名,如果提供,将返回插入后的主键值列表
|
||
session: 可选的数据库会话对象,不提供时将自动创建
|
||
|
||
Returns:
|
||
list: 插入后的主键值列表,如果不提供primary_key则返回None
|
||
|
||
Raises:
|
||
ValueError: 当参数验证失败时
|
||
Exception: 当插入操作失败时
|
||
"""
|
||
if not table_name or not data_list:
|
||
raise ValueError("表名和数据列表不能为空")
|
||
|
||
if not isinstance(data_list, list) or not data_list:
|
||
raise ValueError("数据列表必须是非空的列表")
|
||
|
||
# 获取第一个数据字典的键作为列名
|
||
columns = data_list[0].keys()
|
||
columns_str = ', '.join(columns)
|
||
|
||
# 构建占位符部分
|
||
placeholders_list = []
|
||
all_params = {}
|
||
|
||
# 为每一行数据生成占位符和参数
|
||
for i, data in enumerate(data_list):
|
||
# 确保每一行数据都有相同的键
|
||
if data.keys() != columns:
|
||
raise ValueError(f"第{i+1}行数据的字段不匹配")
|
||
|
||
# 生成该行的占位符,如 (:name_1, :age_1)
|
||
row_placeholders = [f":{key}_{i}" for key in columns]
|
||
placeholders_list.append(f"({', '.join(row_placeholders)})")
|
||
|
||
# 更新参数
|
||
for key, value in data.items():
|
||
all_params[f"{key}_{i}"] = value
|
||
|
||
# 构建完整的批量插入SQL
|
||
values_clause = ', '.join(placeholders_list)
|
||
if primary_key:
|
||
sql = f"INSERT INTO {table_name} ({columns_str}) VALUES {values_clause} RETURNING {primary_key}"
|
||
else:
|
||
sql = f"INSERT INTO {table_name} ({columns_str}) VALUES {values_clause}"
|
||
|
||
logger.debug(f"生成的批量插入SQL: {sql}")
|
||
|
||
is_own_session = False
|
||
try:
|
||
# 如果没有提供会话,创建新会话
|
||
if session is None:
|
||
session = await self.get_session()
|
||
is_own_session = True
|
||
|
||
# 执行批量插入
|
||
result = await session.execute(text(sql), all_params)
|
||
await session.commit()
|
||
|
||
# 如果指定了主键,返回插入的主键值列表
|
||
if primary_key:
|
||
return [row[0] for row in result]
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"批量插入失败: {str(e)}")
|
||
if session:
|
||
await session.rollback()
|
||
raise
|
||
finally:
|
||
# 如果是自己创建的会话,负责关闭
|
||
if is_own_session and session:
|
||
await session.close()
|
||
|
||
@db_retry()
|
||
async def batch_update(self, table_name, data_list, primary_key, session=None):
|
||
"""
|
||
批量更新指定表中的数据(异步版本)
|
||
|
||
Args:
|
||
table_name: 表名
|
||
data_list: 要更新的数据字典列表,每个字典必须包含主键字段和值
|
||
primary_key: 主键字段名
|
||
session: 可选的数据库会话对象,不提供时将自动创建
|
||
|
||
Returns:
|
||
int: 受影响的总行数
|
||
|
||
Raises:
|
||
ValueError: 当参数验证失败时
|
||
Exception: 当更新操作失败时
|
||
"""
|
||
if not table_name or not data_list or not primary_key:
|
||
raise ValueError("表名、数据列表和主键不能为空")
|
||
|
||
if not isinstance(data_list, list) or not data_list:
|
||
raise ValueError("数据列表必须是非空的列表")
|
||
|
||
is_own_session = False
|
||
total_affected = 0
|
||
|
||
try:
|
||
# 如果没有提供会话,创建新会话
|
||
if session is None:
|
||
session = await self.get_session()
|
||
is_own_session = True
|
||
|
||
# 逐个更新数据(PostgreSQL不支持标准的批量更新语法)
|
||
for data in data_list:
|
||
if primary_key not in data:
|
||
raise ValueError(f"数据字典中必须包含主键字段 '{primary_key}'")
|
||
|
||
# 构建更新字段部分,排除主键
|
||
update_fields = [f"{key} = :{key}" for key in data.keys() if key != primary_key]
|
||
if not update_fields:
|
||
continue # 没有需要更新的字段
|
||
|
||
update_sql_part = ', '.join(update_fields)
|
||
|
||
# 构建完整的更新SQL
|
||
sql = f"UPDATE {table_name} SET {update_sql_part} WHERE {primary_key} = :{primary_key}"
|
||
|
||
# 执行更新
|
||
result = await session.execute(text(sql), data)
|
||
total_affected += result.rowcount if hasattr(result, 'rowcount') else 0
|
||
|
||
# 提交事务
|
||
await session.commit()
|
||
return total_affected
|
||
except Exception as e:
|
||
logger.error(f"批量更新失败: {str(e)}")
|
||
if session:
|
||
await session.rollback()
|
||
raise
|
||
finally:
|
||
# 如果是自己创建的会话,负责关闭
|
||
if is_own_session and session:
|
||
await session.close()
|
||
|
||
async def begin_transaction(self, session=None):
|
||
"""
|
||
开始一个新的事务(异步版本)
|
||
|
||
Args:
|
||
session: 可选的数据库会话对象,不提供时将创建新会话
|
||
|
||
Returns:
|
||
Session: 数据库会话对象
|
||
"""
|
||
if session is None:
|
||
session = await self.get_session()
|
||
return session
|
||
|
||
async def commit_transaction(self, session):
|
||
"""
|
||
提交当前事务(异步版本)
|
||
|
||
Args:
|
||
session: 数据库会话对象
|
||
"""
|
||
if session:
|
||
await session.commit()
|
||
|
||
async def rollback_transaction(self, session):
|
||
"""
|
||
回滚当前事务(异步版本)
|
||
|
||
Args:
|
||
session: 数据库会话对象
|
||
"""
|
||
if session:
|
||
await session.rollback()
|
||
|
||
async def transaction(self):
|
||
"""
|
||
异步事务上下文管理器,用于简化异步事务管理
|
||
|
||
使用示例:
|
||
async with await db.transaction() as session:
|
||
await db.execute_update("INSERT INTO users (name) VALUES (:name)", {"name": "张三"}, session)
|
||
await db.execute_update("UPDATE statistics SET count = count + 1", {}, session)
|
||
"""
|
||
return TransactionContext(self)
|
||
|
||
|
||
# Db类结束
|