Files
aiData/DbKit/SqlTemplateLoader.py
HuangHai e51dc18d06 'commit'
2026-01-21 08:41:47 +08:00

533 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import os
import re
import asyncio
import aiofiles
import json
# 导入RedisKit
from Util.RedisKit import redisKit
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Redis键名常量
REDIS_SQL_TEMPLATES_KEY = "sql_templates:all_templates"
REDIS_SQL_LOADED_FILES_KEY = "sql_templates:loaded_files"
REDIS_SQL_TEMPLATE_MAP_KEY = "sql_templates:template_map"
REDIS_SQL_TEMPLATES_LOADED_KEY = "sql_templates:templates_loaded"
class SqlTemplateLoader:
"""SQL模板加载器负责加载和解析SQL模板文件异步版本"""
def __init__(self, sql_dir=None, auto_load=True):
"""初始化SQL模板加载器
Args:
sql_dir: SQL模板文件所在目录
auto_load: 是否自动加载目录下所有SQL模板文件
"""
self.sql_dir = sql_dir
self.templates = {} # 使用嵌套字典存储命名空间下的SQL模板
self._loaded_files = set() # 用于记录已加载的文件,避免重复加载
self._template_map = {} # 使用平面字典存储完整的模板名称映射,用于快速查找
# 注意异步版本不会在__init__中自动加载需要单独调用异步方法
async def load_all_templates(self):
"""自动加载SQL目录下所有的SQL模板文件异步版本
Returns:
int: 成功加载的模板文件数量
"""
logger.info(f"开始加载SQL模板目录: {self.sql_dir}")
if not os.path.exists(self.sql_dir):
logger.warning(f"SQL模板目录不存在: {self.sql_dir}")
return 0
# 检查是否已经从Redis加载过
templates_loaded = await self._check_templates_loaded_from_redis()
if templates_loaded:
await self._load_from_redis()
logger.info(f"从Redis加载SQL模板完成{len(self._loaded_files)}个文件,{len(self._template_map)}个模板")
return len(self._loaded_files)
loaded_count = 0
sql_files = []
try:
# 遍历目录下所有.sql文件
for filename in os.listdir(self.sql_dir):
if filename.lower().endswith('.sql'):
file_path = os.path.join(self.sql_dir, filename)
file_abs_path = os.path.abspath(file_path)
sql_files.append(file_abs_path)
# 使用异步方式加载所有模板,不立即构建映射以提高性能
for file_abs_path in sql_files:
# 跳过已经加载过的文件
if file_abs_path in self._loaded_files:
logger.debug(f"SQL文件已加载跳过: {file_abs_path}")
continue
try:
filename = os.path.basename(file_abs_path)
if await self.load_template(file_abs_path, build_map=False):
loaded_count += 1
except Exception as e:
logger.error(f"加载SQL模板文件 {filename} 失败: {str(e)}")
# 所有模板加载完成后,统一构建映射
if loaded_count > 0:
self._build_template_map()
logger.info(f"成功加载 {loaded_count} 个SQL模板文件已构建模板映射")
# 加载完成后保存到Redis
await self._save_to_redis()
return loaded_count
except Exception as e:
logger.error(f"遍历SQL模板目录时出错: {str(e)}")
return 0
async def load_template(self, template_path, build_map=True):
"""加载指定的SQL模板文件异步版本
Args:
template_path: SQL模板文件的路径可以是绝对路径或相对于sql_dir的相对路径
build_map: 是否立即构建模板映射默认为True。在批量加载时可设为False以提高性能
Returns:
bool: 加载是否成功
Raises:
FileNotFoundError: 当模板文件不存在时
Exception: 当加载过程发生其他错误时
"""
# 确定模板文件的完整路径
if self.sql_dir and not os.path.isabs(template_path):
file_path = os.path.join(self.sql_dir, template_path)
else:
file_path = template_path
# 检查文件是否已加载,避免重复加载
file_abs_path = os.path.abspath(file_path)
if file_abs_path in self._loaded_files:
logger.debug(f"SQL模板文件已加载跳过: {file_abs_path}")
return True
try:
# 异步读取文件内容
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f:
content = await f.read()
# 使用栈结构解析整个内容
templates = self._parse_with_stack(content)
# 处理命名空间
if 'namespace' in templates and templates['namespace']:
namespace = templates['namespace']
sql_templates = {k: v for k, v in templates.items() if k != 'namespace'}
else:
# 没有命名空间的情况
namespace = 'default'
sql_templates = templates
# 将解析出的SQL模板添加到templates字典中
if namespace not in self.templates:
self.templates[namespace] = {}
self.templates[namespace].update(sql_templates)
# 重新构建模板映射,确保新加载的模板可以被快速查找
if build_map:
self._build_template_map()
# 记录已加载的文件
self._loaded_files.add(file_abs_path)
return True
except FileNotFoundError:
logger.error(f"SQL模板文件不存在: {file_path}")
raise FileNotFoundError(f"SQL模板文件不存在: {file_path}")
except Exception as e:
logger.error(f"加载SQL模板失败: {str(e)}", exc_info=True)
raise
def _parse_with_stack(self, content):
"""使用栈结构解析SQL模板支持嵌套的标签匹配
支持--单行注释和/* */多行注释
Args:
content: 要解析的SQL模板内容
Returns:
解析后的SQL模板字典键为模板名称值为模板内容
"""
templates = {}
stack = [] # 用于存储标签信息的栈
current_content = [] # 当前正在收集的内容
in_multiline_comment = False # 多行注释标记
try:
# 按行处理内容
lines = content.split('\n')
for line_num, line in enumerate(lines):
# 处理多行注释
if in_multiline_comment:
# 检查是否是多行注释结束
if '*/' in line:
in_multiline_comment = False
# 截取注释结束后的内容
line = line.split('*/', 1)[1]
else:
# 完全在多行注释内,跳过此行
continue
# 处理单行注释和检查多行注释开始
if not line.strip().startswith('#') and stack and stack[-1][0] == 'sql':
# 检查是否有--注释
if '--' in line:
# 保留--前面的内容
line = line.split('--', 1)[0]
# 检查是否有/*多行注释开始
if '/*' in line:
parts = line.split('/*', 1)
# 如果/*和*/在同一行
if '*/' in parts[1]:
# 只保留/*前面和*/后面的内容
comment_part = parts[1].split('*/', 1)
line = parts[0] + comment_part[1]
else:
# 多行注释开始,只保留/*前面的内容
line = parts[0]
in_multiline_comment = True
stripped_line = line.strip()
# 检查是否是标签行
if stripped_line.startswith('#'):
# 检查是否是命名空间标签
if stripped_line.startswith('#namespace(') and (
stripped_line.find('"') > 0 or stripped_line.find("'") > 0):
# 提取命名空间名称
quote_pattern = re.search(r'#namespace\((["\'])([^"\']+)\1\)', stripped_line)
if quote_pattern:
namespace = quote_pattern.group(2)
templates['namespace'] = namespace
# 将命名空间标签入栈
stack.append(('namespace', namespace, line_num))
# 检查是否是SQL模板开始标签
elif stripped_line.startswith('#sql(') and (
stripped_line.find('"') > 0 or stripped_line.find("'") > 0):
# 提取SQL模板名称
quote_pattern = re.search(r'#sql\((["\'])([^"\']+)\1\)', stripped_line)
if quote_pattern:
sql_name = quote_pattern.group(2)
# 将SQL模板开始标签入栈
stack.append(('sql', sql_name, line_num))
current_content = [] # 重置内容收集器
# 检查是否是#if条件开始标签
elif stripped_line.startswith('#if(') and ')' in stripped_line:
# 提取条件表达式
condition_pattern = re.search(r'#if\(([^)]+)\)', stripped_line)
if condition_pattern:
condition = condition_pattern.group(1)
stack.append(('if', condition, line_num))
# 保留#if行到内容中 - 只要有SQL标签在栈中就保留
if any(tag[0] == 'sql' for tag in stack):
current_content.append(line.rstrip())
# 检查是否是结束标签
elif stripped_line == '#end' and stack:
# 获取栈顶元素
tag_type, tag_name, start_line = stack[-1]
# 只有当标签不是SQL结束标签时才保留#end行到内容中
if any(tag[0] == 'sql' for tag in stack) and tag_type != 'sql':
current_content.append(line.rstrip())
# 弹出当前标签
stack.pop()
# 如果是SQL模板结束标签保存内容
if tag_type == 'sql':
sql_content = '\n'.join(current_content).strip()
if sql_content:
templates[tag_name] = sql_content
# 重置内容收集器,准备下一个模板
current_content = []
# 收集内容不是标签行且当前有活动的SQL模板
else:
# 只要栈中有SQL标签就收集内容不管是在条件块内还是外
if any(tag[0] == 'sql' for tag in stack) and line.strip():
# 保留原始缩进
current_content.append(line.rstrip())
# 检查是否所有标签都正确闭合
if stack:
unclosed_tags = [f"{tag_type}('{tag_name}')" for tag_type, tag_name, _ in stack]
logger.warning(f"存在未闭合的标签: {', '.join(unclosed_tags)}")
else:
logger.debug("所有标签都已正确闭合")
except Exception as e:
logger.error(f"使用栈结构解析SQL模板时出错: {str(e)}", exc_info=True)
return templates
async def get_all_template_names(self):
"""
获取所有已加载的SQL模板的完整名称列表异步版本
Returns:
list: 包含所有模板名称的列表,格式为 "命名空间.模板名称"
"""
template_names = []
# 遍历所有命名空间
for namespace, templates in self.templates.items():
# 遍历当前命名空间下的所有模板
for template_name in templates.keys():
# 构建完整的模板名称(命名空间.模板名称)
full_name = f"{namespace}.{template_name}"
template_names.append(full_name)
return template_names
def _has_template(self, full_template_name):
"""
内部方法:快速检查指定的完整模板名称是否存在
使用MAP机制实现O(1)时间复杂度的查找
Args:
full_template_name: 完整的模板名称,格式为 "命名空间.模板名称"
Returns:
bool: 如果模板存在返回True否则返回False
"""
logger.debug(f"检查模板是否存在: {full_template_name}")
# 如果_map为空先构建映射
if not self._template_map and self.templates:
logger.debug("模板映射为空,正在构建映射...")
self._build_template_map()
logger.debug(f"映射构建完成,包含 {len(self._template_map)} 个模板")
# 直接在映射中查找
exists = full_template_name in self._template_map
logger.debug(f"模板 '{full_template_name}' 存在性检查结果: {exists}")
if not exists and self._template_map:
logger.debug(f"当前映射中的模板: {list(self._template_map.keys())}")
return exists
def _build_template_map(self):
"""
内部方法构建完整模板名称到SQL内容的映射
用于支持快速查找功能
"""
logger.debug("开始构建模板映射...")
self._template_map.clear()
# 遍历所有命名空间和模板,构建映射
for namespace, templates in self.templates.items():
logger.debug(f"处理命名空间: {namespace},包含 {len(templates)} 个模板")
for template_name, sql_content in templates.items():
full_name = f"{namespace}.{template_name}"
self._template_map[full_name] = sql_content
logger.debug(f"添加映射: {full_name} -> {len(sql_content)} 字符的SQL")
logger.debug(f"模板映射构建完成,总共 {len(self._template_map)} 个模板")
logger.debug(f"映射中的模板列表: {list(self._template_map.keys())}")
async def get_sql(self, sql_name):
"""
获取指定名称的SQL语句支持命名空间格式namespace.sql_name异步版本
Args:
sql_name: SQL模板名称可以包含命名空间前缀
Returns:
str: SQL语句
Raises:
KeyError: 当找不到指定名称的SQL模板时
"""
try:
logger.debug(f"尝试获取SQL模板: {sql_name}")
logger.debug(f"当前模板映射状态: {len(self._template_map) if self._template_map else 0} 个模板")
# 快速检查模板是否存在
if self._has_template(sql_name):
logger.debug(f"在映射中找到模板: {sql_name}")
# 如果在映射中存在,直接返回
sql_content = self._template_map[sql_name]
logger.debug(f"模板内容长度: {len(sql_content)} 字符")
return sql_content
# 如果内存中没有检查Redis中是否有加载的模板数据
templates_loaded = await self._check_templates_loaded_from_redis()
if templates_loaded:
await self._load_from_redis()
# 再次尝试从内存中获取
if self._has_template(sql_name):
sql_content = self._template_map[sql_name]
logger.debug(f"从Redis加载后找到模板: {sql_name}")
return sql_content
logger.debug(f"映射中未找到模板: {sql_name},尝试从原始模板中查找")
# 尝试解析命名空间和SQL名称
if '.' in sql_name:
namespace, name = sql_name.split('.', 1)
else:
# 默认命名空间
namespace = "default"
name = sql_name
logger.debug(f"解析结果 - 命名空间: {namespace}, 模板名: {name}")
logger.debug(f"可用命名空间: {list(self.templates.keys())}")
# 检查命名空间是否存在
if namespace and namespace not in self.templates:
logger.error(f"命名空间 '{namespace}' 不存在,可用命名空间: {list(self.templates.keys())}")
raise KeyError(f"找不到命名空间: {namespace}")
# 检查SQL名称是否存在
available_templates = list(self.templates[namespace].keys())
logger.debug(f"命名空间 '{namespace}' 下的可用模板: {available_templates}")
if name and name not in self.templates[namespace]:
logger.error(f"SQL模板 '{name}' 在命名空间 '{namespace}' 下不存在,可用模板: {available_templates}")
raise KeyError(f"在命名空间 '{namespace}' 下找不到SQL模板: {name}")
sql = self.templates[namespace][name]
logger.debug(f"成功获取SQL模板内容长度: {len(sql)} 字符")
return sql
except KeyError:
raise
except Exception as e:
logger.error(f"获取SQL模板时出错: {str(e)}", exc_info=True)
raise
async def reload_template(self, template_path):
"""重新加载指定的SQL模板文件异步版本
Args:
template_path: SQL模板文件的路径可以是绝对路径或相对于sql_dir的相对路径
Returns:
bool: 重新加载是否成功
"""
try:
# 确定模板文件的完整路径
if self.sql_dir and not os.path.isabs(template_path):
file_path = os.path.join(self.sql_dir, template_path)
else:
file_path = template_path
# 获取文件的绝对路径
file_abs_path = os.path.abspath(file_path)
# 如果文件已加载,则从已加载列表中移除
if file_abs_path in self._loaded_files:
self._loaded_files.remove(file_abs_path)
# 清除映射表,下次访问时会重新构建
self._template_map.clear()
# 清除Redis中的模板数据
await self._clear_redis_data()
# 重新加载模板,立即构建映射
success = await self.load_template(file_abs_path, build_map=True)
# 如果加载成功更新Redis
if success:
await self._save_to_redis()
return success
except Exception as e:
logger.error(f"重新加载SQL模板失败: {str(e)}", exc_info=True)
return False
async def _check_templates_loaded_from_redis(self) -> bool:
"""
检查Redis中是否已经加载了SQL模板
Returns:
bool: 如果Redis中已加载返回True否则返回False
"""
try:
return await redisKit.exists(REDIS_SQL_TEMPLATES_LOADED_KEY)
except Exception as e:
logger.error(f"检查Redis中SQL模板状态失败: {str(e)}")
return False
async def _load_from_redis(self):
"""
从Redis加载SQL模板数据
"""
try:
# 加载templates
templates_str = await redisKit.get_data(REDIS_SQL_TEMPLATES_KEY)
if templates_str:
self.templates = json.loads(templates_str)
# 加载loaded_files
loaded_files_str = await redisKit.get_data(REDIS_SQL_LOADED_FILES_KEY)
if loaded_files_str:
self._loaded_files = set(json.loads(loaded_files_str))
# 加载template_map
template_map_str = await redisKit.get_data(REDIS_SQL_TEMPLATE_MAP_KEY)
if template_map_str:
self._template_map = json.loads(template_map_str)
# 无论从Redis加载了什么都尝试重新构建一次映射以确保一致性
if self.templates and not self._template_map:
self._build_template_map()
except Exception as e:
logger.error(f"从Redis加载SQL模板数据失败: {str(e)}")
# 出错时不抛出异常,使用内存中的默认值
async def _save_to_redis(self):
"""
将SQL模板数据保存到Redis
"""
try:
# 保存templates
await redisKit.set_data(REDIS_SQL_TEMPLATES_KEY, json.dumps(self.templates))
# 保存loaded_files
await redisKit.set_data(REDIS_SQL_LOADED_FILES_KEY, json.dumps(list(self._loaded_files)))
# 保存template_map
await redisKit.set_data(REDIS_SQL_TEMPLATE_MAP_KEY, json.dumps(self._template_map))
# 设置标记表示已加载
await redisKit.set_data(REDIS_SQL_TEMPLATES_LOADED_KEY, "1")
except Exception as e:
logger.error(f"将SQL模板数据保存到Redis失败: {str(e)}")
# 出错时不抛出异常,继续执行
async def _clear_redis_data(self):
"""
清除Redis中的模板数据
"""
try:
await redisKit.delete(REDIS_SQL_TEMPLATES_KEY)
await redisKit.delete(REDIS_SQL_LOADED_FILES_KEY)
await redisKit.delete(REDIS_SQL_TEMPLATE_MAP_KEY)
await redisKit.delete(REDIS_SQL_TEMPLATES_LOADED_KEY)
except Exception as e:
logger.error(f"清除Redis中SQL模板数据失败: {str(e)}")
# 出错时不抛出异常,继续执行