Files
HuangHai 6655e0cc29 'commit'
2026-01-18 18:59:17 +08:00

486 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import os
import cv2
import numpy as np
import time
import json
import hashlib
from Apps.AiTeJiYiChong.Config.Setting import BOTTOM_SAFE_EXCLUDE_RATIO
from Config.Config import TEMP_IMAGE_DIR
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def clean_station_name(name):
"""
清理场站名称,去掉末尾的省略号和空格
"""
if not name:
return ""
# 去掉末尾的 ... 或 ……
name = name.strip()
while name.endswith(".") or name.endswith(""):
name = name[:-1]
return name.strip()
def get_file_md5(file_path):
"""计算文件的 MD5 值"""
if not os.path.exists(file_path):
return None
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def get_name_md5(name):
if not name:
return "unknown"
if not isinstance(name, str):
name = str(name)
return hashlib.md5(name.encode("utf-8")).hexdigest()
def get_image_content_md5(file_path, top_ratio=0.1, bottom_ratio=0.1):
"""
计算图片核心内容的 MD5 值(排除状态栏和导航栏)
"""
img = read_image(file_path)
if img is None:
return None
h, w = img.shape[:2]
top = int(h * top_ratio)
bottom = int(h * (1 - bottom_ratio))
# 裁剪中间部分
content = img[top:bottom, :]
# 将图片数据转换为字节流计算 MD5
# 使用 cv2.imencode 转为 jpg 字节流,避免原始 numpy 数组可能的细微差异(虽然通常没问题)
success, encoded_img = cv2.imencode(".jpg", content)
if success:
return hashlib.md5(encoded_img.tobytes()).hexdigest()
return hashlib.md5(content.tobytes()).hexdigest()
def read_image(path):
"""读取图片,支持中文路径"""
if not path or not os.path.exists(path):
return None
try:
# 使用 np.fromfile 解决中文路径问题
data = np.fromfile(path, dtype=np.uint8)
if data.size == 0:
return None
img = cv2.imdecode(data, -1)
return img
except Exception as e:
logger.error(f"Error reading image {path}: {e}")
return None
def save_image(path, img):
"""保存图片,支持中文路径"""
try:
ext = os.path.splitext(path)[1]
if not ext:
ext = ".jpg"
cv2.imencode(ext, img)[1].tofile(path)
return True
except Exception as e:
logger.error(f"Error saving image {path}: {e}")
return False
# 截图
def take_screenshot(d, image_uuid, save_dir=TEMP_IMAGE_DIR):
path = os.path.join(save_dir, f"{image_uuid}.jpg")
os.makedirs(save_dir, exist_ok=True)
d.screenshot(path)
return path
def clear_temp_dir(save_dir=TEMP_IMAGE_DIR):
"""清空临时目录中的所有文件"""
if not os.path.exists(save_dir):
return
logger.info(f"正在清空临时目录: {save_dir}")
for file in os.listdir(save_dir):
file_path = os.path.join(save_dir, file)
try:
if os.path.isfile(file_path):
os.remove(file_path)
elif os.path.isdir(file_path):
import shutil
shutil.rmtree(file_path)
except Exception as e:
logger.error(f"无法删除文件 {file_path}: {e}")
def click_image_template(d, template_path, timeout=5.0, threshold=0.8):
"""
使用 OpenCV 模板匹配查找并点击图片
"""
if not os.path.exists(template_path):
logger.info(f"Template file not found: {template_path}")
return False
template = read_image(template_path)
if template is None:
logger.info(f"Failed to load template: {template_path}")
return False
t_h, t_w = template.shape[:2]
start_time = time.time()
while time.time() - start_time < timeout:
temp_uuid = "temp_click_check"
screenshot_path = take_screenshot(d, temp_uuid, save_dir=TEMP_IMAGE_DIR)
target = read_image(screenshot_path)
if target is None:
time.sleep(0.5)
continue
# 多尺度匹配
best_match = None
for scale in np.linspace(0.8, 1.2, 5):
resized = cv2.resize(template, (int(t_w * scale), int(t_h * scale)))
res = cv2.matchTemplate(target, resized, cv2.TM_CCOEFF_NORMED)
_, max_val, _, max_loc = cv2.minMaxLoc(res)
if best_match is None or max_val > best_match[0]:
best_match = (max_val, max_loc, resized.shape[1], resized.shape[0])
if best_match and best_match[0] >= threshold:
max_val, max_loc, r_w, r_h = best_match
center_x = max_loc[0] + r_w // 2
center_y = max_loc[1] + r_h // 2
logger.info(f"成功点击图片模板: {template_path}, 匹配度: {max_val:.2f}")
d.click(center_x, center_y)
return True
time.sleep(0.5)
return False
def setup_logger(name, log_file=None, clear_old_log=False):
"""
配置日志,支持同时输出到控制台和文件。
使用供应商代号作为父级 Logger所有子 Logger 继承其 Handler
并通过 propagate=False 避免与根 Logger 重复。
:param name: Logger 名称
:param log_file: 指定日志文件路径,如果不指定则使用默认路径
:param clear_old_log: 是否在启动时清空旧日志文件
"""
# 1. 获取供应商代号 (如 AiTeJiYiChong)
supplier_code = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
# 2. 获取父级 Logger 并配置
parent_logger = logging.getLogger(supplier_code)
parent_logger.setLevel(logging.INFO)
parent_logger.propagate = False # 禁止向上传递给 root logger防止重复
if log_file is None:
# 获取项目根目录 (aiData)
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
log_dir = os.path.join(root_dir, "Logs")
if not os.path.exists(log_dir):
os.makedirs(log_dir)
log_file = os.path.join(log_dir, f"{supplier_code}.log")
# 如果需要清空旧日志且文件存在
if clear_old_log and os.path.exists(log_file):
try:
# 关闭现有的 handler 以便删除文件
for handler in parent_logger.handlers[:]:
handler.close()
parent_logger.removeHandler(handler)
os.remove(log_file)
except Exception as e:
print(f"无法清空旧日志文件 {log_file}: {e}")
if not parent_logger.handlers:
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# 控制台 Handler
ch = logging.StreamHandler()
ch.setFormatter(formatter)
parent_logger.addHandler(ch)
# 文件 Handler
fh = logging.FileHandler(log_file, encoding='utf-8')
fh.setFormatter(formatter)
parent_logger.addHandler(fh)
# 3. 返回子 Logger
if name == supplier_code:
return parent_logger
return logging.getLogger(f"{supplier_code}.{name}")
# 默认使用供应商级别的日志配置
logger = setup_logger("Kit")
def find_template_coords(img_path, template_path, threshold=0.8):
"""
在图片中查找模板并返回中心坐标
"""
if not os.path.exists(img_path) or not os.path.exists(template_path):
return None
img = read_image(img_path)
template = read_image(template_path)
if img is None or template is None:
return None
t_h, t_w = template.shape[:2]
# 多尺度匹配
best_match = None
for scale in np.linspace(0.8, 1.2, 5):
resized = cv2.resize(template, (int(t_w * scale), int(t_h * scale)))
if resized.shape[0] > img.shape[0] or resized.shape[1] > img.shape[1]:
continue
res = cv2.matchTemplate(img, resized, cv2.TM_CCOEFF_NORMED)
_, max_val, _, max_loc = cv2.minMaxLoc(res)
if best_match is None or max_val > best_match[0]:
best_match = (max_val, max_loc, resized.shape[1], resized.shape[0])
if best_match and best_match[0] >= threshold:
max_val, max_loc, r_w, r_h = best_match
center_x = max_loc[0] + r_w // 2
center_y = max_loc[1] + r_h // 2
return (center_x, center_y, max_val)
return None
def crop_cards_from_image(img_path, output_dir=None, save_debug=True):
"""
从图片中裁剪场站卡片并生成 _flag.jpg 和 _vl.jpg
算法:以导航图标 (arrow.jpg) 为主要锚点,辅以红色价格区域,向上/下/左/右探测背景边界
"""
logger.info(f"Processing: {img_path}")
if not os.path.exists(img_path):
return []
img = read_image(img_path)
if img is None:
return []
h, w = img.shape[:2]
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
s_channel = hsv[:, :, 1]
anchors = []
# 1. 使用导航图标 (arrow.jpg) 进行模板匹配
template_path = os.path.join(os.path.dirname(__file__), "BiaoShi", "arrow.jpg")
if os.path.exists(template_path):
template = read_image(template_path)
if template is not None:
t_h, t_w = template.shape[:2]
# 模板匹配
res = cv2.matchTemplate(img, template, cv2.TM_CCOEFF_NORMED)
threshold = 0.7
loc = np.where(res >= threshold)
# 去重并记录锚点
matched_points = []
for pt in zip(*loc[::-1]): # (x, y)
is_duplicate = False
for mx, my in matched_points:
if abs(mx - pt[0]) < 50 and abs(my - pt[1]) < 50:
is_duplicate = True
break
if not is_duplicate:
matched_points.append(pt)
anchors.append((pt[0] + t_w // 2, pt[1] + t_h // 2))
logger.info(f"Found {len(matched_points)} anchors via arrow template matching.")
# 2. 如果模板匹配找得不够,或者作为补充,使用红色价格区域
# 进一步放宽红色范围
lower_red1 = np.array([0, 30, 30])
upper_red1 = np.array([15, 255, 255])
lower_red2 = np.array([150, 30, 30])
upper_red2 = np.array([180, 255, 255])
mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
red_mask = cv2.bitwise_or(mask1, mask2)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 20))
morphed = cv2.dilate(red_mask, kernel, iterations=1)
contours, _ = cv2.findContours(morphed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
red_anchors_count = 0
red_rects = [] # 记录所有红色的矩形区域
for cnt in contours:
x, y, cw, ch = cv2.boundingRect(cnt)
if 20 < cw < 600 and 10 < ch < 150 and y > h * 0.2:
# 在原始 red_mask 中精确查找该区域的边界,避免膨胀带来的误差
roi_mask = red_mask[y:y+ch, x:x+cw]
points = cv2.findNonZero(roi_mask)
if points is not None:
rx, ry, rcw, rch = cv2.boundingRect(points)
# 转换回原图坐标
exact_rect = (x + rx, y + ry, rcw, rch)
red_rects.append(exact_rect)
ax, ay = exact_rect[0] + exact_rect[2] // 2, exact_rect[1] + exact_rect[3] // 2
else:
red_rects.append((x, y, cw, ch))
ax, ay = x + cw // 2, y + ch // 2
# 检查是否与已有锚点重合
is_duplicate = False
for ex, ey in anchors:
if abs(ey - ay) < 100:
is_duplicate = True
break
if not is_duplicate:
anchors.append((ax, ay))
red_anchors_count += 1
logger.info(f"Found {red_anchors_count} additional anchors via red price detection.")
# 3. 定位背景分隔行 (Blue Separator Rows)
# 计算每一行的平均饱和度和色调
h_channel = hsv[:, :, 0]
row_s_means = np.mean(s_channel[:, int(w*0.2):int(w*0.8)], axis=1)
row_h_means = np.mean(h_channel[:, int(w*0.2):int(w*0.8)], axis=1)
row_s_stds = np.std(s_channel[:, int(w*0.2):int(w*0.8)], axis=1)
# 分隔行的特征:浅蓝色 (H~105, S>15) 且整行均匀 (std小)
is_separator = (row_s_means > 15) & (row_h_means > 90) & (row_h_means < 120) & (row_s_stds < 10)
# 找出所有“非分隔”区域
segments = []
start_y = None
for y in range(h):
if not is_separator[y]:
if start_y is None:
start_y = y
else:
if start_y is not None:
segments.append((start_y, y))
start_y = None
if start_y is not None:
segments.append((start_y, h))
# 合并非常接近的区域 (可能被误认为分隔线的行)
merged_segments = []
if segments:
curr_start, curr_end = segments[0]
for i in range(1, len(segments)):
next_start, next_end = segments[i]
if next_start - curr_end < 20: # 间隙小于 20 像素则合并
curr_end = next_end
else:
merged_segments.append((curr_start, curr_end))
curr_start, curr_end = next_start, next_end
merged_segments.append((curr_start, curr_end))
logger.info(f"Found {len(merged_segments)} merged segments: {merged_segments}")
final_cards = []
for b_start, b_end in merged_segments:
card_h = b_end - b_start
# 场站卡片高度通常在 250 到 500 之间
if card_h > 150:
# 检查这个段落里是否有锚点
segment_anchors = [ (ax, ay) for ax, ay in anchors if b_start - 20 < ay < b_end + 20 ]
if segment_anchors:
# 优化下边界:如果存在红色价格,以下边界为准 (用户建议)
segment_red_rects = [ (rx, ry, rcw, rch) for rx, ry, rcw, rch in red_rects if b_start < ry < b_end ]
y2_refined = b_end
if segment_red_rects:
# 找到最下方的红色区域
max_red_bottom = max([ry + rch for rx, ry, rcw, rch in segment_red_rects])
# 用户建议:发现红色字结束就停止计算下边界。给予微小缓冲空间 (5px)
y2_refined = min(b_end, max_red_bottom + 5)
logger.info(f"Refined y2 from {b_end} to {y2_refined} based on red text.")
# 如果段落太大,可能包含了多个卡片(分隔线没断开)
if card_h > 600:
logger.info(f"Segment at y=[{b_start}, {b_end}] is too large ({card_h}), attempting split by anchors...")
segment_anchors.sort(key=lambda a: a[1])
# 只有当锚点之间距离足够大时才拆分
splits = [b_start]
for i in range(len(segment_anchors) - 1):
ay1 = segment_anchors[i][1]
ay2 = segment_anchors[i+1][1]
if ay2 - ay1 > 200: # 锚点间距大于 200 才考虑拆分
# 在两个锚点之间找最像分隔线的行 (饱和度最高)
split_y = ay1 + np.argmax(row_s_means[ay1:ay2])
splits.append(split_y)
splits.append(b_end)
for i in range(len(splits) - 1):
s1, s2 = splits[i], splits[i+1]
if s2 - s1 > 150:
# 对拆分后的每个部分也尝试优化下边界
part_red_rects = [ (rx, ry, rcw, rch) for rx, ry, rcw, rch in red_rects if s1 < ry < s2 ]
s2_refined = s2
if part_red_rects:
max_red_bottom = max([ry + rch for rx, ry, rcw, rch in part_red_rects])
s2_refined = min(s2, max_red_bottom + 5)
final_cards.append((s1, s2_refined, int(w*0.02), int(w*0.98)))
logger.info(f"Added split card: y=[{s1}, {s2_refined}]")
else:
final_cards.append((b_start, y2_refined, int(w*0.02), int(w*0.98)))
logger.info(f"Added card: y=[{b_start}, {y2_refined}], original_h={card_h}")
# 4. 排序 (按 y1 从上到下)
final_cards.sort(key=lambda c: c[0])
# 保存结果图
if output_dir is None:
output_dir = os.path.dirname(img_path)
base_name = os.path.basename(img_path)
stem, ext = os.path.splitext(base_name)
debug_img = img.copy() # _flag.jpg
vl_img = img.copy() # _vl.jpg
json_data = {"image": base_name, "width": w, "height": h, "cards": []}
for idx, (y1, y2, x1, x2) in enumerate(final_cards):
# 计算点击点 (卡片上方区域)
click_x = int(x1 + (x2 - x1) * 0.2)
click_y = int(y1 + (y2 - y1) * 0.2)
# 在 flag 图上画绿框和红点
cv2.rectangle(debug_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.circle(debug_img, (click_x, click_y), 10, (0, 0, 255), -1)
# 在 vl 图上只画绿框
cv2.rectangle(vl_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
json_data["cards"].append({
"id": idx + 1,
"rect": [int(x1), int(y1), int(x2), int(y2)],
"bounds_norm": {
"left": round(float(x1) / w, 4),
"top": round(float(y1) / h, 4),
"right": round(float(x2) / w, 4),
"bottom": round(float(y2) / h, 4)
},
"click_point": [int(click_x), int(click_y)]
})
# 保存文件
if save_debug:
save_image(os.path.join(output_dir, f"{stem}_flag{ext}"), debug_img)
save_image(os.path.join(output_dir, f"{stem}_vl{ext}"), vl_img)
with open(os.path.join(output_dir, f"{stem}.json"), 'w', encoding='utf-8') as f:
json.dump(json_data, f, ensure_ascii=False, indent=2)
logger.info(f"Generated _flag and _vl images for {len(final_cards)} cards.")
return json_data