Files
aiData/DbKit/SqlTemplateLoader.py

533 lines
23 KiB
Python
Raw Normal View History

2026-01-12 07:49:18 +08:00
import logging
import os
import re
import asyncio
import aiofiles
import json
# 导入RedisKit
from Util.RedisKit import redisKit
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Redis键名常量
REDIS_SQL_TEMPLATES_KEY = "sql_templates:all_templates"
REDIS_SQL_LOADED_FILES_KEY = "sql_templates:loaded_files"
REDIS_SQL_TEMPLATE_MAP_KEY = "sql_templates:template_map"
REDIS_SQL_TEMPLATES_LOADED_KEY = "sql_templates:templates_loaded"
class SqlTemplateLoader:
"""SQL模板加载器负责加载和解析SQL模板文件异步版本"""
def __init__(self, sql_dir=None, auto_load=True):
"""初始化SQL模板加载器
Args:
sql_dir: SQL模板文件所在目录
auto_load: 是否自动加载目录下所有SQL模板文件
"""
self.sql_dir = sql_dir
self.templates = {} # 使用嵌套字典存储命名空间下的SQL模板
self._loaded_files = set() # 用于记录已加载的文件,避免重复加载
self._template_map = {} # 使用平面字典存储完整的模板名称映射,用于快速查找
# 注意异步版本不会在__init__中自动加载需要单独调用异步方法
async def load_all_templates(self):
"""自动加载SQL目录下所有的SQL模板文件异步版本
Returns:
int: 成功加载的模板文件数量
"""
logger.info(f"开始加载SQL模板目录: {self.sql_dir}")
if not os.path.exists(self.sql_dir):
logger.warning(f"SQL模板目录不存在: {self.sql_dir}")
return 0
# 检查是否已经从Redis加载过
templates_loaded = await self._check_templates_loaded_from_redis()
if templates_loaded:
await self._load_from_redis()
logger.info(f"从Redis加载SQL模板完成{len(self._loaded_files)}个文件,{len(self._template_map)}个模板")
return len(self._loaded_files)
loaded_count = 0
sql_files = []
try:
# 遍历目录下所有.sql文件
for filename in os.listdir(self.sql_dir):
if filename.lower().endswith('.sql'):
file_path = os.path.join(self.sql_dir, filename)
file_abs_path = os.path.abspath(file_path)
sql_files.append(file_abs_path)
# 使用异步方式加载所有模板,不立即构建映射以提高性能
for file_abs_path in sql_files:
# 跳过已经加载过的文件
if file_abs_path in self._loaded_files:
logger.debug(f"SQL文件已加载跳过: {file_abs_path}")
continue
try:
filename = os.path.basename(file_abs_path)
if await self.load_template(file_abs_path, build_map=False):
loaded_count += 1
except Exception as e:
logger.error(f"加载SQL模板文件 {filename} 失败: {str(e)}")
# 所有模板加载完成后,统一构建映射
if loaded_count > 0:
self._build_template_map()
logger.info(f"成功加载 {loaded_count} 个SQL模板文件已构建模板映射")
# 加载完成后保存到Redis
await self._save_to_redis()
return loaded_count
except Exception as e:
logger.error(f"遍历SQL模板目录时出错: {str(e)}")
return 0
async def load_template(self, template_path, build_map=True):
"""加载指定的SQL模板文件异步版本
Args:
template_path: SQL模板文件的路径可以是绝对路径或相对于sql_dir的相对路径
build_map: 是否立即构建模板映射默认为True在批量加载时可设为False以提高性能
Returns:
bool: 加载是否成功
Raises:
FileNotFoundError: 当模板文件不存在时
Exception: 当加载过程发生其他错误时
"""
# 确定模板文件的完整路径
if self.sql_dir and not os.path.isabs(template_path):
file_path = os.path.join(self.sql_dir, template_path)
else:
file_path = template_path
# 检查文件是否已加载,避免重复加载
file_abs_path = os.path.abspath(file_path)
if file_abs_path in self._loaded_files:
logger.debug(f"SQL模板文件已加载跳过: {file_abs_path}")
return True
try:
# 异步读取文件内容
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f:
content = await f.read()
# 使用栈结构解析整个内容
templates = self._parse_with_stack(content)
# 处理命名空间
if 'namespace' in templates and templates['namespace']:
namespace = templates['namespace']
sql_templates = {k: v for k, v in templates.items() if k != 'namespace'}
else:
# 没有命名空间的情况
namespace = 'default'
sql_templates = templates
# 将解析出的SQL模板添加到templates字典中
if namespace not in self.templates:
self.templates[namespace] = {}
self.templates[namespace].update(sql_templates)
# 重新构建模板映射,确保新加载的模板可以被快速查找
if build_map:
self._build_template_map()
# 记录已加载的文件
self._loaded_files.add(file_abs_path)
return True
except FileNotFoundError:
logger.error(f"SQL模板文件不存在: {file_path}")
raise FileNotFoundError(f"SQL模板文件不存在: {file_path}")
except Exception as e:
logger.error(f"加载SQL模板失败: {str(e)}", exc_info=True)
raise
def _parse_with_stack(self, content):
"""使用栈结构解析SQL模板支持嵌套的标签匹配
支持--单行注释和/* */多行注释
Args:
content: 要解析的SQL模板内容
Returns:
解析后的SQL模板字典键为模板名称值为模板内容
"""
templates = {}
stack = [] # 用于存储标签信息的栈
current_content = [] # 当前正在收集的内容
in_multiline_comment = False # 多行注释标记
try:
# 按行处理内容
lines = content.split('\n')
for line_num, line in enumerate(lines):
# 处理多行注释
if in_multiline_comment:
# 检查是否是多行注释结束
if '*/' in line:
in_multiline_comment = False
# 截取注释结束后的内容
line = line.split('*/', 1)[1]
else:
# 完全在多行注释内,跳过此行
continue
# 处理单行注释和检查多行注释开始
if not line.strip().startswith('#') and stack and stack[-1][0] == 'sql':
# 检查是否有--注释
if '--' in line:
# 保留--前面的内容
line = line.split('--', 1)[0]
# 检查是否有/*多行注释开始
if '/*' in line:
parts = line.split('/*', 1)
# 如果/*和*/在同一行
if '*/' in parts[1]:
# 只保留/*前面和*/后面的内容
comment_part = parts[1].split('*/', 1)
line = parts[0] + comment_part[1]
else:
# 多行注释开始,只保留/*前面的内容
line = parts[0]
in_multiline_comment = True
stripped_line = line.strip()
# 检查是否是标签行
if stripped_line.startswith('#'):
# 检查是否是命名空间标签
if stripped_line.startswith('#namespace(') and (
stripped_line.find('"') > 0 or stripped_line.find("'") > 0):
# 提取命名空间名称
quote_pattern = re.search(r'#namespace\((["\'])([^"\']+)\1\)', stripped_line)
if quote_pattern:
namespace = quote_pattern.group(2)
templates['namespace'] = namespace
# 将命名空间标签入栈
stack.append(('namespace', namespace, line_num))
# 检查是否是SQL模板开始标签
elif stripped_line.startswith('#sql(') and (
stripped_line.find('"') > 0 or stripped_line.find("'") > 0):
# 提取SQL模板名称
quote_pattern = re.search(r'#sql\((["\'])([^"\']+)\1\)', stripped_line)
if quote_pattern:
sql_name = quote_pattern.group(2)
# 将SQL模板开始标签入栈
stack.append(('sql', sql_name, line_num))
current_content = [] # 重置内容收集器
# 检查是否是#if条件开始标签
elif stripped_line.startswith('#if(') and ')' in stripped_line:
# 提取条件表达式
condition_pattern = re.search(r'#if\(([^)]+)\)', stripped_line)
if condition_pattern:
condition = condition_pattern.group(1)
stack.append(('if', condition, line_num))
# 保留#if行到内容中 - 只要有SQL标签在栈中就保留
if any(tag[0] == 'sql' for tag in stack):
current_content.append(line.rstrip())
# 检查是否是结束标签
elif stripped_line == '#end' and stack:
# 获取栈顶元素
tag_type, tag_name, start_line = stack[-1]
# 只有当标签不是SQL结束标签时才保留#end行到内容中
if any(tag[0] == 'sql' for tag in stack) and tag_type != 'sql':
current_content.append(line.rstrip())
# 弹出当前标签
stack.pop()
# 如果是SQL模板结束标签保存内容
if tag_type == 'sql':
sql_content = '\n'.join(current_content).strip()
if sql_content:
templates[tag_name] = sql_content
# 重置内容收集器,准备下一个模板
current_content = []
# 收集内容不是标签行且当前有活动的SQL模板
else:
# 只要栈中有SQL标签就收集内容不管是在条件块内还是外
if any(tag[0] == 'sql' for tag in stack) and line.strip():
# 保留原始缩进
current_content.append(line.rstrip())
# 检查是否所有标签都正确闭合
if stack:
unclosed_tags = [f"{tag_type}('{tag_name}')" for tag_type, tag_name, _ in stack]
logger.warning(f"存在未闭合的标签: {', '.join(unclosed_tags)}")
else:
logger.debug("所有标签都已正确闭合")
except Exception as e:
logger.error(f"使用栈结构解析SQL模板时出错: {str(e)}", exc_info=True)
return templates
async def get_all_template_names(self):
"""
获取所有已加载的SQL模板的完整名称列表异步版本
Returns:
list: 包含所有模板名称的列表格式为 "命名空间.模板名称"
"""
template_names = []
# 遍历所有命名空间
for namespace, templates in self.templates.items():
# 遍历当前命名空间下的所有模板
for template_name in templates.keys():
# 构建完整的模板名称(命名空间.模板名称)
full_name = f"{namespace}.{template_name}"
template_names.append(full_name)
return template_names
def _has_template(self, full_template_name):
"""
内部方法快速检查指定的完整模板名称是否存在
使用MAP机制实现O(1)时间复杂度的查找
Args:
full_template_name: 完整的模板名称格式为 "命名空间.模板名称"
Returns:
bool: 如果模板存在返回True否则返回False
"""
logger.debug(f"检查模板是否存在: {full_template_name}")
# 如果_map为空先构建映射
if not self._template_map and self.templates:
logger.debug("模板映射为空,正在构建映射...")
self._build_template_map()
logger.debug(f"映射构建完成,包含 {len(self._template_map)} 个模板")
# 直接在映射中查找
exists = full_template_name in self._template_map
logger.debug(f"模板 '{full_template_name}' 存在性检查结果: {exists}")
if not exists and self._template_map:
logger.debug(f"当前映射中的模板: {list(self._template_map.keys())}")
return exists
def _build_template_map(self):
"""
内部方法构建完整模板名称到SQL内容的映射
用于支持快速查找功能
"""
logger.debug("开始构建模板映射...")
self._template_map.clear()
# 遍历所有命名空间和模板,构建映射
for namespace, templates in self.templates.items():
logger.debug(f"处理命名空间: {namespace},包含 {len(templates)} 个模板")
for template_name, sql_content in templates.items():
full_name = f"{namespace}.{template_name}"
self._template_map[full_name] = sql_content
logger.debug(f"添加映射: {full_name} -> {len(sql_content)} 字符的SQL")
logger.debug(f"模板映射构建完成,总共 {len(self._template_map)} 个模板")
logger.debug(f"映射中的模板列表: {list(self._template_map.keys())}")
async def get_sql(self, sql_name):
"""
获取指定名称的SQL语句支持命名空间格式namespace.sql_name异步版本
Args:
sql_name: SQL模板名称可以包含命名空间前缀
Returns:
str: SQL语句
Raises:
KeyError: 当找不到指定名称的SQL模板时
"""
try:
logger.debug(f"尝试获取SQL模板: {sql_name}")
logger.debug(f"当前模板映射状态: {len(self._template_map) if self._template_map else 0} 个模板")
# 快速检查模板是否存在
if self._has_template(sql_name):
logger.debug(f"在映射中找到模板: {sql_name}")
# 如果在映射中存在,直接返回
sql_content = self._template_map[sql_name]
logger.debug(f"模板内容长度: {len(sql_content)} 字符")
return sql_content
# 如果内存中没有检查Redis中是否有加载的模板数据
templates_loaded = await self._check_templates_loaded_from_redis()
if templates_loaded:
await self._load_from_redis()
# 再次尝试从内存中获取
if self._has_template(sql_name):
sql_content = self._template_map[sql_name]
logger.debug(f"从Redis加载后找到模板: {sql_name}")
return sql_content
logger.debug(f"映射中未找到模板: {sql_name},尝试从原始模板中查找")
# 尝试解析命名空间和SQL名称
if '.' in sql_name:
namespace, name = sql_name.split('.', 1)
else:
# 默认命名空间
namespace = "default"
name = sql_name
logger.debug(f"解析结果 - 命名空间: {namespace}, 模板名: {name}")
logger.debug(f"可用命名空间: {list(self.templates.keys())}")
# 检查命名空间是否存在
if namespace and namespace not in self.templates:
logger.error(f"命名空间 '{namespace}' 不存在,可用命名空间: {list(self.templates.keys())}")
raise KeyError(f"找不到命名空间: {namespace}")
# 检查SQL名称是否存在
available_templates = list(self.templates[namespace].keys())
logger.debug(f"命名空间 '{namespace}' 下的可用模板: {available_templates}")
if name and name not in self.templates[namespace]:
logger.error(f"SQL模板 '{name}' 在命名空间 '{namespace}' 下不存在,可用模板: {available_templates}")
raise KeyError(f"在命名空间 '{namespace}' 下找不到SQL模板: {name}")
sql = self.templates[namespace][name]
logger.debug(f"成功获取SQL模板内容长度: {len(sql)} 字符")
return sql
except KeyError:
raise
except Exception as e:
logger.error(f"获取SQL模板时出错: {str(e)}", exc_info=True)
raise
async def reload_template(self, template_path):
"""重新加载指定的SQL模板文件异步版本
Args:
template_path: SQL模板文件的路径可以是绝对路径或相对于sql_dir的相对路径
Returns:
bool: 重新加载是否成功
"""
try:
# 确定模板文件的完整路径
if self.sql_dir and not os.path.isabs(template_path):
file_path = os.path.join(self.sql_dir, template_path)
else:
file_path = template_path
# 获取文件的绝对路径
file_abs_path = os.path.abspath(file_path)
# 如果文件已加载,则从已加载列表中移除
if file_abs_path in self._loaded_files:
self._loaded_files.remove(file_abs_path)
# 清除映射表,下次访问时会重新构建
self._template_map.clear()
# 清除Redis中的模板数据
await self._clear_redis_data()
# 重新加载模板,立即构建映射
success = await self.load_template(file_abs_path, build_map=True)
# 如果加载成功更新Redis
if success:
await self._save_to_redis()
return success
except Exception as e:
logger.error(f"重新加载SQL模板失败: {str(e)}", exc_info=True)
return False
async def _check_templates_loaded_from_redis(self) -> bool:
"""
检查Redis中是否已经加载了SQL模板
Returns:
bool: 如果Redis中已加载返回True否则返回False
"""
try:
return await redisKit.exists(REDIS_SQL_TEMPLATES_LOADED_KEY)
except Exception as e:
logger.error(f"检查Redis中SQL模板状态失败: {str(e)}")
return False
async def _load_from_redis(self):
"""
从Redis加载SQL模板数据
"""
try:
# 加载templates
templates_str = await redisKit.get_data(REDIS_SQL_TEMPLATES_KEY)
if templates_str:
self.templates = json.loads(templates_str)
# 加载loaded_files
loaded_files_str = await redisKit.get_data(REDIS_SQL_LOADED_FILES_KEY)
if loaded_files_str:
self._loaded_files = set(json.loads(loaded_files_str))
# 加载template_map
template_map_str = await redisKit.get_data(REDIS_SQL_TEMPLATE_MAP_KEY)
if template_map_str:
self._template_map = json.loads(template_map_str)
2026-01-21 08:41:47 +08:00
# 无论从Redis加载了什么都尝试重新构建一次映射以确保一致性
if self.templates and not self._template_map:
self._build_template_map()
2026-01-12 07:49:18 +08:00
except Exception as e:
logger.error(f"从Redis加载SQL模板数据失败: {str(e)}")
# 出错时不抛出异常,使用内存中的默认值
async def _save_to_redis(self):
"""
将SQL模板数据保存到Redis
"""
try:
# 保存templates
await redisKit.set_data(REDIS_SQL_TEMPLATES_KEY, json.dumps(self.templates))
# 保存loaded_files
await redisKit.set_data(REDIS_SQL_LOADED_FILES_KEY, json.dumps(list(self._loaded_files)))
# 保存template_map
await redisKit.set_data(REDIS_SQL_TEMPLATE_MAP_KEY, json.dumps(self._template_map))
# 设置标记表示已加载
await redisKit.set_data(REDIS_SQL_TEMPLATES_LOADED_KEY, "1")
except Exception as e:
logger.error(f"将SQL模板数据保存到Redis失败: {str(e)}")
# 出错时不抛出异常,继续执行
async def _clear_redis_data(self):
"""
清除Redis中的模板数据
"""
try:
await redisKit.delete(REDIS_SQL_TEMPLATES_KEY)
await redisKit.delete(REDIS_SQL_LOADED_FILES_KEY)
await redisKit.delete(REDIS_SQL_TEMPLATE_MAP_KEY)
await redisKit.delete(REDIS_SQL_TEMPLATES_LOADED_KEY)
except Exception as e:
logger.error(f"清除Redis中SQL模板数据失败: {str(e)}")
# 出错时不抛出异常,继续执行