593 lines
20 KiB
Python
593 lines
20 KiB
Python
import asyncio
|
||
import json
|
||
import os
|
||
import sys
|
||
|
||
import cv2
|
||
import numpy as np
|
||
|
||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
if project_root not in sys.path:
|
||
sys.path.append(project_root)
|
||
|
||
from Apps.YeLiTe.Kit import setup_logger, get_ocr_reader
|
||
from Apps.YeLiTe.ReadImageKit import ReadImageKit
|
||
from Apps.YeLiTe.Config.Setting import SAFE_EXCLUDE_RATIO, BOTTOM_SAFE_EXCLUDE_RATIO, STATION_BOX_WIDTH_RATIO, STATION_BOX_HEIGHT_RATIO
|
||
from Util.LlmUtil import get_llm_response
|
||
|
||
|
||
logger = setup_logger("YeLiTe.FirstPageKit")
|
||
NON_STATION_KEYWORDS = [
|
||
"首页",
|
||
"找桩",
|
||
"会员",
|
||
"我的",
|
||
"扫码充电",
|
||
"扫码",
|
||
"搜索场站或目的地",
|
||
"签到",
|
||
"优惠券",
|
||
"站内搜索",
|
||
"筛选",
|
||
]
|
||
STATION_TITLE_KEYWORDS = [
|
||
"充电站",
|
||
"交流站",
|
||
"超快充",
|
||
"快充",
|
||
"换电站",
|
||
"加电站",
|
||
"超级充电",
|
||
"超充",
|
||
"充电",
|
||
]
|
||
TOP_ZONE_STATION_HINT_KEYWORDS = [
|
||
"充电站",
|
||
"公共充电",
|
||
"超级充电",
|
||
"慢充站",
|
||
"慢充",
|
||
"超快充",
|
||
"快充",
|
||
"充电",
|
||
"超充",
|
||
]
|
||
|
||
LLM_NON_STATION_NAME_KEYWORDS = [
|
||
"收费停车",
|
||
"限时免费停车",
|
||
"停车费",
|
||
"END",
|
||
"End",
|
||
"end",
|
||
"结束",
|
||
]
|
||
|
||
|
||
def _load_image(path):
|
||
if not os.path.exists(path):
|
||
raise FileNotFoundError(path)
|
||
img = cv2.imread(path)
|
||
if img is None:
|
||
raise RuntimeError(f"无法读取图片: {path}")
|
||
h, w = img.shape[:2]
|
||
return img, w, h
|
||
|
||
|
||
def _extract_json(text: str) -> str:
|
||
if not text:
|
||
return "[]"
|
||
cleaned = text.strip()
|
||
if "```" in cleaned:
|
||
lines = []
|
||
for line in cleaned.splitlines():
|
||
if line.strip().startswith("```"):
|
||
continue
|
||
lines.append(line)
|
||
cleaned = "\n".join(lines).strip()
|
||
decoder = json.JSONDecoder()
|
||
pos = 0
|
||
while pos < len(cleaned):
|
||
idx_dict = cleaned.find("{", pos)
|
||
idx_list = cleaned.find("[", pos)
|
||
candidates = [i for i in (idx_dict, idx_list) if i != -1]
|
||
if not candidates:
|
||
break
|
||
start = min(candidates)
|
||
snippet = cleaned[start:]
|
||
try:
|
||
_, end = decoder.raw_decode(snippet)
|
||
return snippet[:end]
|
||
except json.JSONDecodeError:
|
||
pos = start + 1
|
||
continue
|
||
return "[]"
|
||
|
||
|
||
async def run_ocr_rect(image_path, log_path=None):
|
||
log_lines = []
|
||
|
||
def log_detail(msg):
|
||
logger.info(msg)
|
||
log_lines.append(msg)
|
||
|
||
img, w, h = _load_image(image_path)
|
||
log_detail(f"开始处理图片: {image_path}, 宽={w}, 高={h}")
|
||
|
||
reader = get_ocr_reader()
|
||
ocr_results = reader.read_text(img)
|
||
log_detail(f"OCR 原始结果数量: {len(ocr_results)}")
|
||
|
||
entries = []
|
||
for idx, (quad, text, prob) in enumerate(ocr_results):
|
||
pts = np.array(quad).astype(int)
|
||
x_min = int(np.min(pts[:, 0]))
|
||
y_min = int(np.min(pts[:, 1]))
|
||
x_max = int(np.max(pts[:, 0]))
|
||
y_max = int(np.max(pts[:, 1]))
|
||
cx = (x_min + x_max) / 2.0
|
||
cy = (y_min + y_max) / 2.0
|
||
cx_norm = cx / w
|
||
cy_norm = cy / h
|
||
|
||
status = "keep"
|
||
reasons = []
|
||
if prob < 0.3:
|
||
status = "drop"
|
||
reasons.append("prob<0.3")
|
||
if not text:
|
||
status = "drop"
|
||
reasons.append("empty_text")
|
||
if cy_norm < SAFE_EXCLUDE_RATIO:
|
||
status = "drop"
|
||
reasons.append("top_safe_zone")
|
||
if cy_norm > (1 - BOTTOM_SAFE_EXCLUDE_RATIO):
|
||
status = "drop"
|
||
reasons.append("bottom_safe_zone")
|
||
if status == "keep":
|
||
for kw in NON_STATION_KEYWORDS:
|
||
if kw and kw in text:
|
||
status = "drop"
|
||
reasons.append("non_station_keyword")
|
||
break
|
||
if status == "drop" and "prob<0.3" in reasons and text:
|
||
for kw in STATION_TITLE_KEYWORDS:
|
||
if kw and kw in text:
|
||
status = "keep"
|
||
reasons.append("force_keep_station_title")
|
||
break
|
||
if (
|
||
status == "drop"
|
||
and "prob<0.3" in reasons
|
||
and text
|
||
and "/" in text
|
||
and any(ch.isdigit() for ch in text)
|
||
and "top_safe_zone" not in reasons
|
||
and "bottom_safe_zone" not in reasons
|
||
):
|
||
status = "keep"
|
||
reasons.append("force_keep_busy_or_price_pattern")
|
||
if status == "drop" and "top_safe_zone" in reasons and text:
|
||
if SAFE_EXCLUDE_RATIO - 0.06 <= cy_norm < SAFE_EXCLUDE_RATIO:
|
||
for kw in TOP_ZONE_STATION_HINT_KEYWORDS:
|
||
if kw and kw in text:
|
||
status = "keep"
|
||
reasons.append("force_keep_top_station_title")
|
||
break
|
||
|
||
log_detail(
|
||
f"OCR[{idx + 1}] text={repr(text)} prob={prob:.3f} "
|
||
f"cx_norm={cx_norm:.4f} cy_norm={cy_norm:.4f} "
|
||
f"status={status} reasons={','.join(reasons) if reasons else '-'}"
|
||
)
|
||
|
||
if status != "keep":
|
||
continue
|
||
|
||
entries.append(
|
||
{
|
||
"text": text,
|
||
"prob": float(prob),
|
||
"cx_norm": cx_norm,
|
||
"cy_norm": cy_norm,
|
||
}
|
||
)
|
||
|
||
log_detail(f"OCR 通过过滤的有效文本数量: {len(entries)}")
|
||
|
||
if not entries:
|
||
log_detail("无有效 OCR 文本, 结束当前图片处理")
|
||
return
|
||
|
||
indexed_entries = []
|
||
for idx, e in enumerate(entries):
|
||
indexed_entries.append(
|
||
{
|
||
"id": idx + 1,
|
||
"text": e["text"],
|
||
"prob": e["prob"],
|
||
"cx_norm": round(e["cx_norm"], 4),
|
||
"cy_norm": round(e["cy_norm"], 4),
|
||
}
|
||
)
|
||
|
||
payload_json = json.dumps(indexed_entries, ensure_ascii=False)
|
||
log_detail(f"传给 LLM 的 OCR 条目数: {len(indexed_entries)}")
|
||
|
||
query_text = (
|
||
"下面是驿来特列表页整张截图的 OCR 结果,每一项代表一行文字,包含其中心点的归一化坐标:\n"
|
||
"ocr_items = " + payload_json + "\n"
|
||
"请你根据这些文本,将它们聚合成若干个“充电场站卡片”。输出一个 JSON 数组,每个元素必须包含:\n"
|
||
"1) station_name: 场站名称,只能是卡片标题中的名称,不允许是筛选标签、导航按钮、底部功能区等。\n"
|
||
"2) anchor_point_norm: 一个对象 {\"x\": number, \"y\": number},表示该场站名称文字所在行的中心点坐标,取值范围 0-1。\n"
|
||
"并且尽量补充以下可选字段(找不到时可以省略或设为 null):\n"
|
||
"3) distance_text: 距离字符串,例如 \"6.9km\"、\"500m\",从对应卡片中的距离行提取。\n"
|
||
"4) busy_list: 忙闲信息数组,数组中的每一项是 {\"mode\": \"快|慢|超|普通\", \"idle\": number, \"total\": number}。\n"
|
||
" 在驿来特 UI 中,忙闲文本通常是 \"2/16\"、\"34/54\"、\"1/2\" 这种形式,\n"
|
||
" 有时前面会带有“快”“慢”“超”等模式字样,有时则只有数字,语义都是“当前空闲/总数”。\n"
|
||
" 例如:\"2/16\"(快充) => {\"mode\": \"快\", \"idle\": 2, \"total\": 16};\n"
|
||
" \"34/54\" => {\"mode\": \"快\", \"idle\": 34, \"total\": 54};\n"
|
||
" \"1/2\"(交流慢充) => {\"mode\": \"慢\", \"idle\": 1, \"total\": 2}。\n"
|
||
" 如果没有明确模式字样,可以根据上下文合理猜测一个模式(通常为\"快\"或\"慢\"),但 mode 字段必须填写。\n"
|
||
"额外提示:\n"
|
||
"- 每个场站卡片通常包含一行类似 \"1.4km\"、\"3.6km\" 的距离文本;\n"
|
||
"- 该距离文本所在行的左侧、且在同一卡片中的那一行文字,就是对应的场站标题 station_name;\n"
|
||
"- 忙闲信息通常出现在卡片右侧的彩色小块中,例如 \"2/16\"、\"34/54\"、\"1/2\" 等;\n"
|
||
"- 即使 station_name 中不包含“充电站”“超快充”等字样(例如包含“交流站”),但只要与某个 \"x.xkm\" 行在同一卡片区域内,也应视为一个完整的场站名称。\n"
|
||
"要求:\n"
|
||
"- 场站按从上到下排序;\n"
|
||
"- station_name 不能取距离行本身(如 \"1.4km\"),而是要取与之成一对的标题行;\n"
|
||
"- 如果某些 OCR 文本显然不属于任何场站卡片,可以忽略;\n"
|
||
"- 不要输出用于标记结束的 'END'、'end'、'结束' 等占位元素,数组中每个元素都必须是真实存在的充电场站;\n"
|
||
"- 只输出 JSON 数组,不要输出其它任何文字。"
|
||
)
|
||
|
||
chunks = []
|
||
async for part in get_llm_response(
|
||
query_text=query_text,
|
||
stream=False,
|
||
system_prompt="你是一个帮助整理 OCR 文本的助手,只输出 JSON。",
|
||
chat_history=None,
|
||
temperature=0,
|
||
):
|
||
chunks.append(part)
|
||
|
||
full_text = "".join(chunks)
|
||
log_detail("LLM 原始返回内容开始")
|
||
log_lines.append(full_text)
|
||
log_detail("LLM 原始返回内容结束")
|
||
|
||
raw = _extract_json(full_text)
|
||
log_detail(f"从 LLM 返回内容中抽取出的 JSON 片段: {raw}")
|
||
|
||
try:
|
||
data = json.loads(raw)
|
||
except Exception as e:
|
||
log_detail(f"解析 LLM 返回 JSON 失败: {e}")
|
||
logs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "Logs")
|
||
os.makedirs(logs_dir, exist_ok=True)
|
||
if log_path is None:
|
||
final_log_path = os.path.join(
|
||
logs_dir, os.path.basename(image_path).replace(".jpg", ".log")
|
||
)
|
||
mode = "w"
|
||
else:
|
||
final_log_path = log_path
|
||
mode = "a"
|
||
with open(final_log_path, mode, encoding="utf-8") as f:
|
||
for line in log_lines:
|
||
f.write(line + "\n")
|
||
log_detail(f"已写入详细日志到: {final_log_path}")
|
||
return
|
||
|
||
if not isinstance(data, list):
|
||
log_detail(f"期望 LLM 返回 JSON 数组, 实际类型: {type(data)}")
|
||
return
|
||
|
||
if not data:
|
||
log_detail("LLM 返回空数组, 结束当前图片处理")
|
||
return
|
||
|
||
stations = []
|
||
for idx, item in enumerate(data):
|
||
if not isinstance(item, dict):
|
||
log_detail(f"LLM item[{idx}] 不是对象类型, 跳过")
|
||
continue
|
||
name = item.get("station_name") or item.get("name")
|
||
if name:
|
||
bad = False
|
||
for kw in LLM_NON_STATION_NAME_KEYWORDS:
|
||
if kw and kw in name:
|
||
log_detail(
|
||
f"LLM item[{idx}] station_name 命中非场站关键词 {kw}, 丢弃, name={name}"
|
||
)
|
||
bad = True
|
||
break
|
||
if bad:
|
||
continue
|
||
anchor = item.get("anchor_point_norm") or item.get("anchor") or item.get("center_norm")
|
||
if not anchor:
|
||
log_detail(f"LLM item[{idx}] 缺少 anchor 信息, 跳过, content={item}")
|
||
continue
|
||
try:
|
||
ax = float(anchor.get("x"))
|
||
ay = float(anchor.get("y"))
|
||
except Exception:
|
||
log_detail(f"LLM item[{idx}] anchor 不可解析, content={anchor}")
|
||
continue
|
||
if not (0 <= ax <= 1 and 0 <= ay <= 1):
|
||
log_detail(
|
||
f"LLM item[{idx}] anchor 超出范围, x={ax}, y={ay}, 跳过"
|
||
)
|
||
continue
|
||
px = int(ax * w)
|
||
py = int(ay * h) - 20
|
||
if py < 0:
|
||
py = 0
|
||
|
||
distance_text = item.get("distance_text") or item.get("distance")
|
||
busy_list_raw = (
|
||
item.get("busy_list")
|
||
or item.get("busyInfos")
|
||
or item.get("busy_info")
|
||
or item.get("busy")
|
||
)
|
||
busy_list = []
|
||
if isinstance(busy_list_raw, list):
|
||
for bi in busy_list_raw:
|
||
if not isinstance(bi, dict):
|
||
continue
|
||
mode = bi.get("mode") or bi.get("type")
|
||
idle = bi.get("idle")
|
||
total = bi.get("total")
|
||
try:
|
||
idle = int(idle) if idle is not None else None
|
||
except Exception:
|
||
idle = None
|
||
try:
|
||
total = int(total) if total is not None else None
|
||
except Exception:
|
||
total = None
|
||
busy_list.append({"mode": mode, "idle": idle, "total": total})
|
||
elif isinstance(busy_list_raw, dict):
|
||
mode = busy_list_raw.get("mode") or busy_list_raw.get("type")
|
||
idle = busy_list_raw.get("idle")
|
||
total = busy_list_raw.get("total")
|
||
try:
|
||
idle = int(idle) if idle is not None else None
|
||
except Exception:
|
||
idle = None
|
||
try:
|
||
total = int(total) if total is not None else None
|
||
except Exception:
|
||
total = None
|
||
busy_list.append({"mode": mode, "idle": idle, "total": total})
|
||
|
||
stations.append(
|
||
{
|
||
"name": name,
|
||
"ax": ax,
|
||
"ay": ay,
|
||
"px": px,
|
||
"py": py,
|
||
"distance_text": distance_text,
|
||
"busy_list": busy_list,
|
||
}
|
||
)
|
||
log_detail(
|
||
f"LLM anchor 规范化[{len(stations)}] name={name} ax={ax:.4f} ay={ay:.4f} py={py} "
|
||
f"distance={distance_text} busy_list={busy_list}"
|
||
)
|
||
|
||
if not stations:
|
||
log_detail("LLM 解析后没有可用的场站锚点, 结束当前图片处理")
|
||
return
|
||
|
||
filtered = []
|
||
for s in stations:
|
||
if s.get("busy_list"):
|
||
filtered.append(s)
|
||
else:
|
||
log_detail(
|
||
f"场站 {s.get('name')} busy_list 为空, 视为信息不完整, 丢弃"
|
||
)
|
||
stations = filtered
|
||
if not stations:
|
||
log_detail("所有场站的忙闲信息均为空, 本页不画绿框")
|
||
return
|
||
|
||
stations.sort(key=lambda s: s["py"])
|
||
|
||
overlay = img.copy()
|
||
results = []
|
||
|
||
box_w = int(w * STATION_BOX_WIDTH_RATIO)
|
||
box_h_conf = int(h * STATION_BOX_HEIGHT_RATIO)
|
||
box_h = box_h_conf
|
||
|
||
if len(stations) >= 2:
|
||
gaps = []
|
||
for i in range(len(stations) - 1):
|
||
dy = stations[i + 1]["py"] - stations[i]["py"]
|
||
if dy > 0:
|
||
gaps.append(dy)
|
||
if gaps:
|
||
min_gap = min(gaps)
|
||
max_no_overlap = max(min_gap - 10, 40)
|
||
if max_no_overlap < box_h_conf * 0.6:
|
||
box_h = box_h_conf
|
||
log_detail(
|
||
f"场站间距过小(min_gap={min_gap}), 使用配置高度 box_h={box_h_conf}"
|
||
)
|
||
else:
|
||
box_h = min(box_h_conf, max_no_overlap)
|
||
log_detail(
|
||
f"根据最小锚点间距调整 box_h: conf={box_h_conf}, "
|
||
f"min_gap={min_gap}, final={box_h}"
|
||
)
|
||
else:
|
||
box_h = box_h_conf
|
||
log_detail("未找到有效间距, 使用配置高度 box_h={box_h_conf}")
|
||
else:
|
||
box_h = box_h_conf
|
||
log_detail(f"仅一个场站, 使用配置高度 box_h={box_h_conf}")
|
||
|
||
x1_fixed = max(0, (w - box_w) // 2)
|
||
x2_fixed = min(w, x1_fixed + box_w)
|
||
effective_top = int(h * SAFE_EXCLUDE_RATIO)
|
||
effective_bottom = int(h * (1 - BOTTOM_SAFE_EXCLUDE_RATIO))
|
||
|
||
log_detail(
|
||
f"固定绿框参数: box_w={box_w}, box_h={box_h}, "
|
||
f"x1_fixed={x1_fixed}, x2_fixed={x2_fixed}, "
|
||
f"effective_top={effective_top}, effective_bottom={effective_bottom}"
|
||
)
|
||
|
||
prev_y2 = None
|
||
|
||
for idx, s in enumerate(stations):
|
||
name = s["name"]
|
||
ax = s["ax"]
|
||
ay = s["ay"]
|
||
px = s["px"]
|
||
py = s["py"]
|
||
|
||
if py < effective_top:
|
||
if ay >= SAFE_EXCLUDE_RATIO - 0.05:
|
||
log_detail(
|
||
f"Station[{idx + 1}] {name} 锚点 py={py} 稍微高于有效区域, 允许保留并从 effective_top 开始绘制"
|
||
)
|
||
y1 = effective_top
|
||
else:
|
||
log_detail(
|
||
f"Station[{idx + 1}] {name} 锚点 py={py} 位于顶部保护区之上(effective_top={effective_top}), 丢弃"
|
||
)
|
||
continue
|
||
else:
|
||
y1 = py - int(box_h * 0.1)
|
||
y2 = y1 + box_h
|
||
orig_y1 = y1
|
||
orig_y2 = y2
|
||
if y1 < 0:
|
||
y1 = 0
|
||
y2 = box_h
|
||
if y2 > h:
|
||
y2 = h
|
||
y1 = h - box_h
|
||
|
||
if ay >= SAFE_EXCLUDE_RATIO and y1 < effective_top:
|
||
y1 = effective_top
|
||
y2 = y1 + box_h
|
||
if y2 > h:
|
||
y2 = h
|
||
y1 = h - box_h
|
||
if y2 > effective_bottom:
|
||
y2 = effective_bottom
|
||
y1 = y2 - box_h
|
||
if y1 < 0:
|
||
y1 = 0
|
||
y2 = box_h
|
||
|
||
if prev_y2 is not None and y1 <= prev_y2:
|
||
shift = prev_y2 - y1 + 1
|
||
y1 += shift
|
||
y2 += shift
|
||
if y2 > effective_bottom:
|
||
if idx == len(stations) - 1:
|
||
min_height = int(box_h * 0.5)
|
||
new_y1 = prev_y2 + 1
|
||
new_y2 = effective_bottom
|
||
if new_y2 - new_y1 >= min_height:
|
||
y1 = new_y1
|
||
y2 = new_y2
|
||
else:
|
||
log_detail(
|
||
f"Station[{idx + 1}] {name} 底部剩余空间不足以放置绿框, 被丢弃"
|
||
)
|
||
continue
|
||
else:
|
||
log_detail(
|
||
f"Station[{idx + 1}] {name} 因避免重叠无法放入有效区域, 被丢弃"
|
||
)
|
||
continue
|
||
|
||
prev_y2 = y2
|
||
|
||
click_x = int((x1_fixed + x2_fixed) / 2)
|
||
click_y = int((y1 + y2) / 2)
|
||
|
||
log_detail(
|
||
f"Station[{idx + 1}] name={name} anchor=({ax:.4f},{ay:.4f}) "
|
||
f"px={px}, py={py}, box_orig=({orig_y1},{orig_y2}), "
|
||
f"box_adj=({y1},{y2}), click=({click_x},{click_y})"
|
||
)
|
||
|
||
cv2.rectangle(overlay, (x1_fixed, y1), (x2_fixed, y2), (0, 255, 0), 2)
|
||
cv2.circle(overlay, (click_x, click_y), 8, (0, 0, 255), -1)
|
||
|
||
results.append(
|
||
{
|
||
"index": idx + 1,
|
||
"station_name": name,
|
||
"rect": [x1_fixed, y1, x2_fixed, y2],
|
||
"click_point": [click_x, click_y],
|
||
"distance_text": s.get("distance_text"),
|
||
"busy_list": s.get("busy_list"),
|
||
}
|
||
)
|
||
|
||
stem, ext = os.path.splitext(image_path)
|
||
out_path = f"{stem}_ocr_rect{ext}"
|
||
cv2.imwrite(out_path, overlay)
|
||
|
||
log_detail(f"输入图片: {image_path}")
|
||
log_detail(f"调试输出图片: {out_path}")
|
||
log_detail("识别到的场站及矩形坐标如下:")
|
||
for item in results:
|
||
log_detail(json.dumps(item, ensure_ascii=False))
|
||
|
||
logs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "Logs")
|
||
os.makedirs(logs_dir, exist_ok=True)
|
||
if log_path is None:
|
||
final_log_path = os.path.join(
|
||
logs_dir, os.path.basename(image_path).replace(".jpg", ".log")
|
||
)
|
||
mode = "w"
|
||
else:
|
||
final_log_path = log_path
|
||
mode = "a"
|
||
with open(final_log_path, mode, encoding="utf-8") as f:
|
||
for line in log_lines:
|
||
f.write(line + "\n")
|
||
log_detail(f"已写入详细日志到: {final_log_path}")
|
||
|
||
return results
|
||
|
||
|
||
async def run_batch_in_dir(base_dir, log_path=None):
|
||
files = sorted(os.listdir(base_dir))
|
||
jpg_files = []
|
||
for name in files:
|
||
lower = name.lower()
|
||
if not lower.endswith(".jpg"):
|
||
continue
|
||
if lower.endswith("_ocr_rect.jpg"):
|
||
continue
|
||
jpg_files.append(os.path.join(base_dir, name))
|
||
|
||
if not jpg_files:
|
||
logger.info(f"目录中未找到待处理的 JPG 文件: {base_dir}")
|
||
return
|
||
|
||
logger.info(f"即将批量处理 {len(jpg_files)} 张图片:")
|
||
for p in jpg_files:
|
||
logger.info(f" - {p}")
|
||
|
||
for p in jpg_files:
|
||
await run_ocr_rect(p, log_path=log_path)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||
asyncio.run(run_batch_in_dir(base_dir))
|