447 lines
16 KiB
Python
447 lines
16 KiB
Python
import asyncio
|
||
import json
|
||
import os
|
||
import sys
|
||
|
||
import cv2
|
||
import numpy as np
|
||
|
||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
if project_root not in sys.path:
|
||
sys.path.append(project_root)
|
||
|
||
from Apps.TeLaiDian.Kit import setup_logger, get_ocr_reader, draw_rectangles
|
||
from Apps.TeLaiDian.Config.Setting import SAFE_EXCLUDE_RATIO, BOTTOM_SAFE_EXCLUDE_RATIO
|
||
from Util.LlmUtil import get_llm_response
|
||
|
||
|
||
logger = setup_logger("TeLaiDian.FirstPageKit")
|
||
TEXT_TOP_RATIO = SAFE_EXCLUDE_RATIO * 0.8
|
||
|
||
NON_STATION_KEYWORDS = [
|
||
"地图",
|
||
"目的地",
|
||
"电站名",
|
||
"充电礼",
|
||
"再充",
|
||
"注册礼",
|
||
"元券",
|
||
"PLUS会员",
|
||
"我的收藏",
|
||
"最近充电",
|
||
"我的卡券",
|
||
"我的订单",
|
||
"充电券",
|
||
"电信积分兑换",
|
||
"确认",
|
||
"广告",
|
||
"距离/区域",
|
||
"综合排序",
|
||
"偏好",
|
||
"星级站",
|
||
"停车减免",
|
||
"重卡可用",
|
||
"首页",
|
||
"特省钱",
|
||
"扫码",
|
||
"输入",
|
||
"商城",
|
||
"推荐",
|
||
]
|
||
|
||
def _load_image(path):
|
||
if not os.path.exists(path):
|
||
raise FileNotFoundError(path)
|
||
img = cv2.imread(path)
|
||
if img is None:
|
||
raise RuntimeError(f"无法读取图片: {path}")
|
||
h, w = img.shape[:2]
|
||
return img, w, h
|
||
|
||
|
||
def _extract_json(text: str) -> str:
|
||
if not text:
|
||
return "[]"
|
||
cleaned = text.strip()
|
||
if "```" in cleaned:
|
||
lines = []
|
||
for line in cleaned.splitlines():
|
||
if line.strip().startswith("```"):
|
||
continue
|
||
lines.append(line)
|
||
cleaned = "\n".join(lines).strip()
|
||
decoder = json.JSONDecoder()
|
||
pos = 0
|
||
while pos < len(cleaned):
|
||
idx_dict = cleaned.find("{", pos)
|
||
idx_list = cleaned.find("[", pos)
|
||
candidates = [i for i in (idx_dict, idx_list) if i != -1]
|
||
if not candidates:
|
||
break
|
||
start = min(candidates)
|
||
snippet = cleaned[start:]
|
||
try:
|
||
_, end = decoder.raw_decode(snippet)
|
||
return snippet[:end]
|
||
except json.JSONDecodeError:
|
||
pos = start + 1
|
||
continue
|
||
return "[]"
|
||
|
||
|
||
async def run_ocr_rect(image_path, log_path=None):
|
||
log_lines = []
|
||
|
||
def log_detail(msg):
|
||
logger.info(msg)
|
||
log_lines.append(msg)
|
||
|
||
img, w, h = _load_image(image_path)
|
||
log_detail(f"开始处理图片: {image_path}, 宽={w}, 高={h}")
|
||
|
||
reader = get_ocr_reader()
|
||
ocr_results = reader.read_text(img)
|
||
log_detail(f"OCR 原始结果数量: {len(ocr_results)}")
|
||
|
||
entries = []
|
||
for idx, (quad, text, prob) in enumerate(ocr_results):
|
||
pts = np.array(quad).astype(int)
|
||
x_min = int(np.min(pts[:, 0]))
|
||
y_min = int(np.min(pts[:, 1]))
|
||
x_max = int(np.max(pts[:, 0]))
|
||
y_max = int(np.max(pts[:, 1]))
|
||
cx = (x_min + x_max) / 2.0
|
||
cy = (y_min + y_max) / 2.0
|
||
cx_norm = cx / w
|
||
cy_norm = cy / h
|
||
|
||
status = "keep"
|
||
reasons = []
|
||
txt = text or ""
|
||
if prob < 0.3:
|
||
status = "drop"
|
||
reasons.append("prob<0.3")
|
||
if not txt:
|
||
status = "drop"
|
||
reasons.append("empty_text")
|
||
has_station_kw = False
|
||
has_distance_kw = False
|
||
if txt:
|
||
if ("充电站" in txt) or ("超快充" in txt) or ("快充" in txt) or ("慢充" in txt):
|
||
has_station_kw = True
|
||
if ("km" in txt) or ("m" in txt):
|
||
has_distance_kw = True
|
||
if cy_norm < TEXT_TOP_RATIO and not (
|
||
(has_station_kw and len(txt) >= 4) or has_distance_kw
|
||
):
|
||
status = "drop"
|
||
reasons.append("top_safe_zone")
|
||
if cy_norm > (1 - BOTTOM_SAFE_EXCLUDE_RATIO):
|
||
status = "drop"
|
||
reasons.append("bottom_safe_zone")
|
||
if status == "keep" and txt:
|
||
for kw in NON_STATION_KEYWORDS:
|
||
if kw and kw in txt:
|
||
status = "drop"
|
||
reasons.append("non_station_keyword")
|
||
break
|
||
|
||
log_detail(
|
||
f"OCR[{idx + 1}] text={repr(text)} prob={prob:.3f} "
|
||
f"cx_norm={cx_norm:.4f} cy_norm={cy_norm:.4f} "
|
||
f"status={status} reasons={','.join(reasons) if reasons else '-'}"
|
||
)
|
||
|
||
if status != "keep":
|
||
continue
|
||
|
||
entries.append(
|
||
{
|
||
"text": text,
|
||
"prob": float(prob),
|
||
"cx_norm": cx_norm,
|
||
"cy_norm": cy_norm,
|
||
}
|
||
)
|
||
|
||
log_detail(f"OCR 通过过滤的有效文本数量: {len(entries)}")
|
||
|
||
if not entries:
|
||
log_detail("无有效 OCR 文本, 结束当前图片处理")
|
||
return
|
||
|
||
indexed_entries = []
|
||
for idx, e in enumerate(entries):
|
||
indexed_entries.append(
|
||
{
|
||
"id": idx + 1,
|
||
"text": e["text"],
|
||
"prob": e["prob"],
|
||
"cx_norm": round(e["cx_norm"], 4),
|
||
"cy_norm": round(e["cy_norm"], 4),
|
||
}
|
||
)
|
||
|
||
payload_json = json.dumps(indexed_entries, ensure_ascii=False)
|
||
log_detail(f"传给 LLM 的 OCR 条目数: {len(indexed_entries)}")
|
||
|
||
query_text = (
|
||
"下面是特来电列表页整张截图的 OCR 结果,每一项代表一行文字,包含其中心点的归一化坐标:\n"
|
||
"ocr_items = " + payload_json + "\n"
|
||
"请你根据这些文本,将它们聚合成若干个“充电场站卡片”。输出一个 JSON 数组,每个元素必须包含:\n"
|
||
"1) station_name: 场站名称,只能是卡片标题中的名称,不允许是筛选标签、导航按钮、底部功能区等。\n"
|
||
"2) anchor_point_norm: 一个对象 {\"x\": number, \"y\": number},表示该场站名称文字所在行的中心点坐标,取值范围 0-1。\n"
|
||
"并且尽量补充以下可选字段(找不到时可以省略或设为 null):\n"
|
||
"3) distance_text: 距离字符串,例如 \"6.9km\"、\"500m\",从对应卡片中的距离行提取。\n"
|
||
"4) busy_list: 忙闲信息数组,数组中的每一项是 {\"mode\": \"快|慢|超|普通\", \"idle\": number, \"total\": number}。\n"
|
||
" 在特来电 UI 中,忙闲文本通常是 \"空闲x/总y\" 或 \"快 空闲x/总y\" 这种形式,\n"
|
||
" 请从相应行中解析出模式和空闲/总数。\n"
|
||
"额外提示:\n"
|
||
"- 每个场站卡片通常包含一行类似 \"1.4km\"、\"3.6km\" 的距离文本;\n"
|
||
"- 该距离文本所在行的左侧、且在同一卡片中的那一行文字,就是对应的场站标题 station_name;\n"
|
||
"- 忙闲信息通常出现在卡片右侧的彩色小块中,例如 \"快 空闲24/32\"、\"慢 空闲0/10\" 等;\n"
|
||
"要求:\n"
|
||
"- 场站按从上到下排序;\n"
|
||
"- station_name 不能取距离行本身(如 \"1.4km\"),而是要取与之成一对的标题行;\n"
|
||
"- 如果某些 OCR 文本显然不属于任何场站卡片,可以忽略;\n"
|
||
"- 只输出 JSON 数组,不要输出其它任何文字。"
|
||
)
|
||
|
||
chunks = []
|
||
async for part in get_llm_response(
|
||
query_text=query_text,
|
||
stream=False,
|
||
system_prompt="你是一个帮助整理 OCR 文本的助手,只输出 JSON。",
|
||
chat_history=None,
|
||
temperature=0,
|
||
):
|
||
chunks.append(part)
|
||
|
||
full_text = "".join(chunks)
|
||
log_detail("LLM 原始返回内容开始")
|
||
log_lines.append(full_text)
|
||
log_detail("LLM 原始返回内容结束")
|
||
|
||
raw = _extract_json(full_text)
|
||
log_detail(f"从 LLM 返回内容中抽取出的 JSON 片段: {raw}")
|
||
|
||
try:
|
||
data = json.loads(raw)
|
||
except Exception as e:
|
||
log_detail(f"解析 LLM 返回 JSON 失败: {e}")
|
||
logs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "Logs")
|
||
os.makedirs(logs_dir, exist_ok=True)
|
||
if log_path is None:
|
||
final_log_path = os.path.join(
|
||
logs_dir, os.path.basename(image_path).replace(".jpg", ".log")
|
||
)
|
||
mode = "w"
|
||
else:
|
||
final_log_path = log_path
|
||
mode = "a"
|
||
with open(final_log_path, mode, encoding="utf-8") as f:
|
||
for line in log_lines:
|
||
f.write(line + "\n")
|
||
log_detail(f"已写入详细日志到: {final_log_path}")
|
||
return
|
||
|
||
if not isinstance(data, list):
|
||
log_detail(f"期望 LLM 返回 JSON 数组, 实际类型: {type(data)}")
|
||
return
|
||
|
||
if not data:
|
||
log_detail("LLM 返回空数组, 结束当前图片处理")
|
||
return
|
||
|
||
stations = []
|
||
for idx, item in enumerate(data):
|
||
if not isinstance(item, dict):
|
||
log_detail(f"LLM item[{idx}] 不是对象类型, 跳过")
|
||
continue
|
||
name = item.get("station_name") or item.get("name")
|
||
anchor = item.get("anchor_point_norm") or item.get("anchor")
|
||
distance = item.get("distance_text") or item.get("distance")
|
||
busy_list = item.get("busy_list") or []
|
||
if not name or not isinstance(anchor, dict):
|
||
log_detail(f"LLM item[{idx}] 缺少必要字段, 跳过: {item}")
|
||
continue
|
||
ax = float(anchor.get("x", 0.5))
|
||
ay = float(anchor.get("y", 0.5))
|
||
if not (0 <= ax <= 1 and 0 <= ay <= 1):
|
||
log_detail(f"LLM item[{idx}] anchor_point_norm 超出范围, 跳过: {anchor}")
|
||
continue
|
||
px = int(ax * w)
|
||
py = int(ay * h)
|
||
stations.append(
|
||
{
|
||
"station_name": name,
|
||
"anchor_point_norm": {"x": ax, "y": ay},
|
||
"distance_text": distance,
|
||
"busy_list": busy_list,
|
||
"anchor_px": px,
|
||
"anchor_py": py,
|
||
}
|
||
)
|
||
log_detail(
|
||
f"LLM anchor 规范化[{len(stations)}] name={name} ax={ax:.4f} ay={ay:.4f} "
|
||
f"px={px} py={py} distance={distance} busy_list={busy_list}"
|
||
)
|
||
|
||
if not stations:
|
||
log_detail("LLM 解析后无有效场站, 结束当前图片处理")
|
||
return
|
||
filtered = []
|
||
for s in stations:
|
||
dl = s.get("distance_text")
|
||
bl = s.get("busy_list") or []
|
||
ok_dist = isinstance(dl, str) and (("km" in dl) or ("m" in dl))
|
||
ok_busy = isinstance(bl, list) and len(bl) > 0
|
||
if ok_dist:
|
||
filtered.append(s)
|
||
if not ok_busy:
|
||
log_detail(f"场站缺少忙闲信息但保留: {s.get('station_name')}")
|
||
else:
|
||
log_detail(f"丢弃缺少距离信息的条目: {s.get('station_name')}")
|
||
stations = filtered
|
||
if not stations:
|
||
log_detail("过滤后无有效场站, 结束当前图片处理")
|
||
return
|
||
stations.sort(key=lambda s: s["anchor_py"])
|
||
|
||
min_gap = None
|
||
if len(stations) >= 2:
|
||
gaps = []
|
||
for i in range(1, len(stations)):
|
||
gaps.append(stations[i]["anchor_py"] - stations[i - 1]["anchor_py"])
|
||
min_gap = min(gaps)
|
||
box_h = min_gap
|
||
if box_h < int(h * 0.20):
|
||
box_h = int(h * 0.20)
|
||
if box_h > int(h * 0.32):
|
||
box_h = int(h * 0.32)
|
||
log_detail(f"根据最小锚点间距调整 box_h: min_gap={min_gap}, final={box_h}")
|
||
else:
|
||
box_h = int(h * 0.22)
|
||
log_detail(f"仅有一个场站, 使用默认 box_h={box_h}")
|
||
|
||
box_w = int(w * 0.90)
|
||
x1_fixed = int((w - box_w) / 2)
|
||
x2_fixed = x1_fixed + box_w
|
||
|
||
result = []
|
||
rects = []
|
||
click_points = []
|
||
effective_top = int(h * SAFE_EXCLUDE_RATIO)
|
||
effective_bottom = int(h * (1 - BOTTOM_SAFE_EXCLUDE_RATIO))
|
||
prev_y2 = None
|
||
anchor_ratio = 0.15
|
||
for idx, st in enumerate(stations, start=1):
|
||
name = st.get("station_name")
|
||
py = st["anchor_py"]
|
||
anchor_norm = st.get("anchor_point_norm") or {}
|
||
try:
|
||
ay = float(anchor_norm.get("y", py / h))
|
||
except Exception:
|
||
ay = py / h
|
||
y1 = int(py - box_h * anchor_ratio)
|
||
y2 = y1 + box_h
|
||
if min_gap is not None:
|
||
downward_limit = int(min_gap * 0.75)
|
||
max_y2 = py + downward_limit
|
||
if y2 > max_y2:
|
||
y2 = max_y2
|
||
y1 = y2 - box_h
|
||
if y1 < 0:
|
||
y1 = 0
|
||
y2 = y1 + box_h
|
||
if y2 > h:
|
||
y2 = h
|
||
y1 = y2 - box_h
|
||
if y2 > effective_bottom:
|
||
y2 = effective_bottom
|
||
y1 = y2 - box_h
|
||
if prev_y2 is not None and y1 <= prev_y2:
|
||
shift = prev_y2 - y1 + 1
|
||
y1 += shift
|
||
y2 += shift
|
||
if y2 > effective_bottom:
|
||
if idx == len(stations):
|
||
new_y1 = prev_y2 + 1
|
||
new_y2 = effective_bottom
|
||
if new_y2 <= new_y1:
|
||
log_detail(f"底部空间不足,丢弃: {st.get('station_name')}")
|
||
continue
|
||
y1 = new_y1
|
||
y2 = new_y2
|
||
else:
|
||
log_detail(f"避免重叠无法放置,丢弃: {st.get('station_name')}")
|
||
continue
|
||
rect = [x1_fixed, y1, x2_fixed, y2]
|
||
cx = int((rect[0] + rect[2]) / 2)
|
||
cy = int((rect[1] + rect[3]) / 2)
|
||
|
||
click_points.append([cx, cy])
|
||
rects.append(rect)
|
||
prev_y2 = y2
|
||
|
||
item = {
|
||
"index": idx,
|
||
"station_name": st["station_name"],
|
||
"rect": rect,
|
||
"click_point": [cx, cy],
|
||
"distance_text": st.get("distance_text"),
|
||
"busy_list": st.get("busy_list") or [],
|
||
}
|
||
result.append(item)
|
||
log_detail(
|
||
f"Station[{idx}] name={item['station_name']} "
|
||
f"rect={rect} click={item['click_point']} "
|
||
f"distance={item['distance_text']} busy_list={item['busy_list']}"
|
||
)
|
||
|
||
try:
|
||
draw_rectangles(image_path, rects, click_points, save_vl=False)
|
||
except Exception as e:
|
||
log_detail(f"绘制调试矩形失败: {e}")
|
||
|
||
logs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "Logs")
|
||
os.makedirs(logs_dir, exist_ok=True)
|
||
if log_path is None:
|
||
final_log_path = os.path.join(
|
||
logs_dir, os.path.basename(image_path).replace(".jpg", ".log")
|
||
)
|
||
mode = "w"
|
||
else:
|
||
final_log_path = log_path
|
||
mode = "a"
|
||
with open(final_log_path, mode, encoding="utf-8") as f:
|
||
for line in log_lines:
|
||
f.write(line + "\n")
|
||
log_detail(f"已写入详细日志到: {final_log_path}")
|
||
|
||
return result
|
||
|
||
|
||
async def run_batch_in_dir(image_dir, log_file=None):
|
||
img_files = []
|
||
for name in os.listdir(image_dir):
|
||
lower = name.lower()
|
||
if not (lower.endswith(".jpg") or lower.endswith(".png")):
|
||
continue
|
||
if "_flag" in lower or "_vl" in lower:
|
||
continue
|
||
img_files.append(os.path.join(image_dir, name))
|
||
img_files.sort()
|
||
|
||
if not img_files:
|
||
logger.info(f"目录下未找到图片: {image_dir}")
|
||
return
|
||
|
||
for idx, path in enumerate(img_files, start=1):
|
||
logger.info(f"[批处理] 开始处理第 {idx} 张图片: {path}")
|
||
try:
|
||
await run_ocr_rect(path, log_path=log_file)
|
||
except Exception as e:
|
||
logger.exception(f"[批处理] 处理图片失败: {path}, {e}")
|
||
|