Files
aiData/Apps/TeLaiDian/FirstPageKit.py
HuangHai 6688daa446 'commit'
2026-01-18 13:44:51 +08:00

447 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import os
import sys
import cv2
import numpy as np
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if project_root not in sys.path:
sys.path.append(project_root)
from Apps.TeLaiDian.Kit import setup_logger, get_ocr_reader, draw_rectangles
from Apps.TeLaiDian.Config.Setting import SAFE_EXCLUDE_RATIO, BOTTOM_SAFE_EXCLUDE_RATIO
from Util.LlmUtil import get_llm_response
logger = setup_logger("TeLaiDian.FirstPageKit")
TEXT_TOP_RATIO = SAFE_EXCLUDE_RATIO * 0.8
NON_STATION_KEYWORDS = [
"地图",
"目的地",
"电站名",
"充电礼",
"再充",
"注册礼",
"元券",
"PLUS会员",
"我的收藏",
"最近充电",
"我的卡券",
"我的订单",
"充电券",
"电信积分兑换",
"确认",
"广告",
"距离/区域",
"综合排序",
"偏好",
"星级站",
"停车减免",
"重卡可用",
"首页",
"特省钱",
"扫码",
"输入",
"商城",
"推荐",
]
def _load_image(path):
if not os.path.exists(path):
raise FileNotFoundError(path)
img = cv2.imread(path)
if img is None:
raise RuntimeError(f"无法读取图片: {path}")
h, w = img.shape[:2]
return img, w, h
def _extract_json(text: str) -> str:
if not text:
return "[]"
cleaned = text.strip()
if "```" in cleaned:
lines = []
for line in cleaned.splitlines():
if line.strip().startswith("```"):
continue
lines.append(line)
cleaned = "\n".join(lines).strip()
decoder = json.JSONDecoder()
pos = 0
while pos < len(cleaned):
idx_dict = cleaned.find("{", pos)
idx_list = cleaned.find("[", pos)
candidates = [i for i in (idx_dict, idx_list) if i != -1]
if not candidates:
break
start = min(candidates)
snippet = cleaned[start:]
try:
_, end = decoder.raw_decode(snippet)
return snippet[:end]
except json.JSONDecodeError:
pos = start + 1
continue
return "[]"
async def run_ocr_rect(image_path, log_path=None):
log_lines = []
def log_detail(msg):
logger.info(msg)
log_lines.append(msg)
img, w, h = _load_image(image_path)
log_detail(f"开始处理图片: {image_path}, 宽={w}, 高={h}")
reader = get_ocr_reader()
ocr_results = reader.read_text(img)
log_detail(f"OCR 原始结果数量: {len(ocr_results)}")
entries = []
for idx, (quad, text, prob) in enumerate(ocr_results):
pts = np.array(quad).astype(int)
x_min = int(np.min(pts[:, 0]))
y_min = int(np.min(pts[:, 1]))
x_max = int(np.max(pts[:, 0]))
y_max = int(np.max(pts[:, 1]))
cx = (x_min + x_max) / 2.0
cy = (y_min + y_max) / 2.0
cx_norm = cx / w
cy_norm = cy / h
status = "keep"
reasons = []
txt = text or ""
if prob < 0.3:
status = "drop"
reasons.append("prob<0.3")
if not txt:
status = "drop"
reasons.append("empty_text")
has_station_kw = False
has_distance_kw = False
if txt:
if ("充电站" in txt) or ("超快充" in txt) or ("快充" in txt) or ("慢充" in txt):
has_station_kw = True
if ("km" in txt) or ("m" in txt):
has_distance_kw = True
if cy_norm < TEXT_TOP_RATIO and not (
(has_station_kw and len(txt) >= 4) or has_distance_kw
):
status = "drop"
reasons.append("top_safe_zone")
if cy_norm > (1 - BOTTOM_SAFE_EXCLUDE_RATIO):
status = "drop"
reasons.append("bottom_safe_zone")
if status == "keep" and txt:
for kw in NON_STATION_KEYWORDS:
if kw and kw in txt:
status = "drop"
reasons.append("non_station_keyword")
break
log_detail(
f"OCR[{idx + 1}] text={repr(text)} prob={prob:.3f} "
f"cx_norm={cx_norm:.4f} cy_norm={cy_norm:.4f} "
f"status={status} reasons={','.join(reasons) if reasons else '-'}"
)
if status != "keep":
continue
entries.append(
{
"text": text,
"prob": float(prob),
"cx_norm": cx_norm,
"cy_norm": cy_norm,
}
)
log_detail(f"OCR 通过过滤的有效文本数量: {len(entries)}")
if not entries:
log_detail("无有效 OCR 文本, 结束当前图片处理")
return
indexed_entries = []
for idx, e in enumerate(entries):
indexed_entries.append(
{
"id": idx + 1,
"text": e["text"],
"prob": e["prob"],
"cx_norm": round(e["cx_norm"], 4),
"cy_norm": round(e["cy_norm"], 4),
}
)
payload_json = json.dumps(indexed_entries, ensure_ascii=False)
log_detail(f"传给 LLM 的 OCR 条目数: {len(indexed_entries)}")
query_text = (
"下面是特来电列表页整张截图的 OCR 结果,每一项代表一行文字,包含其中心点的归一化坐标:\n"
"ocr_items = " + payload_json + "\n"
"请你根据这些文本,将它们聚合成若干个“充电场站卡片”。输出一个 JSON 数组,每个元素必须包含:\n"
"1) station_name: 场站名称,只能是卡片标题中的名称,不允许是筛选标签、导航按钮、底部功能区等。\n"
"2) anchor_point_norm: 一个对象 {\"x\": number, \"y\": number},表示该场站名称文字所在行的中心点坐标,取值范围 0-1。\n"
"并且尽量补充以下可选字段(找不到时可以省略或设为 null:\n"
"3) distance_text: 距离字符串,例如 \"6.9km\"\"500m\",从对应卡片中的距离行提取。\n"
"4) busy_list: 忙闲信息数组,数组中的每一项是 {\"mode\": \"快|慢|超|普通\", \"idle\": number, \"total\": number}。\n"
" 在特来电 UI 中,忙闲文本通常是 \"空闲x/总y\"\"快 空闲x/总y\" 这种形式,\n"
" 请从相应行中解析出模式和空闲/总数。\n"
"额外提示:\n"
"- 每个场站卡片通常包含一行类似 \"1.4km\"\"3.6km\" 的距离文本;\n"
"- 该距离文本所在行的左侧、且在同一卡片中的那一行文字,就是对应的场站标题 station_name\n"
"- 忙闲信息通常出现在卡片右侧的彩色小块中,例如 \"快 空闲24/32\"\"慢 空闲0/10\" 等;\n"
"要求:\n"
"- 场站按从上到下排序;\n"
"- station_name 不能取距离行本身(如 \"1.4km\"),而是要取与之成一对的标题行;\n"
"- 如果某些 OCR 文本显然不属于任何场站卡片,可以忽略;\n"
"- 只输出 JSON 数组,不要输出其它任何文字。"
)
chunks = []
async for part in get_llm_response(
query_text=query_text,
stream=False,
system_prompt="你是一个帮助整理 OCR 文本的助手,只输出 JSON。",
chat_history=None,
temperature=0,
):
chunks.append(part)
full_text = "".join(chunks)
log_detail("LLM 原始返回内容开始")
log_lines.append(full_text)
log_detail("LLM 原始返回内容结束")
raw = _extract_json(full_text)
log_detail(f"从 LLM 返回内容中抽取出的 JSON 片段: {raw}")
try:
data = json.loads(raw)
except Exception as e:
log_detail(f"解析 LLM 返回 JSON 失败: {e}")
logs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "Logs")
os.makedirs(logs_dir, exist_ok=True)
if log_path is None:
final_log_path = os.path.join(
logs_dir, os.path.basename(image_path).replace(".jpg", ".log")
)
mode = "w"
else:
final_log_path = log_path
mode = "a"
with open(final_log_path, mode, encoding="utf-8") as f:
for line in log_lines:
f.write(line + "\n")
log_detail(f"已写入详细日志到: {final_log_path}")
return
if not isinstance(data, list):
log_detail(f"期望 LLM 返回 JSON 数组, 实际类型: {type(data)}")
return
if not data:
log_detail("LLM 返回空数组, 结束当前图片处理")
return
stations = []
for idx, item in enumerate(data):
if not isinstance(item, dict):
log_detail(f"LLM item[{idx}] 不是对象类型, 跳过")
continue
name = item.get("station_name") or item.get("name")
anchor = item.get("anchor_point_norm") or item.get("anchor")
distance = item.get("distance_text") or item.get("distance")
busy_list = item.get("busy_list") or []
if not name or not isinstance(anchor, dict):
log_detail(f"LLM item[{idx}] 缺少必要字段, 跳过: {item}")
continue
ax = float(anchor.get("x", 0.5))
ay = float(anchor.get("y", 0.5))
if not (0 <= ax <= 1 and 0 <= ay <= 1):
log_detail(f"LLM item[{idx}] anchor_point_norm 超出范围, 跳过: {anchor}")
continue
px = int(ax * w)
py = int(ay * h)
stations.append(
{
"station_name": name,
"anchor_point_norm": {"x": ax, "y": ay},
"distance_text": distance,
"busy_list": busy_list,
"anchor_px": px,
"anchor_py": py,
}
)
log_detail(
f"LLM anchor 规范化[{len(stations)}] name={name} ax={ax:.4f} ay={ay:.4f} "
f"px={px} py={py} distance={distance} busy_list={busy_list}"
)
if not stations:
log_detail("LLM 解析后无有效场站, 结束当前图片处理")
return
filtered = []
for s in stations:
dl = s.get("distance_text")
bl = s.get("busy_list") or []
ok_dist = isinstance(dl, str) and (("km" in dl) or ("m" in dl))
ok_busy = isinstance(bl, list) and len(bl) > 0
if ok_dist:
filtered.append(s)
if not ok_busy:
log_detail(f"场站缺少忙闲信息但保留: {s.get('station_name')}")
else:
log_detail(f"丢弃缺少距离信息的条目: {s.get('station_name')}")
stations = filtered
if not stations:
log_detail("过滤后无有效场站, 结束当前图片处理")
return
stations.sort(key=lambda s: s["anchor_py"])
min_gap = None
if len(stations) >= 2:
gaps = []
for i in range(1, len(stations)):
gaps.append(stations[i]["anchor_py"] - stations[i - 1]["anchor_py"])
min_gap = min(gaps)
box_h = min_gap
if box_h < int(h * 0.20):
box_h = int(h * 0.20)
if box_h > int(h * 0.32):
box_h = int(h * 0.32)
log_detail(f"根据最小锚点间距调整 box_h: min_gap={min_gap}, final={box_h}")
else:
box_h = int(h * 0.22)
log_detail(f"仅有一个场站, 使用默认 box_h={box_h}")
box_w = int(w * 0.90)
x1_fixed = int((w - box_w) / 2)
x2_fixed = x1_fixed + box_w
result = []
rects = []
click_points = []
effective_top = int(h * SAFE_EXCLUDE_RATIO)
effective_bottom = int(h * (1 - BOTTOM_SAFE_EXCLUDE_RATIO))
prev_y2 = None
anchor_ratio = 0.15
for idx, st in enumerate(stations, start=1):
name = st.get("station_name")
py = st["anchor_py"]
anchor_norm = st.get("anchor_point_norm") or {}
try:
ay = float(anchor_norm.get("y", py / h))
except Exception:
ay = py / h
y1 = int(py - box_h * anchor_ratio)
y2 = y1 + box_h
if min_gap is not None:
downward_limit = int(min_gap * 0.75)
max_y2 = py + downward_limit
if y2 > max_y2:
y2 = max_y2
y1 = y2 - box_h
if y1 < 0:
y1 = 0
y2 = y1 + box_h
if y2 > h:
y2 = h
y1 = y2 - box_h
if y2 > effective_bottom:
y2 = effective_bottom
y1 = y2 - box_h
if prev_y2 is not None and y1 <= prev_y2:
shift = prev_y2 - y1 + 1
y1 += shift
y2 += shift
if y2 > effective_bottom:
if idx == len(stations):
new_y1 = prev_y2 + 1
new_y2 = effective_bottom
if new_y2 <= new_y1:
log_detail(f"底部空间不足,丢弃: {st.get('station_name')}")
continue
y1 = new_y1
y2 = new_y2
else:
log_detail(f"避免重叠无法放置,丢弃: {st.get('station_name')}")
continue
rect = [x1_fixed, y1, x2_fixed, y2]
cx = int((rect[0] + rect[2]) / 2)
cy = int((rect[1] + rect[3]) / 2)
click_points.append([cx, cy])
rects.append(rect)
prev_y2 = y2
item = {
"index": idx,
"station_name": st["station_name"],
"rect": rect,
"click_point": [cx, cy],
"distance_text": st.get("distance_text"),
"busy_list": st.get("busy_list") or [],
}
result.append(item)
log_detail(
f"Station[{idx}] name={item['station_name']} "
f"rect={rect} click={item['click_point']} "
f"distance={item['distance_text']} busy_list={item['busy_list']}"
)
try:
draw_rectangles(image_path, rects, click_points, save_vl=False)
except Exception as e:
log_detail(f"绘制调试矩形失败: {e}")
logs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "Logs")
os.makedirs(logs_dir, exist_ok=True)
if log_path is None:
final_log_path = os.path.join(
logs_dir, os.path.basename(image_path).replace(".jpg", ".log")
)
mode = "w"
else:
final_log_path = log_path
mode = "a"
with open(final_log_path, mode, encoding="utf-8") as f:
for line in log_lines:
f.write(line + "\n")
log_detail(f"已写入详细日志到: {final_log_path}")
return result
async def run_batch_in_dir(image_dir, log_file=None):
img_files = []
for name in os.listdir(image_dir):
lower = name.lower()
if not (lower.endswith(".jpg") or lower.endswith(".png")):
continue
if "_flag" in lower or "_vl" in lower:
continue
img_files.append(os.path.join(image_dir, name))
img_files.sort()
if not img_files:
logger.info(f"目录下未找到图片: {image_dir}")
return
for idx, path in enumerate(img_files, start=1):
logger.info(f"[批处理] 开始处理第 {idx} 张图片: {path}")
try:
await run_ocr_rect(path, log_path=log_file)
except Exception as e:
logger.exception(f"[批处理] 处理图片失败: {path}, {e}")