'commit'
BIN
Test/TestFirstPage/Screenshot_20260117_161824.jpg
Normal file
|
After Width: | Height: | Size: 130 KiB |
BIN
Test/TestFirstPage/Screenshot_20260117_161824_vlm_rect.jpg
Normal file
|
After Width: | Height: | Size: 280 KiB |
BIN
Test/TestFirstPage/Screenshot_20260117_161830.jpg
Normal file
|
After Width: | Height: | Size: 130 KiB |
BIN
Test/TestFirstPage/Screenshot_20260117_161830_vlm_rect.jpg
Normal file
|
After Width: | Height: | Size: 297 KiB |
BIN
Test/TestFirstPage/Screenshot_20260117_171622.jpg
Normal file
|
After Width: | Height: | Size: 122 KiB |
BIN
Test/TestFirstPage/Screenshot_20260117_171622_vlm_rect.jpg
Normal file
|
After Width: | Height: | Size: 261 KiB |
BIN
Test/TestFirstPage/Screenshot_20260117_172118.jpg
Normal file
|
After Width: | Height: | Size: 131 KiB |
BIN
Test/TestFirstPage/Screenshot_20260117_172118_vlm_rect.jpg
Normal file
|
After Width: | Height: | Size: 289 KiB |
201
Test/TestFirstPage/T_XinDianTu_VLRectTest.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
if project_root not in sys.path:
|
||||
sys.path.append(project_root)
|
||||
|
||||
from Apps.XinDianTu.Kit import setup_logger
|
||||
from Apps.XinDianTu.ReadImageKit import ReadImageKit
|
||||
from Apps.XinDianTu.Config.Setting import SAFE_EXCLUDE_RATIO, BOTTOM_SAFE_EXCLUDE_RATIO
|
||||
from Config.Config import VL_MODEL_NAME
|
||||
|
||||
|
||||
logger = setup_logger("XinDianTu.VLRectTest")
|
||||
|
||||
|
||||
def _load_image(path):
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(path)
|
||||
img = cv2.imread(path)
|
||||
if img is None:
|
||||
raise RuntimeError(f"无法读取图片: {path}")
|
||||
h, w = img.shape[:2]
|
||||
return img, w, h
|
||||
|
||||
|
||||
def _encode_image_to_data_url(path):
|
||||
with open(path, "rb") as f:
|
||||
encoded = base64.b64encode(f.read()).decode("utf-8")
|
||||
return {"url": f"data:image/jpeg;base64,{encoded}"}
|
||||
|
||||
|
||||
async def run_test(image_path):
|
||||
img, w, h = _load_image(image_path)
|
||||
img_input = _encode_image_to_data_url(image_path)
|
||||
|
||||
prompt = (
|
||||
"这是一张名为“新电途 到哪都能充”的充电站列表页面完整截图。\n"
|
||||
"请你从整张图中找出所有“单个充电站卡片”的区域,并按从上到下的顺序输出一个 JSON 数组。\n"
|
||||
"每个数组元素是一个对象,必须包含:\n"
|
||||
"1. station_name: 该卡片内的充电站名称字符串。\n"
|
||||
"2. bounds_norm: 该卡片在整张图片中的相对矩形坐标,字段为 left, top, right, bottom,"
|
||||
"取值范围均为 0 到 1,表示相对于整张图片宽高的比例。\n"
|
||||
"矩形应尽量完整覆盖卡片,包括名称、价格、桩数、距离等内容。\n"
|
||||
"如果你不完全确定精确边界,也要给出你认为最接近的矩形位置,不要省略。\n"
|
||||
"返回示例:\n"
|
||||
"[\n"
|
||||
" {\"station_name\": \"长春市南关区南溪湿地公园公共充电站\", \"bounds_norm\": {\"left\": 0.03, \"top\": 0.32, \"right\": 0.97, \"bottom\": 0.43}},\n"
|
||||
" {\"station_name\": \"长春环球贸易中心极东站\", \"bounds_norm\": {\"left\": 0.03, \"top\": 0.45, \"right\": 0.97, \"bottom\": 0.56}}\n"
|
||||
"]\n"
|
||||
"只允许输出一个 JSON 数组,不要包含任何解释性文字。"
|
||||
)
|
||||
|
||||
resp = await ReadImageKit._client.chat.completions.create(
|
||||
model=VL_MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": img_input},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
content = resp.choices[0].message.content or ""
|
||||
raw = ReadImageKit._extract_json(content)
|
||||
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except Exception as e:
|
||||
logger.error(f"解析模型返回JSON失败: {e}")
|
||||
logger.error(f"原始内容: {content}")
|
||||
return
|
||||
|
||||
if not isinstance(data, list):
|
||||
logger.error(f"期望返回JSON数组,但得到: {type(data)}")
|
||||
logger.error(f"解析结果: {data}")
|
||||
return
|
||||
|
||||
if not data:
|
||||
logger.info("模型返回空数组。原始内容如下:")
|
||||
logger.info(content)
|
||||
return
|
||||
|
||||
overlay = img.copy()
|
||||
results = []
|
||||
|
||||
for idx, item in enumerate(data):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
name = item.get("station_name")
|
||||
bounds_norm = item.get("bounds_norm") or {}
|
||||
try:
|
||||
l = float(bounds_norm.get("left"))
|
||||
t = float(bounds_norm.get("top"))
|
||||
r = float(bounds_norm.get("right"))
|
||||
b = float(bounds_norm.get("bottom"))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not (0 <= l < r <= 1 and 0 <= t < b <= 1):
|
||||
continue
|
||||
|
||||
x1 = int(l * w)
|
||||
y1 = int(t * h)
|
||||
x2 = int(r * w)
|
||||
y2 = int(b * h)
|
||||
|
||||
center_y = (y1 + y2) / 2.0
|
||||
top_safe_ratio = 0.30
|
||||
bottom_safe_ratio = BOTTOM_SAFE_EXCLUDE_RATIO
|
||||
in_top_safe = center_y < h * top_safe_ratio
|
||||
in_bottom_safe = center_y > h * (1 - bottom_safe_ratio)
|
||||
|
||||
color = (0, 255, 0)
|
||||
click_x = None
|
||||
click_y = None
|
||||
if in_top_safe or in_bottom_safe:
|
||||
color = (0, 0, 255)
|
||||
|
||||
cv2.rectangle(overlay, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
if not in_top_safe and not in_bottom_safe:
|
||||
click_x = int((x1 + x2) / 2)
|
||||
click_y = int((y1 + y2) / 2)
|
||||
cv2.circle(overlay, (click_x, click_y), 8, (0, 0, 255), -1)
|
||||
|
||||
if name:
|
||||
label = f"{idx + 1}:{name}"
|
||||
cv2.putText(
|
||||
overlay,
|
||||
label[:20],
|
||||
(x1 + 5, max(y1 - 10, 20)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
color,
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"index": idx + 1,
|
||||
"station_name": name,
|
||||
"rect": [x1, y1, x2, y2],
|
||||
"center_y": center_y,
|
||||
"in_top_safe": in_top_safe,
|
||||
"in_bottom_safe": in_bottom_safe,
|
||||
"click_point": [click_x, click_y],
|
||||
}
|
||||
)
|
||||
|
||||
out_path = image_path.replace(".jpg", "_vlm_rect.jpg")
|
||||
cv2.imwrite(out_path, overlay)
|
||||
|
||||
logger.info(f"输入图片: {image_path}")
|
||||
logger.info(f"调试输出图片: {out_path}")
|
||||
logger.info("识别到的场站及矩形坐标如下:")
|
||||
for item in results:
|
||||
logger.info(json.dumps(item, ensure_ascii=False))
|
||||
|
||||
|
||||
async def run_batch():
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
files = sorted(os.listdir(base_dir))
|
||||
|
||||
jpg_files = []
|
||||
for name in files:
|
||||
if not name.lower().endswith(".jpg"):
|
||||
continue
|
||||
if name.lower().endswith("_vlm_rect.jpg"):
|
||||
continue
|
||||
jpg_files.append(os.path.join(base_dir, name))
|
||||
|
||||
if not jpg_files:
|
||||
logger.info(f"目录中未找到待处理的 JPG 文件: {base_dir}")
|
||||
return
|
||||
|
||||
logger.info(f"即将批量处理 {len(jpg_files)} 张图片:")
|
||||
for p in jpg_files:
|
||||
logger.info(f" - {p}")
|
||||
|
||||
for p in jpg_files:
|
||||
await run_test(p)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 单张测试示例:
|
||||
# test_image = r"d:\dsWork\aiData\Test\TestFirstPage\Screenshot_20260117_161824.jpg"
|
||||
# asyncio.run(run_test(test_image))
|
||||
|
||||
asyncio.run(run_batch())
|
||||
|
||||