186 lines
7.8 KiB
Python
186 lines
7.8 KiB
Python
import asyncio
|
|
import os
|
|
import sys
|
|
import logging
|
|
from typing import List, Dict
|
|
|
|
# 确保项目根目录在 sys.path 中,并且优先于外部同名模块
|
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
if project_root not in sys.path:
|
|
sys.path.insert(0, project_root)
|
|
|
|
from Util import Win32Patch
|
|
from DbKit.Db import Db
|
|
from Util.RedisKit import RedisKit
|
|
from sqlalchemy.sql import text
|
|
|
|
try:
|
|
from Config.Config import DB_URL
|
|
except ModuleNotFoundError:
|
|
import importlib.util
|
|
_root_dir_cfg = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
_config_path_cfg = os.path.join(_root_dir_cfg, "Config", "Config.py")
|
|
_spec_cfg = importlib.util.spec_from_file_location("project_config_fallback", _config_path_cfg)
|
|
_cfg_mod = importlib.util.module_from_spec(_spec_cfg)
|
|
assert _spec_cfg.loader is not None
|
|
_spec_cfg.loader.exec_module(_cfg_mod)
|
|
DB_URL = _cfg_mod.DB_URL
|
|
|
|
# 配置日志
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger("T6_ClearHistory")
|
|
|
|
# 供应商配置映射
|
|
VENDORS = {
|
|
"1": {"name": "特来电", "redis_pattern": "crawled:tld:*"},
|
|
"2": {"name": "新电途", "redis_pattern": "crawled:xdt:*"},
|
|
"3": {"name": "艾特吉易充", "redis_pattern": "crawled:aite:*"},
|
|
"4": {"name": "驿来特", "redis_pattern": "crawled:ylt:*"},
|
|
}
|
|
|
|
# 清理模式常量
|
|
MODE_ALL = "1" # 全量清理
|
|
MODE_HISTORY = "2" # 仅清理历史记录 (is_current=0)
|
|
|
|
async def clear_vendor_data(db: Db, redis_kit: RedisKit, vendor_info: Dict, mode: str):
|
|
"""
|
|
清理特定供应商的数据库和 Redis 数据
|
|
"""
|
|
operator = vendor_info["name"]
|
|
pattern = vendor_info["redis_pattern"]
|
|
|
|
mode_name = "全量清理" if mode == MODE_ALL else "仅清理历史记录 (is_current=0)"
|
|
logger.info(f"--- 正在清理供应商: {operator} ({mode_name}) ---")
|
|
|
|
# 1. 数据库清理
|
|
try:
|
|
async with db.AsyncSessionLocal() as session:
|
|
# 统一使用一个事务块,包含查询和删除
|
|
async with session.begin():
|
|
# 1.1 先查询该供应商所有的 station_hash
|
|
logger.info(f"[{operator}] 正在获取场站列表...")
|
|
sql_get_hashes = "SELECT station_hash FROM t_station_profile_scd WHERE operator = :operator"
|
|
result_hashes = await session.execute(text(sql_get_hashes), {"operator": operator})
|
|
station_hashes = [row[0] for row in result_hashes.fetchall()]
|
|
|
|
if not station_hashes:
|
|
logger.info(f"[{operator}] 数据库中未找到该供应商的场站记录。")
|
|
else:
|
|
logger.info(f"[{operator}] 找到 {len(station_hashes)} 个场站。")
|
|
|
|
# 构造清理条件
|
|
where_clause = ""
|
|
if mode == MODE_HISTORY:
|
|
where_clause = "AND is_current = 0"
|
|
|
|
# 分批执行删除,防止 SQL 语句过长
|
|
batch_size = 500
|
|
total_status_deleted = 0
|
|
total_price_deleted = 0
|
|
|
|
for i in range(0, len(station_hashes), batch_size):
|
|
batch = station_hashes[i:i + batch_size]
|
|
|
|
# 1. t_station_status_scd
|
|
sql_status = f"DELETE FROM t_station_status_scd WHERE station_hash IN :hashes {where_clause}"
|
|
res_status = await session.execute(text(sql_status), {"hashes": batch})
|
|
total_status_deleted += res_status.rowcount
|
|
|
|
# 2. t_station_price_schedule_scd
|
|
sql_price = f"DELETE FROM t_station_price_schedule_scd WHERE station_hash IN :hashes {where_clause}"
|
|
res_price = await session.execute(text(sql_price), {"hashes": batch})
|
|
total_price_deleted += res_price.rowcount
|
|
|
|
logger.info(f"[{operator}] 正在处理分批 {i//batch_size + 1}/{ (len(station_hashes)-1)//batch_size + 1}...")
|
|
|
|
logger.info(f"[{operator}] 表 t_station_status_scd 清理完成,共删除 {total_status_deleted} 条记录。")
|
|
logger.info(f"[{operator}] 表 t_station_price_schedule_scd 清理完成,共删除 {total_price_deleted} 条记录。")
|
|
|
|
# 3. 最后删除主表记录
|
|
sql_profile = f"DELETE FROM t_station_profile_scd WHERE operator = :operator {where_clause}"
|
|
res_profile = await session.execute(text(sql_profile) , {"operator": operator})
|
|
logger.info(f"[{operator}] 表 t_station_profile_scd 清理完成,共删除 {res_profile.rowcount} 条记录。")
|
|
|
|
logger.info(f"[{operator}] 数据库记录清理完成。")
|
|
except Exception as e:
|
|
logger.error(f"[{operator}] 数据库清理失败: {e}")
|
|
|
|
# 2. Redis 清理 (仅在全量清理模式下执行)
|
|
if mode == MODE_ALL:
|
|
try:
|
|
keys = await redis_kit.keys(pattern)
|
|
if keys:
|
|
logger.info(f"[{operator}] 正在清理 Redis 缓存 (模式: {pattern})...")
|
|
# 分批删除 Redis 键
|
|
for j in range(0, len(keys), 1000):
|
|
batch_keys = keys[j:j+1000]
|
|
await redis_kit.delete(*batch_keys)
|
|
logger.info(f"[{operator}] Redis 缓存清理完成,共删除 {len(keys)} 个键。")
|
|
else:
|
|
logger.info(f"[{operator}] 未匹配到模式为 '{pattern}' 的 Redis 键。")
|
|
except Exception as e:
|
|
logger.error(f"[{operator}] Redis 清理失败: {e}")
|
|
else:
|
|
logger.info(f"[{operator}] 模式为仅清理历史,跳过 Redis 缓存清理。")
|
|
|
|
async def main():
|
|
print("\n" + "="*40)
|
|
print(" 供应商数据清理工具 (Doris 优化版)")
|
|
print("="*40)
|
|
|
|
# 第一步:选择供应商
|
|
print("【第一步】选择供应商:")
|
|
for key, info in VENDORS.items():
|
|
print(f" {key}. {info['name']}")
|
|
print(" 0. 全部供应商")
|
|
print(" q. 退出")
|
|
print("-" * 40)
|
|
|
|
vendor_choice = input("请输入供应商编号: ").strip().lower()
|
|
if vendor_choice == 'q': return
|
|
|
|
selected_vendors = []
|
|
if vendor_choice == '0':
|
|
selected_vendors = list(VENDORS.values())
|
|
elif vendor_choice in VENDORS:
|
|
selected_vendors = [VENDORS[vendor_choice]]
|
|
else:
|
|
print("❌ 无效的选择,请重新运行脚本。")
|
|
return
|
|
|
|
# 第二步:选择清理模式
|
|
print("\n【第二步】选择清理模式:")
|
|
print(f" {MODE_ALL}. 全量清理 (删除所有数据库记录 + 清除 Redis 缓存)")
|
|
print(f" {MODE_HISTORY}. 仅清理历史 (仅删除 is_current=0 的历史记录)")
|
|
print("-" * 40)
|
|
|
|
mode_choice = input(f"请输入模式编号 [默认 {MODE_HISTORY}]: ").strip() or MODE_HISTORY
|
|
|
|
if mode_choice not in [MODE_ALL, MODE_HISTORY]:
|
|
print("❌ 无效的模式选择。")
|
|
return
|
|
|
|
mode_text = "全量清理" if mode_choice == MODE_ALL else "仅清理历史"
|
|
print(f"\n🚀 即将对 {len(selected_vendors)} 个供应商执行 [{mode_text}] 操作...")
|
|
|
|
# 初始化资源
|
|
db = Db(db_url=DB_URL)
|
|
await db.init_db()
|
|
redis_kit = RedisKit()
|
|
|
|
try:
|
|
for vendor in selected_vendors:
|
|
await clear_vendor_data(db, redis_kit, vendor, mode_choice)
|
|
finally:
|
|
logger.info("正在关闭数据库连接...")
|
|
await db.close()
|
|
|
|
if __name__ == "__main__":
|
|
Win32Patch.patch()
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
pass
|
|
except Exception as e:
|
|
logger.error(f"程序运行异常: {e}")
|