Files
aiData/Test/TestOcrLlm.py

174 lines
5.5 KiB
Python
Raw Normal View History

2026-01-12 07:49:18 +08:00
# coding=utf-8
import os
import sys
import time
import asyncio
import json
# Add project root to sys.path
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
sys.path.append(project_root)
from paddleocr import PaddleOCR
from Util.LlmUtil import get_llm_response
from Util.OcrParser import OcrParser
import re
LOG_FILE = os.path.join(current_dir, "ocr_llm_debug.txt")
def log(msg):
print(msg)
sys.stdout.flush()
with open(LOG_FILE, "a", encoding="utf-8") as f:
f.write(msg + "\n")
# ... imports ...
def run_ocr_sync():
image_path = os.path.join(current_dir, "2.jpg")
if not os.path.exists(image_path):
image_path = os.path.join(current_dir, "1.jpg")
log(f"Testing OCR + LLM Pipeline on: {image_path}")
log("-" * 50)
# --- Step 1: PaddleOCR ---
t_start = time.time()
log("Initializing PaddleOCR...")
t_init_start = time.time()
try:
# 尝试使用轻量级模型 (Mobile) 以提升速度
# ocr_version='PP-OCRv4' 通常默认是 mobile
ocr = PaddleOCR(use_textline_orientation=True, lang="ch", ocr_version='PP-OCRv4')
except Exception as e:
log(f"PaddleOCR Init Failed: {e}")
return None, None
t_init_end = time.time()
log(f"PaddleOCR Init Time: {t_init_end - t_init_start:.4f}s")
log("Running OCR Inference...")
t_ocr_start = time.time()
try:
result = ocr.ocr(image_path)
except Exception as e:
log(f"OCR Inference Failed: {e}")
return None, None
t_ocr_end = time.time()
ocr_text_lines = []
# Handle different result structures
if not result:
log("OCR returned empty result.")
else:
res = result[0]
if res is None:
log("OCR result[0] is None.")
elif hasattr(res, 'get') and 'rec_texts' in res:
ocr_text_lines = res.get('rec_texts', [])
elif hasattr(res, 'rec_texts'):
ocr_text_lines = res.rec_texts
elif isinstance(res, list):
for line in res:
if len(line) >= 2 and isinstance(line[1], (tuple, list)):
ocr_text_lines.append(line[1][0])
ocr_text_block = "\n".join(ocr_text_lines)
log(f"OCR Result ({t_ocr_end - t_ocr_start:.4f}s):")
log(ocr_text_block)
log("-" * 50)
return ocr_text_lines, ocr_text_block, (t_ocr_start, t_ocr_end)
async def run_parsing_comparison(ocr_text_lines, ocr_text_block, timing_ocr):
t_ocr_start, t_ocr_end = timing_ocr
ocr_duration = t_ocr_end - t_ocr_start
# --- Mode 1: Regex Parsing ---
log("Running Regex Parsing...")
t_regex_start = time.time()
try:
regex_data = OcrParser.parse(ocr_text_lines)
log("\nParsed Data (Regex):")
log(json.dumps(regex_data, indent=2, ensure_ascii=False))
except Exception as e:
log(f"Regex Parsing Failed: {e}")
t_regex_end = time.time()
regex_duration = t_regex_end - t_regex_start
log(f"Regex Parsing Time: {regex_duration:.4f}s")
log("-" * 50)
# --- Mode 2: LLM Parsing ---
log("Running LLM Parsing...")
prompt = f"""
You are a data extraction assistant. Below is the OCR text recognized from a charging station list card.
Please extract the structured data and return it ONLY as a JSON object (no markdown, no extra text).
Fields to extract:
- station_name: (String) Name of the charging station.
- distance: (String) Distance info (e.g., "7.4km").
- price: (String) Price info (e.g., "0.7111/度").
- tags: (List[String]) Any tags like "", "闲3/4", "组团", "2倍积分", "P", etc.
- parking_info: (String) Parking related info.
OCR Text:
{ocr_text_block}
"""
t_llm_start = time.time()
response_text = ""
try:
log("Starting LLM request...")
async for chunk in get_llm_response(prompt, stream=True):
print(chunk, end='', flush=True)
response_text += chunk
print("\n")
log("LLM request finished.")
t_llm_end = time.time()
log(f"\nLLM Response ({t_llm_end - t_llm_start:.4f}s):")
log(response_text)
try:
clean_text = response_text.replace("```json", "").replace("```", "").strip()
data = json.loads(clean_text)
log("\nParsed JSON Data:")
log(json.dumps(data, indent=2, ensure_ascii=False))
except json.JSONDecodeError:
log("\nFailed to parse JSON directly.")
except Exception as e:
log(f"LLM Error: {e}")
t_llm_end = time.time()
log("-" * 50)
log(f"Summary:")
log(f"OCR Time: {ocr_duration:.4f}s")
log(f"Regex Parsing Time: {regex_duration:.4f}s")
log(f"LLM Parsing Time: {t_llm_end - t_llm_start:.4f}s")
total_regex = ocr_duration + regex_duration
total_llm = ocr_duration + (t_llm_end - t_llm_start)
log(f"Total Pipeline (OCR+Regex): {total_regex:.4f}s")
log(f"Total Pipeline (OCR+LLM): {total_llm:.4f}s")
def main():
# Clear log file
with open(LOG_FILE, "w", encoding="utf-8") as f:
f.write("Starting TestOcrLlm...\n")
ocr_lines, ocr_text, timing = run_ocr_sync()
if ocr_lines:
asyncio.run(run_parsing_comparison(ocr_lines, ocr_text, timing))
if __name__ == "__main__":
main()