Files
aiData/Util/BananaClient.py

551 lines
22 KiB
Python
Raw Permalink Normal View History

2026-01-20 09:00:05 +08:00
import asyncio
import logging
import os
import re
import sys
import uuid
from Config.Config import *
import aiohttp
# 配置日志
# 获取项目根目录
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
from Util.ObsUtil import ObsUploader
# 配置日志
logger = logging.getLogger(__name__)
class BananaClient:
"""
通用 Banana 图片生成客户端支持 TuZi GPTNB 供应商
"""
PROVIDERS = {
"TuZi": {
"api_key": TUZI_API_KEY,
"base_url": TUZI_BASE_URL,
"default_model": "gemini-3-pro-image-preview",
"models": {
"gemini-3-pro-image-preview": "gemini-3-pro-image-preview",
"gemini-3-pro-image-preview-2k": "gemini-3-pro-image-preview-2k",
"gemini-3-pro-image-preview-4k": "gemini-3-pro-image-preview-4k"
}
},
"GPTNB": {
"api_key": GPTNB_API_KEY,
"base_url": GPTNB_BASE_URL,
"default_model": "nano-banana-2",
"models": {
"nano-banana-2": "nano-banana-2",
"nano-banana-2-4k": "nano-banana-2-4k",
"nano-banana-2-hd": "nano-banana-2-hd"
}
}
}
def __init__(self, provider=None, api_key=None, base_url=None, model=None, response_format="url", timeout=120):
"""
初始化 BananaClient
Args:
provider: 供应商名称"TuZi" "GPTNB"如果为None则从配置读取BANANA_MODEL
api_key: API Key如果不提供则使用配置中的默认值
base_url: Base URL如果不提供则使用配置中的默认值
model: 模型名称如果不提供则使用供应商的默认模型
response_format: 响应格式默认为 "url"
timeout: 超时时间默认为 120
"""
if provider is None:
# 尝试从全局配置获取
try:
provider = BANANA_MODEL
except NameError:
provider = None
if provider:
logger.info(f"从配置中读取 Banana 提供商: {provider}")
else:
provider = "TuZi"
logger.warning("未在配置中找到 BANANA_MODEL默认使用 TuZi")
if provider not in self.PROVIDERS:
raise ValueError(f"不支持的供应商: {provider}。支持的供应商: {list(self.PROVIDERS.keys())}")
self.provider_name = provider
provider_config = self.PROVIDERS[provider]
self.api_key = api_key or provider_config["api_key"]
self.base_url = base_url or provider_config["base_url"]
self.model = model or provider_config["default_model"]
self.available_models = provider_config["models"]
self.response_format = response_format
self.timeout = timeout
self.obs_uploader = ObsUploader()
# 统一 provider 标识用于 OBS 路径等 (全小写)
self.provider_code = provider.lower()
logger.info(f"初始化 BananaClient - 供应商: {self.provider_name}, base_url: {self.base_url}, 模型: {self.model}")
async def prepare_image_data(self, image_paths):
"""
准备图片数据将本地文件上传到OBSURL直接使用包装为正确的API调用格式
参数:
image_paths: 单个图片路径字符串或图片路径列表可以是本地文件路径或URL
返回:
单个图片的API调用格式数据或图片API调用格式数据列表
"""
async def process_single_image(image_path):
"""处理单个图片路径支持本地文件和URL"""
image_path = image_path.strip()
logger.info(f"开始处理图片: {image_path}")
try:
# 检查是否是URL
if re.match(r'^https?://', image_path):
# 是URL直接使用
logger.info(f"图片是URL类型直接使用: {image_path}")
return {
"type": "image_url",
"image_url": {"url": image_path}
}
else:
# 是本地文件上传到OBS
logger.info(f"图片是本地文件开始上传到OBS: {image_path}")
if not os.path.exists(image_path):
raise FileNotFoundError(f"本地图片文件不存在: {image_path}")
if not os.path.isfile(image_path):
raise IsADirectoryError(f"指定路径不是文件: {image_path}")
file_size = os.path.getsize(image_path)
logger.info(f"本地文件存在,大小: {file_size} bytes")
# 生成GUID作为文件名
file_guid = str(uuid.uuid4())
# 从文件名获取扩展名
_, ext = os.path.splitext(image_path)
ext = ext.lower() if ext else '.jpg' # 默认使用jpg
# 构建OBS key使用临时文件前缀
obs_key = f"{OBS_TMP_PREFIX}/{file_guid}{ext}"
logger.debug(f"构建OBS key: {obs_key}")
# 上传到OBS
logger.info(f"正在上传到OBS: {obs_key}")
success, result = self.obs_uploader.upload_file(
file_path=image_path,
object_key=obs_key,
bucket_name=OBS_BUCKET
)
if success:
# 构建CDN访问URL
cdn_url = f"https://{CDN_DOMAIN}/{obs_key}"
logger.info(f"✅ 图片上传到OBS成功: {cdn_url}")
logger.info(f" OBS Key: {obs_key}")
return {
"type": "image_url",
"image_url": {"url": cdn_url}
}
else:
logger.error(f"上传到OBS失败: {result}")
raise Exception(f"上传到OBS失败: {result}")
except Exception as e:
logger.error(f"处理单个图片时出错: {image_path} - {str(e)}")
raise
try:
logger.info(f"准备图片数据: {image_paths}")
# 如果是单个图片路径字符串
if isinstance(image_paths, str):
return await process_single_image(image_paths)
# 如果是图片路径列表
elif isinstance(image_paths, list):
result = []
for i, image_path in enumerate(image_paths):
logger.info(f"处理列表中的图片 {i+1}/{len(image_paths)}: {image_path}")
result.append(await process_single_image(image_path))
logger.info(f"成功处理 {len(result)} 张图片")
return result
else:
raise TypeError(f"image_paths参数类型错误应为str或list实际为{type(image_paths)}")
except Exception as e:
logger.error(f"准备图片数据时出错: {e}")
raise
async def generate_image(self, prompt, image_paths=None, n=1, size="1024x1024"):
"""
异步调用 API 生成图片支持文字生图和图片+文字改图
参数:
prompt: 图片描述
image_paths: 图片路径或URL可以是单个字符串或列表用于改图功能
n: 生成图片数量默认为 1
size: 图片尺寸默认为 "1024x1024"
返回:
response_json: API 响应的 JSON 数据
"""
# 确保使用有效的模型
current_model = self.model
if current_model not in self.available_models:
logger.warning(f"警告: 模型 {current_model} 可能不在 {self.provider_name} 的支持列表中: {list(self.available_models.keys())}")
logger.info(f"开始生成图片 - 提供商: {self.provider_name}, 模型: {current_model}, 尺寸: {size}, 数量: {n}")
logger.info(f"提示词: {prompt}")
if image_paths:
logger.info(f"图片输入: {image_paths}")
# 构建API调用消息
if image_paths:
# 处理图片数据
image_data = await self.prepare_image_data(image_paths)
# 构建payload支持改图功能
image_urls = []
if isinstance(image_data, list):
for item in image_data:
if "image_url" in item and "url" in item["image_url"]:
image_urls.append(item["image_url"]["url"])
else:
if "image_url" in image_data and "url" in image_data["image_url"]:
image_urls.append(image_data["image_url"]["url"])
# 如果只有一个URL直接传递字符串如果有多个传递列表
final_image_input = image_urls[0] if len(image_urls) == 1 else image_urls
payload = {
"prompt": prompt,
"n": n,
"model": current_model,
"size": size,
"response_format": self.response_format,
"image": final_image_input, # 支持单个URL或URL列表
"timeout": self.timeout
}
logger.info(f"构建改图请求,使用模型: {current_model},尺寸: {size}")
else:
# 仅文字生图
payload = {
"prompt": prompt,
"n": n,
"model": current_model,
"size": size,
"response_format": self.response_format,
"timeout": self.timeout
}
logger.info(f"构建文字生图请求,使用模型: {current_model},尺寸: {size}")
endpoint = f"{self.base_url}/v1/images/generations"
headers = {
'Content-Type': 'application/json',
"Authorization": f"Bearer {self.api_key}"
}
logger.debug(f"请求头: {headers}")
logger.debug(f"请求payload: {payload}")
try:
logger.info(f"正在异步调用{self.provider_name} API...")
logger.info(f"使用端点: {endpoint}")
async with aiohttp.ClientSession() as session:
async with session.post(
endpoint,
json=payload,
headers=headers,
timeout=self.timeout
) as response:
response.raise_for_status() # 检查HTTP响应状态
response_json = await response.json()
logger.info("API调用完成")
logger.debug(f"API响应: {response_json}") # 使用debug级别记录完整响应
return response_json
except Exception as e:
logger.error(f"API调用失败: {str(e)}")
raise
async def get_image_urls(self, response_json):
"""
异步获取图片 URL 列表
参数:
response_json: API 响应的 JSON 数据
返回:
image_urls: 图片 URL 列表
"""
if response_json is None:
logger.info("未提供响应数据")
return []
# 检查数据结构并提取图片地址
image_urls = []
# 处理不同的响应格式
if "data" in response_json:
data = response_json["data"]
# 格式1: data是字符串URL
if isinstance(data, str):
image_urls.append(data)
# 格式2: data是包含url字段的对象列表
elif isinstance(data, list):
for item in data:
if isinstance(item, dict):
# 检查是否有url字段
if "url" in item:
image_urls.append(item["url"])
# 检查是否有嵌套的数据结构
elif "data" in item:
if isinstance(item["data"], str):
image_urls.append(item["data"])
elif isinstance(item["data"], list):
for nested_item in item["data"]:
if isinstance(nested_item, dict) and "url" in nested_item:
image_urls.append(nested_item["url"])
elif isinstance(nested_item, str):
image_urls.append(nested_item)
# 格式3: item本身是字符串URL
elif isinstance(item, str):
image_urls.append(item)
# 格式4: data是一个对象包含url字段
elif isinstance(data, dict):
if "url" in data:
image_urls.append(data["url"])
# 检查是否有嵌套的数据结构
elif "data" in data:
if isinstance(data["data"], str):
image_urls.append(data["data"])
elif isinstance(data["data"], list):
for item in data["data"]:
if isinstance(item, dict) and "url" in item:
image_urls.append(item["url"])
elif isinstance(item, str):
image_urls.append(item)
# Fallback for TuZi provider which might return 'url' at top level or nested differently
if not image_urls and "url" in response_json:
image_urls.append(response_json["url"])
# Log if no URLs found but response exists
if not image_urls:
logger.warning(f"Failed to extract image URLs from response: {response_json}")
return image_urls
async def print_image_urls(self, response_json):
"""
异步打印图片 URL
参数:
response_json: API 响应的 JSON 数据
"""
logger.info("\n正在解析图片地址...")
image_urls = await self.get_image_urls(response_json)
if image_urls:
for i, url in enumerate(image_urls):
logger.info(f"图片 {i+1} URL: {url}")
else:
logger.info("未找到有效的图片 URL")
2026-01-20 09:26:42 +08:00
async def download_and_upload_to_obs(self, response_json, overlay_logo_path=None):
2026-01-20 09:00:05 +08:00
"""
异步从临时图片URL下载图片并直接上传到OBS不落盘
参数:
response_json: API 响应的 JSON 数据
2026-01-20 09:26:42 +08:00
overlay_logo_path: 水印Logo图片路径可选
2026-01-20 09:00:05 +08:00
返回:
uploaded_urls: 上传到OBS的图片URL列表
"""
logger.info("\n正在下载图片并上传到OBS...")
# 获取图片URL列表
image_urls = await self.get_image_urls(response_json)
uploaded_urls = []
if not image_urls:
logger.info("未找到有效的图片 URL")
return uploaded_urls
logger.info(f"找到 {len(image_urls)} 个图片URL开始下载并上传到OBS")
for i, image_url in enumerate(image_urls):
try:
logger.info(f"正在处理图片 {i+1}/{len(image_urls)}: {image_url}")
# 生成GUID作为文件名
file_guid = str(uuid.uuid4())
# 从URL获取文件扩展名
if image_url.lower().endswith('.png'):
ext = 'png'
elif image_url.lower().endswith('.jpg') or image_url.lower().endswith('.jpeg'):
ext = 'jpg'
elif image_url.lower().endswith('.webp'):
ext = 'webp'
else:
ext = 'jpg' # 默认使用jpg
# 构建OBS key
obs_key = f"{OBS_CLOUD_PREFIX}/{self.provider_code.capitalize()}/{file_guid}.{ext}"
logger.debug(f"构建OBS key: {obs_key}")
# 使用异步方式下载图片
logger.info(f"正在下载图片: {image_url}")
async with aiohttp.ClientSession() as session:
async with session.get(image_url, timeout=120) as response:
response.raise_for_status()
image_data = await response.read()
logger.debug(f"成功下载图片,大小: {len(image_data)} 字节")
2026-01-20 09:26:42 +08:00
# 添加水印逻辑
if overlay_logo_path and os.path.exists(overlay_logo_path):
try:
from PIL import Image
import io
logger.info(f"正在添加水印: {overlay_logo_path}")
# 打开主图
img = Image.open(io.BytesIO(image_data))
# 打开Logo
logo = Image.open(overlay_logo_path)
# 计算Logo大小 (宽度为原图的20%)
target_logo_width = int(img.width * 0.2)
aspect_ratio = logo.height / logo.width
target_logo_height = int(target_logo_width * aspect_ratio)
# 调整Logo大小
logo = logo.resize((target_logo_width, target_logo_height), Image.Resampling.LANCZOS)
# 计算位置 (左上角3% padding)
padding = int(img.width * 0.03)
position = (padding, padding)
# 粘贴Logo (处理透明度)
if logo.mode in ('RGBA', 'LA') or (logo.mode == 'P' and 'transparency' in logo.info):
img.paste(logo, position, logo)
else:
img.paste(logo, position)
# 保存回bytes
output_buffer = io.BytesIO()
# 保持原格式如果无法识别则使用PNG
save_format = img.format if img.format else 'PNG'
img.save(output_buffer, format=save_format)
image_data = output_buffer.getvalue()
logger.info("水印添加成功")
except Exception as e:
logger.error(f"添加水印失败: {e}")
# 失败后继续上传原图
# 直接上传到OBS不落盘
2026-01-20 09:00:05 +08:00
# 直接上传到OBS不落盘
logger.info(f"正在上传到OBS: {obs_key}")
success, result = self.obs_uploader.upload_base64_image(
object_key=obs_key,
base64_data=image_data,
bucket_name=OBS_BUCKET
)
if success:
# 构建OBS访问URL
obs_url = f"https://{OBS_BUCKET}.{OBS_SERVER}/{obs_key}"
# 构建CDN访问URL
cdn_url = f"https://{CDN_DOMAIN}/{obs_key}"
logger.info(f"✅ 图片 {i+1} 上传成功: {cdn_url}")
logger.info(f" OBS Key: {obs_key}")
uploaded_urls.append(cdn_url)
else:
logger.error(f"上传到OBS失败: {result}")
except Exception as e:
logger.error(f"处理图片 {image_url} 失败: {str(e)}")
import traceback
logger.error(f"详细错误堆栈: {traceback.format_exc()}")
logger.info(f"图片处理完成,成功上传 {len(uploaded_urls)}/{len(image_urls)} 个图片到OBS")
return uploaded_urls
async def test_provider(provider_name):
"""测试指定供应商"""
print(f"\n{'='*20} 测试 {provider_name} 供应商 {'='*20}")
try:
# 创建客户端实例
client = BananaClient(provider=provider_name)
# 测试1: 文字生图
logger.info("\n---------------- 文字生图测试 ----------------")
text_prompt = "一只可爱的猫咪坐在月亮上,星空背景,卡通风格"
logger.info(f"提示词: {text_prompt}")
response = await client.generate_image(
prompt=text_prompt,
n=1,
size="1024x1024"
)
if response:
logger.info("文字生图成功!")
await client.print_image_urls(response)
# 可选上传到OBS
# await client.download_and_upload_to_obs(response)
else:
logger.error("文字生图失败!")
except Exception as e:
logger.error(f"{provider_name} 测试失败: {e}")
import traceback
logger.error(traceback.format_exc())
async def main():
"""
异步主函数用于测试 BananaClient 的功能
"""
# 配置控制台日志输出
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger.info("开始 BananaClient 整合测试...")
# 测试 TuZi (默认)
await test_provider("TuZi")
# 测试 GPTNB
#await test_provider("GPTNB")
logger.info("\n" + "="*50)
logger.info("测试全部完成!")
logger.info("="*50)
if __name__ == "__main__":
asyncio.run(main())