This commit is contained in:
HuangHai
2026-01-17 17:36:27 +08:00
parent ec680d1f92
commit 26ec86207b
9 changed files with 201 additions and 0 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 280 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 297 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 122 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 261 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 289 KiB

View File

@@ -0,0 +1,201 @@
import asyncio
import base64
import json
import os
import sys
import cv2
import numpy as np
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if project_root not in sys.path:
sys.path.append(project_root)
from Apps.XinDianTu.Kit import setup_logger
from Apps.XinDianTu.ReadImageKit import ReadImageKit
from Apps.XinDianTu.Config.Setting import SAFE_EXCLUDE_RATIO, BOTTOM_SAFE_EXCLUDE_RATIO
from Config.Config import VL_MODEL_NAME
logger = setup_logger("XinDianTu.VLRectTest")
def _load_image(path):
if not os.path.exists(path):
raise FileNotFoundError(path)
img = cv2.imread(path)
if img is None:
raise RuntimeError(f"无法读取图片: {path}")
h, w = img.shape[:2]
return img, w, h
def _encode_image_to_data_url(path):
with open(path, "rb") as f:
encoded = base64.b64encode(f.read()).decode("utf-8")
return {"url": f"data:image/jpeg;base64,{encoded}"}
async def run_test(image_path):
img, w, h = _load_image(image_path)
img_input = _encode_image_to_data_url(image_path)
prompt = (
"这是一张名为“新电途 到哪都能充”的充电站列表页面完整截图。\n"
"请你从整张图中找出所有“单个充电站卡片”的区域,并按从上到下的顺序输出一个 JSON 数组。\n"
"每个数组元素是一个对象,必须包含:\n"
"1. station_name: 该卡片内的充电站名称字符串。\n"
"2. bounds_norm: 该卡片在整张图片中的相对矩形坐标,字段为 left, top, right, bottom"
"取值范围均为 0 到 1表示相对于整张图片宽高的比例。\n"
"矩形应尽量完整覆盖卡片,包括名称、价格、桩数、距离等内容。\n"
"如果你不完全确定精确边界,也要给出你认为最接近的矩形位置,不要省略。\n"
"返回示例:\n"
"[\n"
" {\"station_name\": \"长春市南关区南溪湿地公园公共充电站\", \"bounds_norm\": {\"left\": 0.03, \"top\": 0.32, \"right\": 0.97, \"bottom\": 0.43}},\n"
" {\"station_name\": \"长春环球贸易中心极东站\", \"bounds_norm\": {\"left\": 0.03, \"top\": 0.45, \"right\": 0.97, \"bottom\": 0.56}}\n"
"]\n"
"只允许输出一个 JSON 数组,不要包含任何解释性文字。"
)
resp = await ReadImageKit._client.chat.completions.create(
model=VL_MODEL_NAME,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": img_input},
{"type": "text", "text": prompt},
],
},
],
)
content = resp.choices[0].message.content or ""
raw = ReadImageKit._extract_json(content)
try:
data = json.loads(raw)
except Exception as e:
logger.error(f"解析模型返回JSON失败: {e}")
logger.error(f"原始内容: {content}")
return
if not isinstance(data, list):
logger.error(f"期望返回JSON数组但得到: {type(data)}")
logger.error(f"解析结果: {data}")
return
if not data:
logger.info("模型返回空数组。原始内容如下:")
logger.info(content)
return
overlay = img.copy()
results = []
for idx, item in enumerate(data):
if not isinstance(item, dict):
continue
name = item.get("station_name")
bounds_norm = item.get("bounds_norm") or {}
try:
l = float(bounds_norm.get("left"))
t = float(bounds_norm.get("top"))
r = float(bounds_norm.get("right"))
b = float(bounds_norm.get("bottom"))
except Exception:
continue
if not (0 <= l < r <= 1 and 0 <= t < b <= 1):
continue
x1 = int(l * w)
y1 = int(t * h)
x2 = int(r * w)
y2 = int(b * h)
center_y = (y1 + y2) / 2.0
top_safe_ratio = 0.30
bottom_safe_ratio = BOTTOM_SAFE_EXCLUDE_RATIO
in_top_safe = center_y < h * top_safe_ratio
in_bottom_safe = center_y > h * (1 - bottom_safe_ratio)
color = (0, 255, 0)
click_x = None
click_y = None
if in_top_safe or in_bottom_safe:
color = (0, 0, 255)
cv2.rectangle(overlay, (x1, y1), (x2, y2), color, 2)
if not in_top_safe and not in_bottom_safe:
click_x = int((x1 + x2) / 2)
click_y = int((y1 + y2) / 2)
cv2.circle(overlay, (click_x, click_y), 8, (0, 0, 255), -1)
if name:
label = f"{idx + 1}:{name}"
cv2.putText(
overlay,
label[:20],
(x1 + 5, max(y1 - 10, 20)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
color,
2,
cv2.LINE_AA,
)
results.append(
{
"index": idx + 1,
"station_name": name,
"rect": [x1, y1, x2, y2],
"center_y": center_y,
"in_top_safe": in_top_safe,
"in_bottom_safe": in_bottom_safe,
"click_point": [click_x, click_y],
}
)
out_path = image_path.replace(".jpg", "_vlm_rect.jpg")
cv2.imwrite(out_path, overlay)
logger.info(f"输入图片: {image_path}")
logger.info(f"调试输出图片: {out_path}")
logger.info("识别到的场站及矩形坐标如下:")
for item in results:
logger.info(json.dumps(item, ensure_ascii=False))
async def run_batch():
base_dir = os.path.dirname(os.path.abspath(__file__))
files = sorted(os.listdir(base_dir))
jpg_files = []
for name in files:
if not name.lower().endswith(".jpg"):
continue
if name.lower().endswith("_vlm_rect.jpg"):
continue
jpg_files.append(os.path.join(base_dir, name))
if not jpg_files:
logger.info(f"目录中未找到待处理的 JPG 文件: {base_dir}")
return
logger.info(f"即将批量处理 {len(jpg_files)} 张图片:")
for p in jpg_files:
logger.info(f" - {p}")
for p in jpg_files:
await run_test(p)
if __name__ == "__main__":
# 单张测试示例:
# test_image = r"d:\dsWork\aiData\Test\TestFirstPage\Screenshot_20260117_161824.jpg"
# asyncio.run(run_test(test_image))
asyncio.run(run_batch())