486 lines
18 KiB
Python
486 lines
18 KiB
Python
import logging
|
||
import os
|
||
import cv2
|
||
import numpy as np
|
||
import time
|
||
import json
|
||
import hashlib
|
||
from Apps.AiTeJiYiChong.Config.Setting import BOTTOM_SAFE_EXCLUDE_RATIO
|
||
from Config.Config import TEMP_IMAGE_DIR
|
||
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def clean_station_name(name):
|
||
"""
|
||
清理场站名称,去掉末尾的省略号和空格
|
||
"""
|
||
if not name:
|
||
return ""
|
||
# 去掉末尾的 ... 或 ……
|
||
name = name.strip()
|
||
while name.endswith(".") or name.endswith("。"):
|
||
name = name[:-1]
|
||
return name.strip()
|
||
|
||
|
||
def get_file_md5(file_path):
|
||
"""计算文件的 MD5 值"""
|
||
if not os.path.exists(file_path):
|
||
return None
|
||
hash_md5 = hashlib.md5()
|
||
with open(file_path, "rb") as f:
|
||
for chunk in iter(lambda: f.read(4096), b""):
|
||
hash_md5.update(chunk)
|
||
return hash_md5.hexdigest()
|
||
|
||
def get_name_md5(name):
|
||
if not name:
|
||
return "unknown"
|
||
if not isinstance(name, str):
|
||
name = str(name)
|
||
return hashlib.md5(name.encode("utf-8")).hexdigest()
|
||
|
||
|
||
def get_image_content_md5(file_path, top_ratio=0.1, bottom_ratio=0.1):
|
||
"""
|
||
计算图片核心内容的 MD5 值(排除状态栏和导航栏)
|
||
"""
|
||
img = read_image(file_path)
|
||
if img is None:
|
||
return None
|
||
|
||
h, w = img.shape[:2]
|
||
top = int(h * top_ratio)
|
||
bottom = int(h * (1 - bottom_ratio))
|
||
|
||
# 裁剪中间部分
|
||
content = img[top:bottom, :]
|
||
|
||
# 将图片数据转换为字节流计算 MD5
|
||
# 使用 cv2.imencode 转为 jpg 字节流,避免原始 numpy 数组可能的细微差异(虽然通常没问题)
|
||
success, encoded_img = cv2.imencode(".jpg", content)
|
||
if success:
|
||
return hashlib.md5(encoded_img.tobytes()).hexdigest()
|
||
return hashlib.md5(content.tobytes()).hexdigest()
|
||
|
||
|
||
def read_image(path):
|
||
"""读取图片,支持中文路径"""
|
||
if not path or not os.path.exists(path):
|
||
return None
|
||
try:
|
||
# 使用 np.fromfile 解决中文路径问题
|
||
data = np.fromfile(path, dtype=np.uint8)
|
||
if data.size == 0:
|
||
return None
|
||
img = cv2.imdecode(data, -1)
|
||
return img
|
||
except Exception as e:
|
||
logger.error(f"Error reading image {path}: {e}")
|
||
return None
|
||
|
||
|
||
def save_image(path, img):
|
||
"""保存图片,支持中文路径"""
|
||
try:
|
||
ext = os.path.splitext(path)[1]
|
||
if not ext:
|
||
ext = ".jpg"
|
||
cv2.imencode(ext, img)[1].tofile(path)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Error saving image {path}: {e}")
|
||
return False
|
||
|
||
|
||
# 截图
|
||
def take_screenshot(d, image_uuid, save_dir=TEMP_IMAGE_DIR):
|
||
path = os.path.join(save_dir, f"{image_uuid}.jpg")
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
d.screenshot(path)
|
||
return path
|
||
|
||
|
||
def clear_temp_dir(save_dir=TEMP_IMAGE_DIR):
|
||
"""清空临时目录中的所有文件"""
|
||
if not os.path.exists(save_dir):
|
||
return
|
||
logger.info(f"正在清空临时目录: {save_dir}")
|
||
for file in os.listdir(save_dir):
|
||
file_path = os.path.join(save_dir, file)
|
||
try:
|
||
if os.path.isfile(file_path):
|
||
os.remove(file_path)
|
||
elif os.path.isdir(file_path):
|
||
import shutil
|
||
shutil.rmtree(file_path)
|
||
except Exception as e:
|
||
logger.error(f"无法删除文件 {file_path}: {e}")
|
||
|
||
|
||
def click_image_template(d, template_path, timeout=5.0, threshold=0.8):
|
||
"""
|
||
使用 OpenCV 模板匹配查找并点击图片
|
||
"""
|
||
if not os.path.exists(template_path):
|
||
logger.info(f"Template file not found: {template_path}")
|
||
return False
|
||
|
||
template = read_image(template_path)
|
||
if template is None:
|
||
logger.info(f"Failed to load template: {template_path}")
|
||
return False
|
||
|
||
t_h, t_w = template.shape[:2]
|
||
start_time = time.time()
|
||
|
||
while time.time() - start_time < timeout:
|
||
temp_uuid = "temp_click_check"
|
||
screenshot_path = take_screenshot(d, temp_uuid, save_dir=TEMP_IMAGE_DIR)
|
||
|
||
target = read_image(screenshot_path)
|
||
if target is None:
|
||
time.sleep(0.5)
|
||
continue
|
||
|
||
# 多尺度匹配
|
||
best_match = None
|
||
for scale in np.linspace(0.8, 1.2, 5):
|
||
resized = cv2.resize(template, (int(t_w * scale), int(t_h * scale)))
|
||
res = cv2.matchTemplate(target, resized, cv2.TM_CCOEFF_NORMED)
|
||
_, max_val, _, max_loc = cv2.minMaxLoc(res)
|
||
if best_match is None or max_val > best_match[0]:
|
||
best_match = (max_val, max_loc, resized.shape[1], resized.shape[0])
|
||
|
||
if best_match and best_match[0] >= threshold:
|
||
max_val, max_loc, r_w, r_h = best_match
|
||
center_x = max_loc[0] + r_w // 2
|
||
center_y = max_loc[1] + r_h // 2
|
||
logger.info(f"成功点击图片模板: {template_path}, 匹配度: {max_val:.2f}")
|
||
d.click(center_x, center_y)
|
||
return True
|
||
|
||
time.sleep(0.5)
|
||
|
||
return False
|
||
|
||
|
||
def setup_logger(name, log_file=None, clear_old_log=False):
|
||
"""
|
||
配置日志,支持同时输出到控制台和文件。
|
||
使用供应商代号作为父级 Logger,所有子 Logger 继承其 Handler,
|
||
并通过 propagate=False 避免与根 Logger 重复。
|
||
:param name: Logger 名称
|
||
:param log_file: 指定日志文件路径,如果不指定则使用默认路径
|
||
:param clear_old_log: 是否在启动时清空旧日志文件
|
||
"""
|
||
# 1. 获取供应商代号 (如 AiTeJiYiChong)
|
||
supplier_code = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
# 2. 获取父级 Logger 并配置
|
||
parent_logger = logging.getLogger(supplier_code)
|
||
parent_logger.setLevel(logging.INFO)
|
||
parent_logger.propagate = False # 禁止向上传递给 root logger,防止重复
|
||
|
||
if log_file is None:
|
||
# 获取项目根目录 (aiData)
|
||
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
log_dir = os.path.join(root_dir, "Logs")
|
||
if not os.path.exists(log_dir):
|
||
os.makedirs(log_dir)
|
||
log_file = os.path.join(log_dir, f"{supplier_code}.log")
|
||
|
||
# 如果需要清空旧日志且文件存在
|
||
if clear_old_log and os.path.exists(log_file):
|
||
try:
|
||
# 关闭现有的 handler 以便删除文件
|
||
for handler in parent_logger.handlers[:]:
|
||
handler.close()
|
||
parent_logger.removeHandler(handler)
|
||
os.remove(log_file)
|
||
except Exception as e:
|
||
print(f"无法清空旧日志文件 {log_file}: {e}")
|
||
|
||
if not parent_logger.handlers:
|
||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
|
||
# 控制台 Handler
|
||
ch = logging.StreamHandler()
|
||
ch.setFormatter(formatter)
|
||
parent_logger.addHandler(ch)
|
||
|
||
# 文件 Handler
|
||
fh = logging.FileHandler(log_file, encoding='utf-8')
|
||
fh.setFormatter(formatter)
|
||
parent_logger.addHandler(fh)
|
||
|
||
# 3. 返回子 Logger
|
||
if name == supplier_code:
|
||
return parent_logger
|
||
return logging.getLogger(f"{supplier_code}.{name}")
|
||
|
||
|
||
# 默认使用供应商级别的日志配置
|
||
logger = setup_logger("Kit")
|
||
|
||
|
||
def find_template_coords(img_path, template_path, threshold=0.8):
|
||
"""
|
||
在图片中查找模板并返回中心坐标
|
||
"""
|
||
if not os.path.exists(img_path) or not os.path.exists(template_path):
|
||
return None
|
||
|
||
img = read_image(img_path)
|
||
template = read_image(template_path)
|
||
if img is None or template is None:
|
||
return None
|
||
|
||
t_h, t_w = template.shape[:2]
|
||
|
||
# 多尺度匹配
|
||
best_match = None
|
||
for scale in np.linspace(0.8, 1.2, 5):
|
||
resized = cv2.resize(template, (int(t_w * scale), int(t_h * scale)))
|
||
if resized.shape[0] > img.shape[0] or resized.shape[1] > img.shape[1]:
|
||
continue
|
||
res = cv2.matchTemplate(img, resized, cv2.TM_CCOEFF_NORMED)
|
||
_, max_val, _, max_loc = cv2.minMaxLoc(res)
|
||
if best_match is None or max_val > best_match[0]:
|
||
best_match = (max_val, max_loc, resized.shape[1], resized.shape[0])
|
||
|
||
if best_match and best_match[0] >= threshold:
|
||
max_val, max_loc, r_w, r_h = best_match
|
||
center_x = max_loc[0] + r_w // 2
|
||
center_y = max_loc[1] + r_h // 2
|
||
return (center_x, center_y, max_val)
|
||
|
||
return None
|
||
|
||
|
||
def crop_cards_from_image(img_path, output_dir=None, save_debug=True):
|
||
"""
|
||
从图片中裁剪场站卡片并生成 _flag.jpg 和 _vl.jpg
|
||
算法:以导航图标 (arrow.jpg) 为主要锚点,辅以红色价格区域,向上/下/左/右探测背景边界
|
||
"""
|
||
logger.info(f"Processing: {img_path}")
|
||
if not os.path.exists(img_path):
|
||
return []
|
||
|
||
img = read_image(img_path)
|
||
if img is None:
|
||
return []
|
||
|
||
h, w = img.shape[:2]
|
||
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||
s_channel = hsv[:, :, 1]
|
||
|
||
anchors = []
|
||
|
||
# 1. 使用导航图标 (arrow.jpg) 进行模板匹配
|
||
template_path = os.path.join(os.path.dirname(__file__), "BiaoShi", "arrow.jpg")
|
||
if os.path.exists(template_path):
|
||
template = read_image(template_path)
|
||
if template is not None:
|
||
t_h, t_w = template.shape[:2]
|
||
# 模板匹配
|
||
res = cv2.matchTemplate(img, template, cv2.TM_CCOEFF_NORMED)
|
||
threshold = 0.7
|
||
loc = np.where(res >= threshold)
|
||
|
||
# 去重并记录锚点
|
||
matched_points = []
|
||
for pt in zip(*loc[::-1]): # (x, y)
|
||
is_duplicate = False
|
||
for mx, my in matched_points:
|
||
if abs(mx - pt[0]) < 50 and abs(my - pt[1]) < 50:
|
||
is_duplicate = True
|
||
break
|
||
if not is_duplicate:
|
||
matched_points.append(pt)
|
||
anchors.append((pt[0] + t_w // 2, pt[1] + t_h // 2))
|
||
logger.info(f"Found {len(matched_points)} anchors via arrow template matching.")
|
||
|
||
# 2. 如果模板匹配找得不够,或者作为补充,使用红色价格区域
|
||
# 进一步放宽红色范围
|
||
lower_red1 = np.array([0, 30, 30])
|
||
upper_red1 = np.array([15, 255, 255])
|
||
lower_red2 = np.array([150, 30, 30])
|
||
upper_red2 = np.array([180, 255, 255])
|
||
|
||
mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
|
||
mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
|
||
red_mask = cv2.bitwise_or(mask1, mask2)
|
||
|
||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 20))
|
||
morphed = cv2.dilate(red_mask, kernel, iterations=1)
|
||
contours, _ = cv2.findContours(morphed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||
|
||
red_anchors_count = 0
|
||
red_rects = [] # 记录所有红色的矩形区域
|
||
for cnt in contours:
|
||
x, y, cw, ch = cv2.boundingRect(cnt)
|
||
if 20 < cw < 600 and 10 < ch < 150 and y > h * 0.2:
|
||
# 在原始 red_mask 中精确查找该区域的边界,避免膨胀带来的误差
|
||
roi_mask = red_mask[y:y+ch, x:x+cw]
|
||
points = cv2.findNonZero(roi_mask)
|
||
if points is not None:
|
||
rx, ry, rcw, rch = cv2.boundingRect(points)
|
||
# 转换回原图坐标
|
||
exact_rect = (x + rx, y + ry, rcw, rch)
|
||
red_rects.append(exact_rect)
|
||
ax, ay = exact_rect[0] + exact_rect[2] // 2, exact_rect[1] + exact_rect[3] // 2
|
||
else:
|
||
red_rects.append((x, y, cw, ch))
|
||
ax, ay = x + cw // 2, y + ch // 2
|
||
|
||
# 检查是否与已有锚点重合
|
||
is_duplicate = False
|
||
for ex, ey in anchors:
|
||
if abs(ey - ay) < 100:
|
||
is_duplicate = True
|
||
break
|
||
if not is_duplicate:
|
||
anchors.append((ax, ay))
|
||
red_anchors_count += 1
|
||
logger.info(f"Found {red_anchors_count} additional anchors via red price detection.")
|
||
|
||
# 3. 定位背景分隔行 (Blue Separator Rows)
|
||
# 计算每一行的平均饱和度和色调
|
||
h_channel = hsv[:, :, 0]
|
||
row_s_means = np.mean(s_channel[:, int(w*0.2):int(w*0.8)], axis=1)
|
||
row_h_means = np.mean(h_channel[:, int(w*0.2):int(w*0.8)], axis=1)
|
||
row_s_stds = np.std(s_channel[:, int(w*0.2):int(w*0.8)], axis=1)
|
||
|
||
# 分隔行的特征:浅蓝色 (H~105, S>15) 且整行均匀 (std小)
|
||
is_separator = (row_s_means > 15) & (row_h_means > 90) & (row_h_means < 120) & (row_s_stds < 10)
|
||
|
||
# 找出所有“非分隔”区域
|
||
segments = []
|
||
start_y = None
|
||
for y in range(h):
|
||
if not is_separator[y]:
|
||
if start_y is None:
|
||
start_y = y
|
||
else:
|
||
if start_y is not None:
|
||
segments.append((start_y, y))
|
||
start_y = None
|
||
if start_y is not None:
|
||
segments.append((start_y, h))
|
||
|
||
# 合并非常接近的区域 (可能被误认为分隔线的行)
|
||
merged_segments = []
|
||
if segments:
|
||
curr_start, curr_end = segments[0]
|
||
for i in range(1, len(segments)):
|
||
next_start, next_end = segments[i]
|
||
if next_start - curr_end < 20: # 间隙小于 20 像素则合并
|
||
curr_end = next_end
|
||
else:
|
||
merged_segments.append((curr_start, curr_end))
|
||
curr_start, curr_end = next_start, next_end
|
||
merged_segments.append((curr_start, curr_end))
|
||
|
||
logger.info(f"Found {len(merged_segments)} merged segments: {merged_segments}")
|
||
|
||
final_cards = []
|
||
for b_start, b_end in merged_segments:
|
||
card_h = b_end - b_start
|
||
# 场站卡片高度通常在 250 到 500 之间
|
||
if card_h > 150:
|
||
# 检查这个段落里是否有锚点
|
||
segment_anchors = [ (ax, ay) for ax, ay in anchors if b_start - 20 < ay < b_end + 20 ]
|
||
if segment_anchors:
|
||
# 优化下边界:如果存在红色价格,以下边界为准 (用户建议)
|
||
segment_red_rects = [ (rx, ry, rcw, rch) for rx, ry, rcw, rch in red_rects if b_start < ry < b_end ]
|
||
y2_refined = b_end
|
||
if segment_red_rects:
|
||
# 找到最下方的红色区域
|
||
max_red_bottom = max([ry + rch for rx, ry, rcw, rch in segment_red_rects])
|
||
# 用户建议:发现红色字结束就停止计算下边界。给予微小缓冲空间 (5px)
|
||
y2_refined = min(b_end, max_red_bottom + 5)
|
||
logger.info(f"Refined y2 from {b_end} to {y2_refined} based on red text.")
|
||
|
||
# 如果段落太大,可能包含了多个卡片(分隔线没断开)
|
||
if card_h > 600:
|
||
logger.info(f"Segment at y=[{b_start}, {b_end}] is too large ({card_h}), attempting split by anchors...")
|
||
segment_anchors.sort(key=lambda a: a[1])
|
||
|
||
# 只有当锚点之间距离足够大时才拆分
|
||
splits = [b_start]
|
||
for i in range(len(segment_anchors) - 1):
|
||
ay1 = segment_anchors[i][1]
|
||
ay2 = segment_anchors[i+1][1]
|
||
if ay2 - ay1 > 200: # 锚点间距大于 200 才考虑拆分
|
||
# 在两个锚点之间找最像分隔线的行 (饱和度最高)
|
||
split_y = ay1 + np.argmax(row_s_means[ay1:ay2])
|
||
splits.append(split_y)
|
||
splits.append(b_end)
|
||
|
||
for i in range(len(splits) - 1):
|
||
s1, s2 = splits[i], splits[i+1]
|
||
if s2 - s1 > 150:
|
||
# 对拆分后的每个部分也尝试优化下边界
|
||
part_red_rects = [ (rx, ry, rcw, rch) for rx, ry, rcw, rch in red_rects if s1 < ry < s2 ]
|
||
s2_refined = s2
|
||
if part_red_rects:
|
||
max_red_bottom = max([ry + rch for rx, ry, rcw, rch in part_red_rects])
|
||
s2_refined = min(s2, max_red_bottom + 5)
|
||
|
||
final_cards.append((s1, s2_refined, int(w*0.02), int(w*0.98)))
|
||
logger.info(f"Added split card: y=[{s1}, {s2_refined}]")
|
||
else:
|
||
final_cards.append((b_start, y2_refined, int(w*0.02), int(w*0.98)))
|
||
logger.info(f"Added card: y=[{b_start}, {y2_refined}], original_h={card_h}")
|
||
|
||
# 4. 排序 (按 y1 从上到下)
|
||
final_cards.sort(key=lambda c: c[0])
|
||
|
||
# 保存结果图
|
||
if output_dir is None:
|
||
output_dir = os.path.dirname(img_path)
|
||
base_name = os.path.basename(img_path)
|
||
stem, ext = os.path.splitext(base_name)
|
||
|
||
debug_img = img.copy() # _flag.jpg
|
||
vl_img = img.copy() # _vl.jpg
|
||
|
||
json_data = {"image": base_name, "width": w, "height": h, "cards": []}
|
||
|
||
for idx, (y1, y2, x1, x2) in enumerate(final_cards):
|
||
# 计算点击点 (卡片上方区域)
|
||
click_x = int(x1 + (x2 - x1) * 0.2)
|
||
click_y = int(y1 + (y2 - y1) * 0.2)
|
||
|
||
# 在 flag 图上画绿框和红点
|
||
cv2.rectangle(debug_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||
cv2.circle(debug_img, (click_x, click_y), 10, (0, 0, 255), -1)
|
||
|
||
# 在 vl 图上只画绿框
|
||
cv2.rectangle(vl_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||
|
||
json_data["cards"].append({
|
||
"id": idx + 1,
|
||
"rect": [int(x1), int(y1), int(x2), int(y2)],
|
||
"bounds_norm": {
|
||
"left": round(float(x1) / w, 4),
|
||
"top": round(float(y1) / h, 4),
|
||
"right": round(float(x2) / w, 4),
|
||
"bottom": round(float(y2) / h, 4)
|
||
},
|
||
"click_point": [int(click_x), int(click_y)]
|
||
})
|
||
|
||
# 保存文件
|
||
if save_debug:
|
||
save_image(os.path.join(output_dir, f"{stem}_flag{ext}"), debug_img)
|
||
save_image(os.path.join(output_dir, f"{stem}_vl{ext}"), vl_img)
|
||
with open(os.path.join(output_dir, f"{stem}.json"), 'w', encoding='utf-8') as f:
|
||
json.dump(json_data, f, ensure_ascii=False, indent=2)
|
||
|
||
logger.info(f"Generated _flag and _vl images for {len(final_cards)} cards.")
|
||
return json_data
|