270 lines
9.0 KiB
Python
270 lines
9.0 KiB
Python
# coding=utf-8
|
||
import os
|
||
import logging
|
||
|
||
import hashlib
|
||
import numpy as np
|
||
import cv2
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def clean_station_name(name):
|
||
"""
|
||
清理场站名称,去除特殊字符和距离信息
|
||
"""
|
||
if not name:
|
||
return ""
|
||
import re
|
||
# 移除常见的括号备注
|
||
name = re.sub(r'\(.*?\)', '', name)
|
||
name = re.sub(r'(.*?)', '', name)
|
||
return name.strip()
|
||
|
||
def take_screenshot(d, filename, save_dir=None):
|
||
"""
|
||
获取屏幕截图并保存
|
||
"""
|
||
if not save_dir:
|
||
from Config.Config import TEMP_IMAGE_DIR
|
||
save_dir = TEMP_IMAGE_DIR
|
||
|
||
if not os.path.exists(save_dir):
|
||
os.makedirs(save_dir)
|
||
|
||
# 确保文件名有后缀
|
||
if not filename.endswith(".jpg") and not filename.endswith(".png"):
|
||
filename = f"{filename}.jpg"
|
||
|
||
full_path = os.path.join(save_dir, filename)
|
||
d.screenshot(full_path)
|
||
return full_path
|
||
|
||
def get_file_md5(file_path):
|
||
"""计算文件的 MD5 值"""
|
||
if not os.path.exists(file_path):
|
||
return None
|
||
hash_md5 = hashlib.md5()
|
||
with open(file_path, "rb") as f:
|
||
for chunk in iter(lambda: f.read(4096), b""):
|
||
hash_md5.update(chunk)
|
||
return hash_md5.hexdigest()
|
||
|
||
def get_image_content_md5(file_path, top_ratio=0.1, bottom_ratio=0.1):
|
||
"""
|
||
计算图片核心内容的 MD5 值(排除状态栏和导航栏)
|
||
"""
|
||
img = read_image(file_path)
|
||
if img is None:
|
||
return None
|
||
|
||
h, w = img.shape[:2]
|
||
top = int(h * top_ratio)
|
||
bottom = int(h * (1 - bottom_ratio))
|
||
|
||
# 裁剪中间部分
|
||
content = img[top:bottom, :]
|
||
|
||
# 直接使用原始像素数组计算 MD5,避免压缩损失
|
||
return hashlib.md5(content.tobytes()).hexdigest()
|
||
|
||
def read_image(path):
|
||
"""读取图片,支持中文路径"""
|
||
if not path or not os.path.exists(path):
|
||
return None
|
||
try:
|
||
data = np.fromfile(path, dtype=np.uint8)
|
||
if data.size == 0:
|
||
return None
|
||
img = cv2.imdecode(data, -1)
|
||
return img
|
||
except Exception as e:
|
||
logger.error(f"Error reading image {path}: {e}")
|
||
return None
|
||
|
||
def save_image(path, img):
|
||
"""保存图片,支持中文路径"""
|
||
try:
|
||
ext = os.path.splitext(path)[1]
|
||
if not ext:
|
||
ext = ".jpg"
|
||
cv2.imencode(ext, img)[1].tofile(path)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Error saving image {path}: {e}")
|
||
return False
|
||
|
||
def draw_rectangles(image_path, bboxes=None, click_points=None, save_vl=True):
|
||
"""
|
||
使用 OpenCV 在图片上绘制矩形框和点击点,生成 _vl.jpg 和 _flag.jpg
|
||
- _vl.jpg: 仅包含矩形框,供视觉模型参考
|
||
- _flag.jpg: 包含矩形框和点击点,供人工调试
|
||
"""
|
||
try:
|
||
DEBUG_BOX_COLOR = (0, 255, 0) # 绿色矩形
|
||
DEBUG_POINT_COLOR = (0, 0, 255) # 红色点击点
|
||
DEBUG_BOX_THICKNESS = 3
|
||
|
||
img = read_image(image_path)
|
||
if img is None:
|
||
return image_path
|
||
|
||
base_img = img.copy()
|
||
if bboxes:
|
||
for box in bboxes:
|
||
if len(box) == 4:
|
||
cv2.rectangle(base_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), DEBUG_BOX_COLOR, DEBUG_BOX_THICKNESS)
|
||
|
||
vl_path = None
|
||
if save_vl:
|
||
vl_path = image_path.replace(".jpg", "_vl.jpg")
|
||
save_image(vl_path, base_img)
|
||
|
||
flag_img = base_img.copy()
|
||
if click_points:
|
||
for p in click_points:
|
||
if len(p) == 2:
|
||
center = (int(p[0]), int(p[1]))
|
||
# 绘制一个实心圆和中心点
|
||
cv2.circle(flag_img, center, 12, DEBUG_POINT_COLOR, -1)
|
||
cv2.circle(flag_img, center, 2, (255, 255, 255), -1)
|
||
|
||
flag_path = image_path.replace(".jpg", "_flag.jpg")
|
||
save_image(flag_path, flag_img)
|
||
|
||
return vl_path or flag_path, flag_path
|
||
except Exception as e:
|
||
logger.error(f"绘制诊断图片失败: {e}")
|
||
return image_path, image_path
|
||
|
||
def clear_temp_dir(save_dir=None):
|
||
"""清空临时目录中的所有文件"""
|
||
if save_dir is None:
|
||
from Config.Config import TEMP_IMAGE_DIR
|
||
save_dir = TEMP_IMAGE_DIR
|
||
|
||
if not os.path.exists(save_dir):
|
||
return
|
||
|
||
for f in os.listdir(save_dir):
|
||
file_path = os.path.join(save_dir, f)
|
||
try:
|
||
if os.path.isfile(file_path):
|
||
os.remove(file_path)
|
||
except Exception as e:
|
||
logger.error(f"Error deleting file {file_path}: {e}")
|
||
|
||
from Util.EasyOcrKit import get_easyocr_reader
|
||
|
||
# 预加载 EasyOCR Reader (单例模式)
|
||
def get_ocr_reader():
|
||
return get_easyocr_reader(gpu=True)
|
||
|
||
def detect_price_info_container_cv(image_path):
|
||
"""
|
||
使用 OCR 精准定位详情页中的价格入口文本(“全部时段”或“全天价格统一”)。
|
||
返回: [x1, y1, x2, y2] 归一化坐标,如果未找到则返回 None
|
||
"""
|
||
img = read_image(image_path)
|
||
if img is None:
|
||
return None
|
||
h, w = img.shape[:2]
|
||
|
||
keywords = ['全部时段', '全天价格统一']
|
||
|
||
try:
|
||
reader = get_ocr_reader()
|
||
# 获取所有识别结果,避免重复 OCR
|
||
results = reader.read_text(img)
|
||
|
||
for (quad, text, prob) in results:
|
||
# 检查是否包含任一关键字
|
||
if any(kw in text for kw in keywords) and prob >= 0.5:
|
||
# 使用封装后的方法计算归一化矩形
|
||
res = reader.get_normalized_rect(quad, w, h)
|
||
print(f"[OCR识别] 找到文本: '{text}', 置信度: {prob:.4f}, 归一化坐标: {res}")
|
||
return res
|
||
except Exception as e:
|
||
print(f"OCR 识别发生异常: {e}")
|
||
|
||
return None
|
||
|
||
def detect_warm_popup_xczs_cv(image_path):
|
||
"""
|
||
使用 OCR 精准定位“温馨提示”弹窗中的“下次再说”按钮。
|
||
返回: [x1, y1, x2, y2] 归一化坐标,如果未找到则返回 None
|
||
"""
|
||
img = read_image(image_path)
|
||
if img is None:
|
||
return None
|
||
h, w = img.shape[:2]
|
||
|
||
keyword = '下次再说'
|
||
|
||
try:
|
||
reader = get_ocr_reader()
|
||
results = reader.read_text(img)
|
||
|
||
for (quad, text, prob) in results:
|
||
if keyword in text and prob >= 0.5:
|
||
res = reader.get_normalized_rect(quad, w, h)
|
||
print(f"[OCR识别] 找到“下次再说”: '{text}', 置信度: {prob:.4f}, 归一化坐标: {res}")
|
||
return res
|
||
except Exception as e:
|
||
print(f"OCR 识别发生异常: {e}")
|
||
|
||
return None
|
||
|
||
def setup_logger(name, log_file=None, clear_old_log=False):
|
||
"""
|
||
配置日志,支持同时输出到控制台和文件。
|
||
使用供应商代号作为父级 Logger,所有子 Logger 继承其 Handler,
|
||
并通过 propagate=False 避免与根 Logger 重复。
|
||
:param name: Logger 名称
|
||
:param log_file: 指定日志文件路径,如果不指定则使用默认路径
|
||
:param clear_old_log: 是否在启动时清空旧日志文件
|
||
"""
|
||
# 1. 获取供应商代号 (如 TelaiDian)
|
||
supplier_code = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
# 2. 获取父级 Logger 并配置
|
||
parent_logger = logging.getLogger(supplier_code)
|
||
parent_logger.setLevel(logging.INFO)
|
||
parent_logger.propagate = False # 禁止向上传递给 root logger,防止重复
|
||
|
||
if log_file is None:
|
||
# 获取项目根目录 (aiData)
|
||
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
log_dir = os.path.join(root_dir, "Logs")
|
||
if not os.path.exists(log_dir):
|
||
os.makedirs(log_dir)
|
||
log_file = os.path.join(log_dir, f"{supplier_code}.log")
|
||
|
||
# 如果需要清空旧日志且文件存在
|
||
if clear_old_log and os.path.exists(log_file):
|
||
try:
|
||
# 关闭现有的 handler 以便删除文件
|
||
for handler in parent_logger.handlers[:]:
|
||
handler.close()
|
||
parent_logger.removeHandler(handler)
|
||
os.remove(log_file)
|
||
except Exception as e:
|
||
print(f"无法清空旧日志文件 {log_file}: {e}")
|
||
|
||
if not parent_logger.handlers:
|
||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
|
||
# 控制台 Handler
|
||
ch = logging.StreamHandler()
|
||
ch.setFormatter(formatter)
|
||
parent_logger.addHandler(ch)
|
||
|
||
# 文件 Handler
|
||
fh = logging.FileHandler(log_file, encoding='utf-8')
|
||
fh.setFormatter(formatter)
|
||
parent_logger.addHandler(fh)
|
||
|
||
# 3. 返回子 Logger (如 TelaiDian.T4_TelaiDian)
|
||
if name == supplier_code:
|
||
return parent_logger
|
||
return logging.getLogger(f"{supplier_code}.{name}")
|