Files
aiData/DbKit/Db.py
HuangHai e51dc18d06 'commit'
2026-01-21 08:41:47 +08:00

1310 lines
56 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import logging
import re
import asyncio
import functools
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql import text
import Config
from DbKit.SqlTemplateLoader import SqlTemplateLoader
from DbKit.TransactionContext import TransactionContext
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def db_retry(max_retries=3, delay=1):
"""
数据库操作重试装饰器
Args:
max_retries: 最大重试次数默认为3次
delay: 重试间隔时间默认为1秒
"""
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
# 检查是否提供了 session如果提供了 session通常意味着是在一个更大的事务中不建议在这里重试
session = kwargs.get('session')
if session is not None:
return await func(*args, **kwargs)
last_exception = None
for attempt in range(max_retries):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
# 只有在还有重试机会时才打印警告
if attempt < max_retries - 1:
logger.warning(f"数据库操作失败 (尝试 {attempt + 1}/{max_retries}),正在进行重试: {str(e)}")
await asyncio.sleep(delay)
else:
logger.error(f"数据库操作在 {max_retries} 次尝试后仍然失败: {str(e)}")
# 如果循环结束仍未返回,说明最后一次尝试也失败了,抛出异常
if last_exception:
raise last_exception
return wrapper
return decorator
class Db:
"""通用数据库操作封装类,提供数据库连接和操作功能"""
# 单例实例
_instance = None
def __new__(cls, db_url=None, sql_dir=None):
"""确保Db类只有一个实例异步版本"""
if cls._instance is None:
cls._instance = super(Db, cls).__new__(cls)
# 初始化数据库连接
cls._instance.engine = None
cls._instance.AsyncSessionLocal = None
cls._instance.db_url = db_url
cls._instance.sql_dir = sql_dir or os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Sql')
# 初始化SQL模板加载器
cls._instance.sql_loader = SqlTemplateLoader(cls._instance.sql_dir, auto_load=False)
# 添加标志表示SQL模板尚未加载
cls._instance._templates_loaded = False
return cls._instance
def __init__(self, db_url=None, sql_dir=None):
"""初始化数据库操作类
Args:
db_url: 数据库连接URL不提供时将使用默认配置
sql_dir: SQL模板文件目录不提供时将使用默认目录
"""
# 单例模式下,初始化只会在第一次创建实例时执行
# 这里可以为空因为实际初始化在__new__中完成
self.sql_dir = sql_dir
self.db_url = db_url
async def init_db(self, db_url=None):
"""
初始化数据库连接(异步版本)
Args:
db_url: 数据库连接URL优先级高于初始化时提供的URL
"""
if self.engine:
return
# 确定使用的数据库连接URL
final_url = db_url or self.db_url
if not final_url:
# 如果没有提供URL使用配置文件中的参数构建连接URL
try:
from Config.Config import DB_URL
final_url = DB_URL
except ImportError:
# 兼容旧配置或Postgres配置
from Config.Config import POSTGRES_HOST, POSTGRES_PORT, POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_DATABASE
# 转换为异步连接URL
final_url = f"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DATABASE}"
else:
# 如果是同步URL转换为异步URL
if final_url.startswith("postgresql://"):
final_url = final_url.replace("postgresql://", "postgresql+asyncpg://")
# 创建异步数据库引擎
self.engine = create_async_engine(final_url)
# 创建异步会话工厂
self.AsyncSessionLocal = sessionmaker(
class_=AsyncSession,
autocommit=False,
autoflush=False,
bind=self.engine
)
# 检查Redis中是否已经加载了SQL模板
templates_loaded_in_redis = await self.sql_loader._check_templates_loaded_from_redis()
# 如果Redis中有加载的模板直接从Redis加载
if templates_loaded_in_redis:
try:
await self.sql_loader._load_from_redis()
self._templates_loaded = True
logger.info(f"从Redis加载SQL模板完成{len(self.sql_loader._loaded_files)}个文件,{len(self.sql_loader._template_map)}个模板")
except Exception as e:
logger.error(f"从Redis加载SQL模板失败: {str(e)}")
# 不抛出异常,允许继续执行
pass
# 如果Redis中没有且内存中也没有加载过执行正常加载流程
elif not getattr(self, '_templates_loaded', False):
try:
await self.load_all_templates()
self._templates_loaded = True
# 获取并记录SQL模板统计信息
loaded_files = len(self.sql_loader._loaded_files)
template_count = 0
for namespace, templates in self.sql_loader.templates.items():
template_count += len(templates)
logger.info(f"SQL模板加载完成共加载{loaded_files}个文件,{template_count}个模板")
except Exception as e:
logger.error(f"加载SQL模板失败: {str(e)}")
# 不抛出异常,允许数据库操作继续进行
pass
async def get_session(self):
"""获取数据库会话(异步版本)
Returns:
AsyncSession: 异步数据库会话对象
"""
if self.AsyncSessionLocal is None:
await self.init_db()
return self.AsyncSessionLocal()
async def close(self):
if hasattr(self, "engine") and self.engine is not None:
try:
await self.engine.dispose()
finally:
self.engine = None
self.AsyncSessionLocal = None
async def load_all_templates(self):
"""异步加载所有SQL模板
Returns:
self: 支持链式调用
"""
await self.sql_loader.load_all_templates()
return self
def _is_sql_template(self, sql_input):
"""
判断输入的带点字符串是SQL模板还是SQL语句
Args:
sql_input: 输入的字符串可能是SQL模板名称或SQL语句
Returns:
bool: 如果是SQL模板返回True如果是SQL语句返回False
"""
# 如果输入不是字符串或者不包含点号直接判定为SQL语句
if not isinstance(sql_input, str) or '.' not in sql_input:
return False
# 检查sql_loader是否存在且已初始化
if not hasattr(self, 'sql_loader') or not self.sql_loader:
return False
# 检查模板映射是否存在
if not hasattr(self.sql_loader, '_template_map') or not self.sql_loader._template_map:
return False
# 在模板映射中查找如果找到就是SQL模板否则是SQL语句
is_template = sql_input in self.sql_loader._template_map
logger.debug(f"输入 '{sql_input}' 在模板映射中查找结果: {is_template}")
return is_template
async def get_sql(self, sql_identifier, params=None):
"""获取指定名称的SQL模板并处理参数替换
Args:
sql_identifier: 格式为'namespace.template_name'的SQL标识符
params: 可选的参数字典,用于替换模板中的参数
Returns:
str: 处理后的SQL语句
"""
# logger.info(f"尝试获取SQL模板: {sql_identifier}")
sql_template = await self.sql_loader.get_sql(sql_identifier)
# 处理模板,即使没有提供参数也进行处理(清理条件块等)
processed_sql = self._process_sql_template(sql_template, params)
logger.debug(f"处理后的SQL语句: {processed_sql}")
return processed_sql
def _process_sql_template(self, sql_template, params=None):
"""
处理SQL模板中的条件判断和参数替换
Args:
sql_template: SQL模板字符串
params: 参数字典
Returns:
str: 处理后的SQL语句
"""
if params is None:
params = {}
# 检查是否包含#if条件块
if_matches = re.findall(r'#if\([^)]+\)', sql_template)
logger.debug(f"发现 {len(if_matches)} 个#if条件块")
for i, match in enumerate(if_matches, 1):
logger.debug(f" {i}. {match}")
# 检查是否包含ORDER BY子句
if 'ORDER BY' in sql_template:
logger.debug("SQL模板包含ORDER BY子句")
processed_sql = sql_template
# 处理#if条件判断
# 修复正则表达式,确保正确匹配条件块,支持 ! 取反
if_pattern = r'#if\((!?\w+)\)([\s\S]*?)#end'
# 使用finditer来获取匹配的位置信息
matches = list(re.finditer(if_pattern, processed_sql))
# 从后向前替换,避免位置偏移问题
for match in reversed(matches):
raw_param_name = match.group(1)
sql_block = match.group(2)
# 处理取反逻辑
is_negation = raw_param_name.startswith('!')
param_name = raw_param_name[1:] if is_negation else raw_param_name
# 检查参数是否存在且不为None (且不为空,如果需要更严格的检查)
# 这里保持原逻辑基础上增加对空值的判定因为空列表通常也意味着False
has_value = param_name in params and params[param_name] is not None and params[param_name]
# 决定是否保留代码块
keep_block = (not has_value) if is_negation else has_value
if keep_block:
# 如果条件满足保留SQL块
# 移除多余的空白行和缩进
lines = sql_block.strip().split('\n')
cleaned_lines = [line.strip() for line in lines if line.strip()]
cleaned_block = '\n'.join(cleaned_lines)
# 替换原始匹配的内容
processed_sql = processed_sql[:match.start()] + cleaned_block + processed_sql[match.end():]
logger.debug(f"保留条件块 #if({raw_param_name}),内容: {cleaned_block}")
else:
# 如果条件不满足,移除整个条件块
processed_sql = processed_sql[:match.start()] + "" + processed_sql[match.end():]
logger.debug(f"移除条件块 #if({raw_param_name})")
# 创建参数字典的副本,用于存储处理后的参数
if params is not None:
processed_params = params.copy()
else:
processed_params = {}
# 处理参数替换 - 支持直接的变量名格式
# 首先找出所有可能的参数占位符(#para格式和直接变量名格式
# 处理#para格式的参数
para_pattern = r'#para\((\w+)\)'
all_para_placeholders = re.findall(para_pattern, processed_sql)
# 替换#para格式的参数
for param_name in all_para_placeholders:
placeholder = f"#para({param_name})"
if placeholder in processed_sql and param_name in params and params[param_name] is not None:
# 检查参数是否为列表类型且用于IN操作
if isinstance(params[param_name], (list, tuple)) and "in" in processed_sql.lower():
logger.debug(f"检测到列表参数 {param_name} 用于IN操作进行特殊处理")
# 为列表中的每个元素创建单独的参数
list_values = params[param_name]
if list_values: # 确保列表不为空
# 生成多个占位符,如 :idList_1, :idList_2, :idList_3
placeholders = ', '.join([f":{param_name}_{i}" for i in range(len(list_values))])
processed_sql = processed_sql.replace(placeholder, placeholders)
# 更新参数字典,添加每个列表元素的参数
for i, value in enumerate(list_values):
processed_params[f"{param_name}_{i}"] = value
logger.debug(f"为列表参数 {param_name} 创建子参数 {param_name}_{i} = {value}")
else:
# 空列表情况将整个条件移除因为IN () 是无效的SQL
logger.warning(f"列表参数 {param_name} 为空,将移除包含该参数的条件表达式")
# 查找包含该占位符的整个条件表达式行
lines = processed_sql.split('\n')
new_lines = []
for line in lines:
if placeholder in line:
line_stripped = line.strip().upper()
if line_stripped.startswith('AND ') or line_stripped.startswith('OR '):
continue
new_lines.append(line)
processed_sql = '\n'.join(new_lines)
else:
# 普通参数处理
processed_sql = processed_sql.replace(placeholder, f":{param_name}")
logger.debug(f"已将参数 {placeholder} 转换为 :{param_name}")
elif placeholder in processed_sql:
# 参数缺失或为None移除包含该参数的条件表达式
logger.warning(f"参数 {param_name} 缺失或为None将移除包含该参数的条件表达式")
lines = processed_sql.split('\n')
new_lines = []
for line in lines:
if placeholder in line:
line_stripped = line.strip().upper()
if line_stripped.startswith('AND ') or line_stripped.startswith('OR '):
continue
new_lines.append(line)
processed_sql = '\n'.join(new_lines)
# 处理直接的变量名格式(如 $person_id 或 #{person_id}
# 尝试匹配各种常见的变量名格式
variable_patterns = [
r'\$([a-zA-Z_][a-zA-Z0-9_]*)', # $variable 格式
r'#\{([a-zA-Z_][a-zA-Z0-9_]*)\}', # #{variable} 格式
r'\{([a-zA-Z_][a-zA-Z0-9_]*)\}' # {variable} 格式
]
for pattern in variable_patterns:
matches = re.finditer(pattern, processed_sql)
# 收集所有匹配并从后向前替换,避免位置偏移
variable_matches = list(matches)
for match in reversed(variable_matches):
placeholder = match.group(0)
param_name = match.group(1)
# 检查参数是否存在且不为None
if param_name in params and params[param_name] is not None:
# 替换为SQLAlchemy参数格式
replacement = f":{param_name}"
processed_sql = processed_sql[:match.start()] + replacement + processed_sql[match.end():]
logger.debug(f"已将变量 {placeholder} 转换为 {replacement}")
else:
# 变量不存在或为None记录警告
logger.warning(f"变量 {param_name} 缺失或为None保留原样")
# 更新params使其包含处理后的参数包括扩展的列表参数
if params is not None:
for key, value in processed_params.items():
params[key] = value
# 清理多余的空白行和空格
lines = processed_sql.split('\n')
processed_lines = [line.strip() for line in lines if line.strip()]
processed_sql = '\n'.join(processed_lines)
# 额外的清理:处理可能产生的孤立逻辑运算符
# 例如如果我们移除了一行可能会留下单独的AND/OR
processed_sql = re.sub(r'(^|\n)\s*(AND|OR)\s*($|\n)', '\1\3', processed_sql)
# 最终清理
lines = processed_sql.split('\n')
processed_lines = [line.strip() for line in lines if line.strip()]
processed_sql = '\n'.join(processed_lines)
return processed_sql
@db_retry()
async def find(self, sql, params=None, session=None):
"""
执行SQL查询并返回结果异步版本
Args:
sql: SQL查询语句或SQL模板名称支持namespace.sql_name格式
params: 查询参数,字典形式
session: 可选的数据库会话对象,不提供时将自动创建
Returns:
list: 查询结果的字典列表
"""
# 使用当前实例
instance = self
# 如果没有提供params创建一个空字典
if params is None:
params = {}
# 确保数据库连接已初始化
if instance.AsyncSessionLocal is None:
await instance.init_db()
# 检查是否需要加载SQL模板
# 首先检查Redis中是否有加载的模板
if not getattr(instance, '_templates_loaded', False):
templates_loaded_in_redis = await instance.sql_loader._check_templates_loaded_from_redis()
if templates_loaded_in_redis:
try:
await instance.sql_loader._load_from_redis()
instance._templates_loaded = True
logger.info("从Redis自动加载SQL模板完成")
except Exception as e:
logger.error(f"从Redis自动加载SQL模板失败: {str(e)}")
else:
# Redis中没有尝试正常加载
try:
await instance.load_all_templates()
instance._templates_loaded = True
logger.info("SQL模板自动加载完成")
except Exception as e:
logger.error(f"自动加载SQL模板失败: {str(e)}")
# 不抛出异常,允许数据库操作继续进行
pass
# 检查是否为SQL模板名称并处理
# 即使不是模板名称,如果包含点号,我们也尝试在加载后检查它是否真的是模板
is_template = self._is_sql_template(sql)
if not is_template and '.' in sql:
# 如果包含点号但没在映射中找到,可能是因为缓存加载问题,尝试重新加载映射
if hasattr(self.sql_loader, '_build_template_map'):
self.sql_loader._build_template_map()
is_template = self._is_sql_template(sql)
if is_template:
try:
logger.debug(f"确认为SQL模板: {sql}")
# 尝试获取SQL模板
sql_content = await instance.get_sql(sql, params)
# 如果成功获取到内容且不等于原始名称,说明是模板
if sql_content and sql_content != sql:
sql = sql_content
logger.debug(f"成功获取SQL模板处理后的SQL: {sql}")
# 对于模板SQL参数已经被get_sql处理过了不需要再处理
# params = None
else:
logger.debug(f"获取到的内容与输入相同,可能不是模板: {sql_content}")
except Exception as e:
logger.error(f"获取SQL模板失败: {str(e)}将作为普通SQL执行")
# 继续执行将其作为普通SQL处理
# 只有当SQL不是从模板获取时才处理普通SQL中的#para()格式参数
# 因为模板已经通过get_sql方法处理了参数替换
if params and isinstance(params, dict) and isinstance(sql, str):
# 检查SQL是否还包含#para()格式(如果是模板,应该已经被处理过了)
if '#para(' in sql:
logger.debug("检测到SQL中仍包含#para()格式,进行参数替换")
# 替换SQL中的#para(param_name)为:param_name
for param_name, param_value in list(params.items()): # 使用list()创建副本,避免在遍历时修改
para_pattern = f'#para\\({param_name}\\)' # 转义括号
if re.search(para_pattern, sql):
# 检查参数是否为列表类型且用于IN操作
if isinstance(param_value, (list, tuple)) and "in" in sql.lower():
logger.debug(f"检测到普通SQL中的列表参数 {param_name} 用于IN操作进行特殊处理")
# 为列表中的每个元素创建单独的参数
list_values = param_value
if list_values: # 确保列表不为空
# 生成多个占位符,如 :idList_1, :idList_2, :idList_3
placeholders = ', '.join([f":{param_name}_{i}" for i in range(len(list_values))])
sql = re.sub(para_pattern, placeholders, sql)
# 更新参数字典,添加每个列表元素的参数
for i, value in enumerate(list_values):
params[f"{param_name}_{i}"] = value
logger.debug(f"为列表参数 {param_name} 创建子参数 {param_name}_{i} = {value}")
else:
# 空列表情况将整个条件移除因为IN () 是无效的SQL
logger.warning(f"列表参数 {param_name} 为空,将移除包含该参数的条件表达式")
# 查找包含该占位符的整个条件表达式行
lines = sql.split('\n')
new_lines = []
for line in lines:
if f"#para({param_name})" in line:
line_stripped = line.strip().upper()
if line_stripped.startswith('AND ') or line_stripped.startswith('OR '):
continue
new_lines.append(line)
sql = '\n'.join(new_lines)
else:
# 普通参数处理
sql = re.sub(para_pattern, f':{param_name}', sql)
logger.debug(f"已替换参数占位符: #para({param_name}) -> :{param_name}")
else:
logger.debug("SQL模板已处理过参数跳过重复处理")
is_own_session = False
try:
# 如果没有提供会话,创建新会话
if session is None:
session = await instance.get_session()
is_own_session = True
# 执行查询
# logger.info(f"执行查询: {sql},参数: {params}")
result = await session.execute(text(sql), params or {})
# 将结果转换为字典列表
return [dict(row._mapping) for row in result]
except Exception as e:
logger.error(f"查询执行失败: {str(e)}")
raise
finally:
# 如果是自己创建的会话,负责关闭
if is_own_session and session:
await session.close()
async def findFirst(self, sql, params=None, session=None):
"""
执行查询并返回结果的第一条记录(异步版本)
Args:
sql: SQL查询语句或SQL模板名称支持namespace.sql_name格式
params: 查询参数,字典形式
session: 可选的数据库会话对象,不提供时将自动创建
Returns:
dict: 第一条记录的字典如果没有结果则返回None
"""
# 直接调用异步的find方法它已经使用_is_sql_template进行了模板检查
rows = await self.find(sql, params, session)
return rows[0] if rows else None
async def findById(self, table_name, primary_key_value, primary_key='id', session=None):
"""
根据主键查询指定表中的记录(异步版本)
Args:
table_name: 表名
primary_key_value: 主键值
primary_key: 主键字段名,默认为'id'
session: 可选的数据库会话对象,不提供时将自动创建
Returns:
dict: 查询到的记录字典如果没有找到则返回None
"""
if not table_name:
raise ValueError("表名不能为空")
# 构建查询SQL
sql = f"SELECT * FROM {table_name} WHERE {primary_key} = :primary_key_value"
# 准备参数
params = {"primary_key_value": primary_key_value}
logger.debug(f"生成的查询SQL: {sql}")
logger.debug(f"查询参数: {params}")
# 异步使用findFirst方法执行查询
return await self.findFirst(sql, params, session)
async def paginate(self, sql, page_number=1, page_size=10, params=None, session=None):
"""
执行SQL查询并返回分页结果异步版本
Args:
sql: SQL查询语句或SQL模板名称支持namespace.sql_name格式
page_number: 页码从1开始
page_size: 每页记录数
params: 查询参数,字典形式
session: 可选的数据库会话对象,不提供时将自动创建
Returns:
dict: 包含分页信息的字典,格式如下:
{
'list': 当前页的数据列表,
'pageNumber': 当前页码,
'pageSize': 每页记录数,
'totalRow': 总记录数,
'totalPage': 总页数
}
"""
# 确保数据库已初始化
if self.AsyncSessionLocal is None:
await self.init_db()
# 初始化诊断日志
logger.debug("======= 开始分页查询 =======")
# 如果没有提供params创建一个空字典
if params is None:
params = {}
logger.debug(f"传入的参数: {params}")
# 确保页码和每页大小为正整数
try:
page_number = int(page_number)
page_size = int(page_size)
except (ValueError, TypeError):
logger.error(f"页码或每页大小类型错误,重置为默认值")
page_number = 1
page_size = 10
if page_number < 1:
page_number = 1
if page_size < 1:
page_size = 10
logger.debug(f"处理后的页码: {page_number}, 每页大小: {page_size}")
# 使用_is_sql_template方法检查并获取SQL模板
is_template = self._is_sql_template(sql)
if not is_template and '.' in sql:
# 如果包含点号但没在映射中找到,可能是因为缓存加载问题,尝试重新加载映射
if hasattr(self.sql_loader, '_build_template_map'):
self.sql_loader._build_template_map()
is_template = self._is_sql_template(sql)
if is_template:
try:
logger.debug(f"确认为SQL模板: {sql}")
# 尝试获取SQL模板
sql_content = await self.get_sql(sql, params)
# 如果成功获取到内容且不等于原始名称,说明是模板
if sql_content and sql_content != sql:
sql = sql_content
logger.debug(f"成功获取SQL模板处理后的SQL: {sql}")
else:
logger.debug(f"获取到的内容与输入相同,可能不是模板: {sql_content}")
except Exception as e:
logger.error(f"获取SQL模板失败: {str(e)}将作为普通SQL执行")
# 继续执行将其作为普通SQL处理
# 只有当SQL不是从模板获取时才处理普通SQL中的#para()格式参数
# 因为模板已经通过get_sql方法处理了参数替换
if params and isinstance(params, dict) and isinstance(sql, str):
# 检查SQL是否还包含#para()格式(如果是模板,应该已经被处理过了)
if '#para(' in sql:
logger.debug("检测到SQL中仍包含#para()格式,进行参数替换")
# 替换SQL中的#para(param_name)为:param_name
for param_name, param_value in list(params.items()): # 使用list()创建副本,避免在遍历时修改
para_pattern = f'#para\\({param_name}\\)' # 转义括号
if re.search(para_pattern, sql):
# 检查参数是否为列表类型且用于IN操作
if isinstance(param_value, (list, tuple)) and "in" in sql.lower():
logger.debug(f"检测到分页查询中的列表参数 {param_name} 用于IN操作进行特殊处理")
# 为列表中的每个元素创建单独的参数
list_values = param_value
if list_values: # 确保列表不为空
# 生成多个占位符,如 :idList_1, :idList_2, :idList_3
placeholders = ', '.join([f":{param_name}_{i}" for i in range(len(list_values))])
sql = re.sub(para_pattern, placeholders, sql)
# 更新参数字典,添加每个列表元素的参数
for i, value in enumerate(list_values):
params[f"{param_name}_{i}"] = value
logger.debug(f"为列表参数 {param_name} 创建子参数 {param_name}_{i} = {value}")
else:
# 空列表情况将整个条件移除因为IN () 是无效的SQL
logger.warning(f"列表参数 {param_name} 为空,将移除包含该参数的条件表达式")
# 查找包含该占位符的整个条件表达式行
lines = sql.split('\n')
new_lines = []
for line in lines:
if f"#para({param_name})" in line:
line_stripped = line.strip().upper()
if line_stripped.startswith('AND ') or line_stripped.startswith('OR '):
continue
new_lines.append(line)
sql = '\n'.join(new_lines)
else:
# 普通参数处理
sql = re.sub(para_pattern, f':{param_name}', sql)
logger.debug(f"已替换参数占位符: #para({param_name}) -> :{param_name}")
else:
logger.debug("SQL模板已处理过参数跳过重复处理")
is_own_session = False
try:
# 如果没有提供会话,创建新会话
if session is None:
session = await self.get_session()
is_own_session = True
# 计算总记录数
count_sql = self._build_count_sql(sql)
logger.debug(f"构建的COUNT SQL: {count_sql}")
try:
count_result = await session.execute(text(count_sql), params or {})
count_value = count_result.scalar()
# 确保total_row是正确的整数值
try:
total_row = int(count_value) if count_value is not None else 0
except (ValueError, TypeError):
total_row = 0
except Exception as e:
logger.error(f"COUNT查询失败: {str(e)}")
total_row = 0
# 如果是自己创建的会话当COUNT查询失败时回滚事务并重新创建会话
if is_own_session:
logger.debug("COUNT查询失败回滚并重新创建会话")
await session.rollback()
await session.close()
session = await self.get_session()
# 添加详细日志记录总记录数和每页大小
logger.debug(f"总记录数: {total_row}, 每页大小: {page_size}")
# 确保total_row是整数
total_row = int(total_row) if total_row is not None else 0
# 计算总页数(使用更清晰的计算方式)
if page_size > 0:
total_page = (total_row + page_size - 1) // page_size
else:
total_page = 0
logger.debug(f"计算的总页数: {total_page}")
# 计算偏移量
offset = (page_number - 1) * page_size
# 构建分页SQLPostgreSQL语法
# 首先检查SQL是否已经包含LIMIT或OFFSET子句
has_limit = re.search(r'(?i)\bLIMIT\b', sql)
has_offset = re.search(r'(?i)\bOFFSET\b', sql)
# 如果已经包含LIMIT或OFFSET直接使用原始SQL
if has_limit or has_offset:
paginated_sql = sql
else:
# 否则在SQL末尾添加LIMIT和OFFSET
# 不需要复杂的ORDER BY处理PostgreSQL会自动处理顺序
paginated_sql = f"{sql} LIMIT :limit OFFSET :offset"
logger.debug(f"构建的分页SQL: {paginated_sql}")
# 添加分页参数
paginated_params = params.copy() if params else {}
paginated_params['limit'] = page_size
paginated_params['offset'] = offset
logger.debug(f"分页查询参数: {paginated_params}")
# 执行分页查询
result = await session.execute(text(paginated_sql), paginated_params)
# 将结果转换为字典列表
list_data = [dict(row._mapping) for row in result]
# 确保total_row是整数
try:
total_row_int = int(total_row)
except Exception as e:
total_row_int = 0
# 计算总页数
if page_size > 0:
remainder = total_row_int % page_size
quotient = total_row_int // page_size
calculated_pages = quotient + 1 if remainder > 0 else quotient
else:
calculated_pages = 0
# 构建返回结果,使用正确的页数
pagination_result = {
'list': list_data if list_data is not None else [],
'page': page_number,
'page_size': page_size,
'total': total_row,
'total_pages': calculated_pages # 使用正确的页数
}
return pagination_result
except Exception as e:
logger.error(f"分页查询失败: {str(e)}")
raise
finally:
# 如果是自己创建的会话,负责关闭
if is_own_session and session:
await session.close()
def _build_count_sql(self, sql):
"""
根据原始SQL构建用于计算总记录数的SQL
Args:
sql: 原始SQL查询语句
Returns:
str: 计算总记录数的SQL
"""
import re
# 清理SQL移除前后空白
sql = sql.strip()
# 策略1: 使用正则表达式查找FROM关键字更灵活
match = re.search(r'(?i)\bFROM\b', sql)
if match:
from_pos = match.start()
# 提取FROM后面的内容包括FROM关键字
from_part = sql[from_pos:]
# 检查是否包含ORDER BY子句如果有移除
order_by_match = re.search(r'(?i)\bORDER\s+BY\b', from_part)
if order_by_match:
order_by_pos = order_by_match.start()
from_part = from_part[:order_by_pos]
# 检查是否包含GROUP BY子句
if re.search(r'(?i)\bGROUP\s+BY\b', from_part):
# 对于复杂查询,使用子查询来计算总数
subquery_sql = sql[:order_by_pos] if order_by_match else sql
return f"SELECT COUNT(*) FROM ({subquery_sql}) AS subquery"
else:
# 简单查询直接替换SELECT部分
return f"SELECT COUNT(*) {from_part}"
# 策略2: 使用多种空格格式查找FROM关键字
from_variants = [' FROM ', ' FROM', 'FROM ', 'FROM']
sql_upper = sql.upper()
for variant in from_variants:
if variant.upper() in sql_upper:
from_pos = sql_upper.find(variant.upper())
from_part = sql[from_pos:]
# 移除ORDER BY
if ' ORDER BY ' in from_part.upper():
order_by_pos = from_part.upper().rfind(' ORDER BY ')
from_part = from_part[:order_by_pos]
return f"SELECT COUNT(*) {from_part}"
# 如果所有策略都失败,返回默认查询
return "SELECT COUNT(*)"
async def check_column_exists(self, table_name, column_name):
"""检查表中是否存在指定列
Args:
table_name: 表名
column_name: 列名
Returns:
bool: 如果列存在返回True否则返回False
"""
try:
# 兼容 Doris 和 MySQL 的语法
sql = f"SHOW COLUMNS FROM {table_name} LIKE :column_name"
params = {"column_name": column_name}
# 使用 find 方法执行查询
result = await self.find(sql, params)
return len(result) > 0
except Exception as e:
logger.debug(f"检查列是否存在时出错 (可能表不存在): {str(e)}")
return False
@db_retry()
async def execute_update(self, sql, params=None, session=None):
"""执行SQL更新操作插入、更新、删除异步版本
Args:
sql: SQL更新语句或SQL模板名称支持namespace.sql_name格式
params: 更新参数,字典形式
session: 可选的数据库会话对象,不提供时将自动创建
Returns:
int: 受影响的行数
"""
# 确保数据库已初始化
if self.AsyncSessionLocal is None:
await self.init_db()
# 如果没有提供params创建一个空字典
if params is None:
params = {}
# 检查是否为SQL模板名称并处理
if self._is_sql_template(sql):
try:
logger.debug(f"确认为SQL模板: {sql}")
# 尝试获取SQL模板
sql_content = await self.get_sql(sql, params)
# 如果成功获取到内容且不等于原始名称,说明是模板
if sql_content and sql_content != sql:
sql = sql_content
logger.debug(f"成功获取SQL模板处理后的SQL: {sql}")
# 对于模板SQL参数已经被get_sql处理过了不需要再处理
params = None
else:
logger.debug(f"获取到的内容与输入相同,可能不是模板: {sql_content}")
except Exception as e:
logger.error(f"获取SQL模板失败: {str(e)}将作为普通SQL执行")
# 继续执行将其作为普通SQL处理
# 只有当SQL不是从模板获取时才处理普通SQL中的#para()格式参数
# 因为模板已经通过get_sql方法处理了参数替换
if params and isinstance(params, dict):
# 检查SQL是否还包含#para()格式(如果是模板,应该已经被处理过了)
if '#para(' in sql:
logger.debug("检测到SQL中仍包含#para()格式,进行参数替换")
# 替换SQL中的#para(param_name)为:param_name
for param_name, param_value in list(params.items()): # 使用list()创建副本,避免在遍历时修改
para_pattern = f'#para\\({param_name}\\)' # 转义括号
if re.search(para_pattern, sql):
# 检查参数是否为列表类型且用于IN操作
if isinstance(param_value, (list, tuple)) and "in" in sql.lower():
logger.debug(f"检测到更新SQL中的列表参数 {param_name} 用于IN操作进行特殊处理")
# 为列表中的每个元素创建单独的参数
list_values = param_value
if list_values: # 确保列表不为空
# 生成多个占位符,如 :idList_1, :idList_2, :idList_3
placeholders = ', '.join([f":{param_name}_{i}" for i in range(len(list_values))])
sql = re.sub(para_pattern, placeholders, sql)
# 更新参数字典,添加每个列表元素的参数
for i, value in enumerate(list_values):
params[f"{param_name}_{i}"] = value
logger.debug(f"为列表参数 {param_name} 创建子参数 {param_name}_{i} = {value}")
else:
# 空列表情况将整个条件移除因为IN () 是无效的SQL
logger.warning(f"列表参数 {param_name} 为空,将移除包含该参数的条件表达式")
# 查找包含该占位符的整个条件表达式行
lines = sql.split('\n')
new_lines = []
for line in lines:
if f"#para({param_name})" in line:
line_stripped = line.strip().upper()
if line_stripped.startswith('AND ') or line_stripped.startswith('OR '):
continue
new_lines.append(line)
sql = '\n'.join(new_lines)
else:
# 普通参数处理
sql = re.sub(para_pattern, f':{param_name}', sql)
logger.debug(f"已替换参数占位符: #para({param_name}) -> :{param_name}")
else:
logger.debug("SQL模板已处理过参数跳过重复处理")
is_own_session = False
try:
# 如果没有提供会话,创建新会话
if session is None:
session = await self.get_session()
is_own_session = True
# 执行更新
# logger.info(f"执行更新: {sql},参数: {params}")
result = await session.execute(text(sql), params or {})
# 提交事务
await session.commit()
# 返回受影响的行数(如果有)
return result.rowcount if hasattr(result, 'rowcount') else 0
except Exception as e:
logger.error(f"更新操作失败: {str(e)}")
if session:
await session.rollback()
raise
finally:
# 如果是自己创建的会话,负责关闭
if is_own_session and session:
await session.close()
@db_retry()
async def save(self, table_name, data, primary_key, session=None):
"""插入数据到指定表,并返回插入后的主键值(异步版本)
Args:
table_name: 表名
data: 要插入的数据字典
primary_key: 主键字段名
session: 可选的数据库会话对象,不提供时将自动创建
Returns:
插入后的主键值
Raises:
ValueError: 当参数验证失败时
Exception: 当插入操作失败时
"""
if not table_name or not data:
raise ValueError("表名和数据不能为空")
# 主键可能由数据库自动生成,不传递也可以
# 无需额外检查直接构建SQL语句
# 如果data中不包含primary_keySQL会自动处理
# logger.info(f"准备插入数据到表 {table_name},主键字段: {primary_key}")
# 构建插入SQL (移除 RETURNINGDoris/MySQL 不支持)
columns = ', '.join(data.keys())
placeholders = ', '.join([f":{key}" for key in data.keys()])
sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
logger.debug(f"生成的插入SQL: {sql}")
logger.debug(f"插入参数: {data}")
is_own_session = False
try:
# 如果没有提供会话,创建新会话
if session is None:
session = await self.get_session()
is_own_session = True
# 执行插入
result = await session.execute(text(sql), data)
await session.commit()
# 获取插入的主键值
# 如果 data 中包含主键,直接返回
if primary_key in data:
return data[primary_key]
# 否则尝试获取 lastrowid (某些 DB 支持)
try:
return result.lastrowid
except:
return None
except Exception as e:
logger.error(f"数据插入失败: {str(e)}")
if session:
await session.rollback()
raise
finally:
# 如果是自己创建的会话,负责关闭
if is_own_session and session:
await session.close()
@db_retry()
async def update(self, table_name, data, primary_key, session=None):
"""根据主键更新指定表中的数据(异步版本)
Args:
table_name: 表名
data: 要更新的数据字典,必须包含主键字段和值
primary_key: 主键字段名
session: 可选的数据库会话对象,不提供时将自动创建
Returns:
int: 受影响的行数
Raises:
ValueError: 当参数验证失败时
Exception: 当更新操作失败时
"""
if not table_name or not data:
raise ValueError("表名和数据不能为空")
if primary_key not in data:
raise ValueError(f"数据字典中必须包含主键字段 '{primary_key}'")
# 获取主键值
primary_key_value = data[primary_key]
# logger.info(f"准备更新表 {table_name} 中的数据,主键: {primary_key}={primary_key_value}")
# 构建更新字段部分,排除主键
update_fields = [f"{key} = :{key}" for key in data.keys() if key != primary_key]
if not update_fields:
logger.warning("没有需要更新的字段")
return 0
update_sql_part = ', '.join(update_fields)
# 构建完整的更新SQL
sql = f"UPDATE {table_name} SET {update_sql_part} WHERE {primary_key} = :{primary_key}"
logger.debug(f"生成的更新SQL: {sql}")
logger.debug(f"更新参数: {data}")
is_own_session = False
try:
# 如果没有提供会话,创建新会话
if session is None:
session = await self.get_session()
is_own_session = True
# 执行更新
result = await session.execute(text(sql), data)
await session.commit()
affected_rows = result.rowcount if hasattr(result, 'rowcount') else 0
# logger.info(f"数据更新成功,受影响行数: {affected_rows}")
return affected_rows
except Exception as e:
logger.error(f"数据更新失败: {str(e)}")
if session:
await session.rollback()
raise
finally:
# 如果是自己创建的会话,负责关闭
if is_own_session and session:
await session.close()
@db_retry()
async def batch_insert(self, table_name, data_list, primary_key=None, session=None):
"""
批量插入数据到指定表(异步版本)
Args:
table_name: 表名
data_list: 要插入的数据字典列表
primary_key: 主键字段名,如果提供,将返回插入后的主键值列表
session: 可选的数据库会话对象,不提供时将自动创建
Returns:
list: 插入后的主键值列表如果不提供primary_key则返回None
Raises:
ValueError: 当参数验证失败时
Exception: 当插入操作失败时
"""
if not table_name or not data_list:
raise ValueError("表名和数据列表不能为空")
if not isinstance(data_list, list) or not data_list:
raise ValueError("数据列表必须是非空的列表")
# 获取第一个数据字典的键作为列名
columns = data_list[0].keys()
columns_str = ', '.join(columns)
# 构建占位符部分
placeholders_list = []
all_params = {}
# 为每一行数据生成占位符和参数
for i, data in enumerate(data_list):
# 确保每一行数据都有相同的键
if data.keys() != columns:
raise ValueError(f"{i+1}行数据的字段不匹配")
# 生成该行的占位符,如 (:name_1, :age_1)
row_placeholders = [f":{key}_{i}" for key in columns]
placeholders_list.append(f"({', '.join(row_placeholders)})")
# 更新参数
for key, value in data.items():
all_params[f"{key}_{i}"] = value
# 构建完整的批量插入SQL
values_clause = ', '.join(placeholders_list)
if primary_key:
sql = f"INSERT INTO {table_name} ({columns_str}) VALUES {values_clause} RETURNING {primary_key}"
else:
sql = f"INSERT INTO {table_name} ({columns_str}) VALUES {values_clause}"
logger.debug(f"生成的批量插入SQL: {sql}")
is_own_session = False
try:
# 如果没有提供会话,创建新会话
if session is None:
session = await self.get_session()
is_own_session = True
# 执行批量插入
result = await session.execute(text(sql), all_params)
await session.commit()
# 如果指定了主键,返回插入的主键值列表
if primary_key:
return [row[0] for row in result]
return None
except Exception as e:
logger.error(f"批量插入失败: {str(e)}")
if session:
await session.rollback()
raise
finally:
# 如果是自己创建的会话,负责关闭
if is_own_session and session:
await session.close()
@db_retry()
async def batch_update(self, table_name, data_list, primary_key, session=None):
"""
批量更新指定表中的数据(异步版本)
Args:
table_name: 表名
data_list: 要更新的数据字典列表,每个字典必须包含主键字段和值
primary_key: 主键字段名
session: 可选的数据库会话对象,不提供时将自动创建
Returns:
int: 受影响的总行数
Raises:
ValueError: 当参数验证失败时
Exception: 当更新操作失败时
"""
if not table_name or not data_list or not primary_key:
raise ValueError("表名、数据列表和主键不能为空")
if not isinstance(data_list, list) or not data_list:
raise ValueError("数据列表必须是非空的列表")
is_own_session = False
total_affected = 0
try:
# 如果没有提供会话,创建新会话
if session is None:
session = await self.get_session()
is_own_session = True
# 逐个更新数据PostgreSQL不支持标准的批量更新语法
for data in data_list:
if primary_key not in data:
raise ValueError(f"数据字典中必须包含主键字段 '{primary_key}'")
# 构建更新字段部分,排除主键
update_fields = [f"{key} = :{key}" for key in data.keys() if key != primary_key]
if not update_fields:
continue # 没有需要更新的字段
update_sql_part = ', '.join(update_fields)
# 构建完整的更新SQL
sql = f"UPDATE {table_name} SET {update_sql_part} WHERE {primary_key} = :{primary_key}"
# 执行更新
result = await session.execute(text(sql), data)
total_affected += result.rowcount if hasattr(result, 'rowcount') else 0
# 提交事务
await session.commit()
return total_affected
except Exception as e:
logger.error(f"批量更新失败: {str(e)}")
if session:
await session.rollback()
raise
finally:
# 如果是自己创建的会话,负责关闭
if is_own_session and session:
await session.close()
async def begin_transaction(self, session=None):
"""
开始一个新的事务(异步版本)
Args:
session: 可选的数据库会话对象,不提供时将创建新会话
Returns:
Session: 数据库会话对象
"""
if session is None:
session = await self.get_session()
return session
async def commit_transaction(self, session):
"""
提交当前事务(异步版本)
Args:
session: 数据库会话对象
"""
if session:
await session.commit()
async def rollback_transaction(self, session):
"""
回滚当前事务(异步版本)
Args:
session: 数据库会话对象
"""
if session:
await session.rollback()
async def transaction(self):
"""
异步事务上下文管理器,用于简化异步事务管理
使用示例:
async with await db.transaction() as session:
await db.execute_update("INSERT INTO users (name) VALUES (:name)", {"name": "张三"}, session)
await db.execute_update("UPDATE statistics SET count = count + 1", {}, session)
"""
return TransactionContext(self)
# Db类结束