231 lines
9.7 KiB
Python
231 lines
9.7 KiB
Python
import logging
|
||
import uuid
|
||
import asyncio
|
||
from datetime import datetime
|
||
from fastapi import APIRouter, HTTPException
|
||
from pydantic import BaseModel
|
||
from sqlalchemy.sql import text
|
||
|
||
from Util.BananaClient import BananaClient
|
||
from Util.LlmUtil import get_llm_response
|
||
from DbKit.Db import Db
|
||
|
||
router = APIRouter(prefix="/haibao")
|
||
logger = logging.getLogger("HaiBaoController")
|
||
|
||
class GenerateRequest(BaseModel):
|
||
prompt: str
|
||
width: int = 1024
|
||
height: int = 1024
|
||
|
||
@router.on_event("startup")
|
||
async def startup_event():
|
||
"""初始化时检查并创建表"""
|
||
db = Db()
|
||
await db.init_db()
|
||
|
||
# Doris 建表语句
|
||
create_table_sql = """
|
||
CREATE TABLE IF NOT EXISTS haibao_history (
|
||
id VARCHAR(50) COMMENT "ID",
|
||
prompt TEXT COMMENT "提示词",
|
||
image_url VARCHAR(500) COMMENT "图片URL",
|
||
scheme_content TEXT COMMENT "文案方案",
|
||
created_at DATETIME COMMENT "创建时间"
|
||
)
|
||
DUPLICATE KEY(id)
|
||
DISTRIBUTED BY HASH(id) BUCKETS 1
|
||
PROPERTIES (
|
||
"replication_num" = "1"
|
||
);
|
||
"""
|
||
try:
|
||
# 使用 engine 直接执行 DDL
|
||
async with db.engine.begin() as conn:
|
||
await conn.execute(text(create_table_sql))
|
||
|
||
# 尝试添加列(如果表已存在但列不存在)
|
||
# Doris 不支持 IF NOT EXISTS for ADD COLUMN directly nicely in all versions without error if exists
|
||
# 所以这里简单捕获异常,如果列已存在则忽略
|
||
try:
|
||
alter_sql = "ALTER TABLE haibao_history ADD COLUMN scheme_content TEXT COMMENT '文案方案'"
|
||
await conn.execute(text(alter_sql))
|
||
except Exception as e:
|
||
# 忽略列已存在的错误
|
||
pass
|
||
|
||
logger.info("海报历史表检查/更新成功")
|
||
except Exception as e:
|
||
logger.error(f"海报历史表创建/更新失败: {e}")
|
||
|
||
class RefineRequest(BaseModel):
|
||
prompt: str
|
||
|
||
@router.post("/refine")
|
||
async def refine_prompt(req: RefineRequest):
|
||
"""润色提示词"""
|
||
try:
|
||
refine_system_prompt = "你是一个资深的AI绘画提示词专家。你的任务是将用户简短的描述扩充为一段详细、高质量的画面描述提示词,用于生成宣传海报。"
|
||
refine_user_prompt = f"""
|
||
请根据以下主题,为充电企业“驿来特”设计一张宣传海报的画面描述。
|
||
|
||
主题:{req.prompt}
|
||
|
||
要求:
|
||
1. 描述画面主体、背景、光影、色彩、构图。
|
||
2. 风格要求:现代感、科技感、精美、3D渲染风格或高品质插画风格。
|
||
3. 融入新能源、绿色环保、充电桩等元素。
|
||
4. 关于品牌元素:画面中可自然融入品牌Logo的视觉元素(如配色、形状),能用多少就用多少,有元素体现即可,不必生搬硬套,保持画面自然和谐。
|
||
5. 直接输出提示词内容,不要包含“好的”、“以下是”等无关废话。
|
||
6. 字数在100-300字之间。
|
||
"""
|
||
|
||
refined_prompt = ""
|
||
try:
|
||
async for chunk in get_llm_response(query_text=refine_user_prompt, system_prompt=refine_system_prompt, stream=False):
|
||
refined_prompt += chunk
|
||
except Exception as e:
|
||
logger.error(f"提示词润色失败: {e}")
|
||
raise Exception("润色服务暂时不可用")
|
||
|
||
return {"refined_prompt": refined_prompt}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@router.post("/generate")
|
||
async def generate_poster(req: GenerateRequest):
|
||
"""生成海报及文案"""
|
||
try:
|
||
# 并行执行生图和生文
|
||
client = BananaClient()
|
||
|
||
# 1. 构造生图任务 (包含智能润色判断)
|
||
async def generate_image_task():
|
||
final_prompt = req.prompt
|
||
|
||
# 智能判断:如果提示词太短(少于50字),则认为用户未进行润色,自动执行润色
|
||
# 如果用户使用了"一键扩写"功能,提示词通常会很长,这里就会跳过自动润色,尊重用户的修改
|
||
if len(final_prompt) < 50:
|
||
logger.info(f"提示词较短({len(final_prompt)}字),执行自动润色...")
|
||
refine_system_prompt = "你是一个资深的AI绘画提示词专家。你的任务是将用户简短的描述扩充为一段详细、高质量的画面描述提示词,用于生成宣传海报。"
|
||
refine_user_prompt = f"""
|
||
请根据以下主题,为充电企业“驿来特”设计一张宣传海报的画面描述。
|
||
|
||
主题:{req.prompt}
|
||
|
||
要求:
|
||
1. 描述画面主体、背景、光影、色彩、构图。
|
||
2. 风格要求:现代感、科技感、精美、3D渲染风格或高品质插画风格。
|
||
3. 融入新能源、绿色环保、充电桩等元素。
|
||
4. 关于品牌元素:画面中可自然融入品牌Logo的视觉元素(如配色、形状),能用多少就用多少,有元素体现即可,不必生搬硬套,保持画面自然和谐。
|
||
5. 直接输出提示词内容,不要包含“好的”、“以下是”等无关废话。
|
||
6. 字数在100-300字之间。
|
||
"""
|
||
|
||
refined_prompt = ""
|
||
try:
|
||
async for chunk in get_llm_response(query_text=refine_user_prompt, system_prompt=refine_system_prompt, stream=False):
|
||
refined_prompt += chunk
|
||
if refined_prompt and refined_prompt.strip():
|
||
final_prompt = refined_prompt
|
||
except Exception as e:
|
||
logger.error(f"自动润色失败,使用原始提示词: {e}")
|
||
|
||
logger.info(f"Final generation prompt: {final_prompt}")
|
||
|
||
# 1.2 调用生图
|
||
resp = await client.generate_image(prompt=final_prompt, size=f"{req.width}x{req.height}")
|
||
# 定义Logo路径
|
||
LOGO_PATH = r"d:\dsWork\aiData\static\Images\login_logo.png"
|
||
obs_urls = await client.download_and_upload_to_obs(resp, overlay_logo_path=LOGO_PATH)
|
||
if not obs_urls:
|
||
raise Exception("未获取到有效的图片URL")
|
||
return obs_urls[0]
|
||
|
||
# 2. 构造生文任务
|
||
async def generate_text_task():
|
||
scheme_prompt = f"""
|
||
你是一个专业的社群运营专家。请为充电企业“驿来特”撰写一段发在微信群里的宣传文案。
|
||
|
||
主题:{req.prompt}
|
||
|
||
要求:
|
||
1. 语气亲切、有吸引力,适合微信社群传播。
|
||
2. 突出“驿来特”品牌,强调新能源、优惠、便利等特点(根据主题自由发挥)。
|
||
3. 包含适当的emoji表情,增加趣味性。
|
||
4. 字数控制在150字以内。
|
||
5. 格式清晰,分段合理。
|
||
"""
|
||
# get_llm_response 是一个异步生成器 (stream=True by default) 或者直接返回 (stream=False)
|
||
# 这里我们强制 stream=False 获取完整文本
|
||
text_response = ""
|
||
# LlmUtil.get_llm_response 默认为 stream=True,我们需要修改调用方式或适配
|
||
# 查看 LlmUtil 源码,如果 stream=False,它 yield 内容。
|
||
# 所以我们需要迭代它
|
||
async for chunk in get_llm_response(query_text=scheme_prompt, stream=False):
|
||
text_response += chunk
|
||
return text_response
|
||
|
||
# 3. 并行执行
|
||
image_url, scheme_content = await asyncio.gather(generate_image_task(), generate_text_task())
|
||
|
||
# 4. 保存到数据库
|
||
db = Db()
|
||
record_id = str(uuid.uuid4())
|
||
created_at = datetime.now()
|
||
|
||
insert_sql = """
|
||
INSERT INTO haibao_history (id, prompt, image_url, scheme_content, created_at)
|
||
VALUES (:id, :prompt, :image_url, :scheme_content, :created_at)
|
||
"""
|
||
|
||
params = {
|
||
"id": record_id,
|
||
"prompt": req.prompt,
|
||
"image_url": image_url,
|
||
"scheme_content": scheme_content,
|
||
"created_at": created_at
|
||
}
|
||
|
||
async with db.get_session() as session:
|
||
async with session.begin():
|
||
await session.execute(text(insert_sql), params)
|
||
|
||
return {
|
||
"id": record_id,
|
||
"image_url": image_url,
|
||
"prompt": req.prompt,
|
||
"scheme_content": scheme_content,
|
||
"created_at": created_at.strftime("%Y-%m-%d %H:%M:%S")
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成海报/文案失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@router.get("/history")
|
||
async def get_history():
|
||
"""获取海报生成历史"""
|
||
db = Db()
|
||
sql = "SELECT * FROM haibao_history ORDER BY created_at DESC LIMIT 50"
|
||
try:
|
||
result = await db.find(sql)
|
||
|
||
formatted_result = []
|
||
for item in result:
|
||
item_dict = dict(item) if not isinstance(item, dict) else item
|
||
|
||
if isinstance(item_dict.get('created_at'), datetime):
|
||
item_dict['created_at'] = item_dict['created_at'].strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
# 确保 scheme_content 存在
|
||
if 'scheme_content' not in item_dict:
|
||
item_dict['scheme_content'] = ""
|
||
|
||
formatted_result.append(item_dict)
|
||
|
||
return formatted_result
|
||
except Exception as e:
|
||
logger.error(f"获取历史失败: {e}")
|
||
return []
|