diff --git a/Config/Config.py b/Config/Config.py index b314652..de13f46 100644 --- a/Config/Config.py +++ b/Config/Config.py @@ -57,3 +57,16 @@ TEMP_IMAGE_DIR = r"d:\dsWork\aiData\Output" # False: 默认按各个供应商自己的 PRICE_FLATTEN_TO_24H 决定 # True: 强制所有供应商都铺平成 24 小时整点数组 PRICE_FLATTEN_TO_24H_GLOBAL = True + +# Banana Client配置 +BANANA_MODEL = 'TuZi' +# BANANA_MODEL = 'GPTNB' + +# GPTNB的API KEY +GPTNB_API_KEY = "sk-amQHwiEzPIZIB2KuF5A10dC23a0e4b02B48a7a2b6aFa0662" +GPTNB_BASE_URL = "https://goapi.gptnb.ai" + +# 兔子平台 配置 +TUZI_BASE_URL = "https://api.tu-zi.com" +# 兔子平台API 令牌 +TUZI_API_KEY = "sk-FCwlaMANKdSlXlY7HkzZncSp0N5gecfpQdk0iR059Hfk4dQ1" diff --git a/Config/__pycache__/Config.cpython-310.pyc b/Config/__pycache__/Config.cpython-310.pyc index 75f6f99..0121438 100644 Binary files a/Config/__pycache__/Config.cpython-310.pyc and b/Config/__pycache__/Config.cpython-310.pyc differ diff --git a/Util/BananaClient.py b/Util/BananaClient.py new file mode 100644 index 0000000..a3a15ab --- /dev/null +++ b/Util/BananaClient.py @@ -0,0 +1,502 @@ +import asyncio +import logging +import os +import re +import sys +import uuid +from Config.Config import * +import aiohttp + +# 配置日志 +# 获取项目根目录 +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(project_root) + +from Util.ObsUtil import ObsUploader + +# 配置日志 +logger = logging.getLogger(__name__) + +class BananaClient: + """ + 通用 Banana 图片生成客户端,支持 TuZi 和 GPTNB 供应商 + """ + + PROVIDERS = { + "TuZi": { + "api_key": TUZI_API_KEY, + "base_url": TUZI_BASE_URL, + "default_model": "gemini-3-pro-image-preview", + "models": { + "gemini-3-pro-image-preview": "gemini-3-pro-image-preview", + "gemini-3-pro-image-preview-2k": "gemini-3-pro-image-preview-2k", + "gemini-3-pro-image-preview-4k": "gemini-3-pro-image-preview-4k" + } + }, + "GPTNB": { + "api_key": GPTNB_API_KEY, + "base_url": GPTNB_BASE_URL, + "default_model": "nano-banana-2", + "models": { + "nano-banana-2": "nano-banana-2", + "nano-banana-2-4k": "nano-banana-2-4k", + "nano-banana-2-hd": "nano-banana-2-hd" + } + } + } + def __init__(self, provider=None, api_key=None, base_url=None, model=None, response_format="url", timeout=120): + """ + 初始化 BananaClient + + Args: + provider: 供应商名称,"TuZi" 或 "GPTNB"。如果为None则从配置读取BANANA_MODEL + api_key: API Key,如果不提供则使用配置中的默认值 + base_url: Base URL,如果不提供则使用配置中的默认值 + model: 模型名称,如果不提供则使用供应商的默认模型 + response_format: 响应格式,默认为 "url" + timeout: 超时时间,默认为 120 秒 + """ + if provider is None: + # 尝试从全局配置获取 + try: + provider = BANANA_MODEL + except NameError: + provider = None + + if provider: + logger.info(f"从配置中读取 Banana 提供商: {provider}") + else: + provider = "TuZi" + logger.warning("未在配置中找到 BANANA_MODEL,默认使用 TuZi") + + if provider not in self.PROVIDERS: + raise ValueError(f"不支持的供应商: {provider}。支持的供应商: {list(self.PROVIDERS.keys())}") + + self.provider_name = provider + provider_config = self.PROVIDERS[provider] + + self.api_key = api_key or provider_config["api_key"] + self.base_url = base_url or provider_config["base_url"] + self.model = model or provider_config["default_model"] + self.available_models = provider_config["models"] + + self.response_format = response_format + self.timeout = timeout + self.obs_uploader = ObsUploader() + + # 统一 provider 标识用于 OBS 路径等 (全小写) + self.provider_code = provider.lower() + + logger.info(f"初始化 BananaClient - 供应商: {self.provider_name}, base_url: {self.base_url}, 模型: {self.model}") + + async def prepare_image_data(self, image_paths): + """ + 准备图片数据,将本地文件上传到OBS,URL直接使用,包装为正确的API调用格式 + + 参数: + image_paths: 单个图片路径字符串或图片路径列表,可以是本地文件路径或URL + + 返回: + 单个图片的API调用格式数据或图片API调用格式数据列表 + """ + + async def process_single_image(image_path): + """处理单个图片路径,支持本地文件和URL""" + image_path = image_path.strip() + logger.info(f"开始处理图片: {image_path}") + + try: + # 检查是否是URL + if re.match(r'^https?://', image_path): + # 是URL,直接使用 + logger.info(f"图片是URL类型,直接使用: {image_path}") + return { + "type": "image_url", + "image_url": {"url": image_path} + } + else: + # 是本地文件,上传到OBS + logger.info(f"图片是本地文件,开始上传到OBS: {image_path}") + if not os.path.exists(image_path): + raise FileNotFoundError(f"本地图片文件不存在: {image_path}") + if not os.path.isfile(image_path): + raise IsADirectoryError(f"指定路径不是文件: {image_path}") + + file_size = os.path.getsize(image_path) + logger.info(f"本地文件存在,大小: {file_size} bytes") + + # 生成GUID作为文件名 + file_guid = str(uuid.uuid4()) + + # 从文件名获取扩展名 + _, ext = os.path.splitext(image_path) + ext = ext.lower() if ext else '.jpg' # 默认使用jpg + + # 构建OBS key,使用临时文件前缀 + obs_key = f"{OBS_TMP_PREFIX}/{file_guid}{ext}" + logger.debug(f"构建OBS key: {obs_key}") + + # 上传到OBS + logger.info(f"正在上传到OBS: {obs_key}") + success, result = self.obs_uploader.upload_file( + file_path=image_path, + object_key=obs_key, + bucket_name=OBS_BUCKET + ) + + if success: + # 构建CDN访问URL + cdn_url = f"https://{CDN_DOMAIN}/{obs_key}" + logger.info(f"✅ 图片上传到OBS成功: {cdn_url}") + logger.info(f" OBS Key: {obs_key}") + + return { + "type": "image_url", + "image_url": {"url": cdn_url} + } + else: + logger.error(f"上传到OBS失败: {result}") + raise Exception(f"上传到OBS失败: {result}") + + except Exception as e: + logger.error(f"处理单个图片时出错: {image_path} - {str(e)}") + raise + + try: + logger.info(f"准备图片数据: {image_paths}") + # 如果是单个图片路径字符串 + if isinstance(image_paths, str): + return await process_single_image(image_paths) + # 如果是图片路径列表 + elif isinstance(image_paths, list): + result = [] + for i, image_path in enumerate(image_paths): + logger.info(f"处理列表中的图片 {i+1}/{len(image_paths)}: {image_path}") + result.append(await process_single_image(image_path)) + logger.info(f"成功处理 {len(result)} 张图片") + return result + else: + raise TypeError(f"image_paths参数类型错误,应为str或list,实际为{type(image_paths)}") + except Exception as e: + logger.error(f"准备图片数据时出错: {e}") + raise + + async def generate_image(self, prompt, image_paths=None, n=1, size="1024x1024"): + """ + 异步调用 API 生成图片(支持文字生图和图片+文字改图) + + 参数: + prompt: 图片描述 + image_paths: 图片路径或URL,可以是单个字符串或列表(用于改图功能) + n: 生成图片数量,默认为 1 + size: 图片尺寸,默认为 "1024x1024" + + 返回: + response_json: API 响应的 JSON 数据 + """ + + # 确保使用有效的模型 + current_model = self.model + if current_model not in self.available_models: + logger.warning(f"警告: 模型 {current_model} 可能不在 {self.provider_name} 的支持列表中: {list(self.available_models.keys())}") + + logger.info(f"开始生成图片 - 提供商: {self.provider_name}, 模型: {current_model}, 尺寸: {size}, 数量: {n}") + logger.info(f"提示词: {prompt}") + if image_paths: + logger.info(f"图片输入: {image_paths}") + + # 构建API调用消息 + if image_paths: + # 处理图片数据 + image_data = await self.prepare_image_data(image_paths) + + # 构建payload,支持改图功能 + image_urls = [] + if isinstance(image_data, list): + for item in image_data: + if "image_url" in item and "url" in item["image_url"]: + image_urls.append(item["image_url"]["url"]) + else: + if "image_url" in image_data and "url" in image_data["image_url"]: + image_urls.append(image_data["image_url"]["url"]) + + # 如果只有一个URL,直接传递字符串;如果有多个,传递列表 + final_image_input = image_urls[0] if len(image_urls) == 1 else image_urls + + payload = { + "prompt": prompt, + "n": n, + "model": current_model, + "size": size, + "response_format": self.response_format, + "image": final_image_input, # 支持单个URL或URL列表 + "timeout": self.timeout + } + logger.info(f"构建改图请求,使用模型: {current_model},尺寸: {size}") + else: + # 仅文字生图 + payload = { + "prompt": prompt, + "n": n, + "model": current_model, + "size": size, + "response_format": self.response_format, + "timeout": self.timeout + } + logger.info(f"构建文字生图请求,使用模型: {current_model},尺寸: {size}") + + endpoint = f"{self.base_url}/v1/images/generations" + + headers = { + 'Content-Type': 'application/json', + "Authorization": f"Bearer {self.api_key}" + } + logger.debug(f"请求头: {headers}") + logger.debug(f"请求payload: {payload}") + + try: + logger.info(f"正在异步调用{self.provider_name} API...") + logger.info(f"使用端点: {endpoint}") + + async with aiohttp.ClientSession() as session: + async with session.post( + endpoint, + json=payload, + headers=headers, + timeout=self.timeout + ) as response: + response.raise_for_status() # 检查HTTP响应状态 + response_json = await response.json() + + logger.info("API调用完成!") + logger.debug(f"API响应: {response_json}") # 使用debug级别记录完整响应 + return response_json + + except Exception as e: + logger.error(f"API调用失败: {str(e)}") + raise + + async def get_image_urls(self, response_json): + """ + 异步获取图片 URL 列表 + + 参数: + response_json: API 响应的 JSON 数据 + + 返回: + image_urls: 图片 URL 列表 + """ + if response_json is None: + logger.info("未提供响应数据") + return [] + + # 检查数据结构并提取图片地址 + image_urls = [] + + # 处理不同的响应格式 + if "data" in response_json: + data = response_json["data"] + + # 格式1: data是字符串URL + if isinstance(data, str): + image_urls.append(data) + + # 格式2: data是包含url字段的对象列表 + elif isinstance(data, list): + for item in data: + if isinstance(item, dict): + # 检查是否有url字段 + if "url" in item: + image_urls.append(item["url"]) + # 检查是否有嵌套的数据结构 + elif "data" in item: + if isinstance(item["data"], str): + image_urls.append(item["data"]) + elif isinstance(item["data"], list): + for nested_item in item["data"]: + if isinstance(nested_item, dict) and "url" in nested_item: + image_urls.append(nested_item["url"]) + elif isinstance(nested_item, str): + image_urls.append(nested_item) + # 格式3: item本身是字符串URL + elif isinstance(item, str): + image_urls.append(item) + + # 格式4: data是一个对象,包含url字段 + elif isinstance(data, dict): + if "url" in data: + image_urls.append(data["url"]) + # 检查是否有嵌套的数据结构 + elif "data" in data: + if isinstance(data["data"], str): + image_urls.append(data["data"]) + elif isinstance(data["data"], list): + for item in data["data"]: + if isinstance(item, dict) and "url" in item: + image_urls.append(item["url"]) + elif isinstance(item, str): + image_urls.append(item) + + # Fallback for TuZi provider which might return 'url' at top level or nested differently + if not image_urls and "url" in response_json: + image_urls.append(response_json["url"]) + + # Log if no URLs found but response exists + if not image_urls: + logger.warning(f"Failed to extract image URLs from response: {response_json}") + + return image_urls + + async def print_image_urls(self, response_json): + """ + 异步打印图片 URL + + 参数: + response_json: API 响应的 JSON 数据 + """ + logger.info("\n正在解析图片地址...") + image_urls = await self.get_image_urls(response_json) + + if image_urls: + for i, url in enumerate(image_urls): + logger.info(f"图片 {i+1} URL: {url}") + else: + logger.info("未找到有效的图片 URL") + + async def download_and_upload_to_obs(self, response_json): + """ + 异步从临时图片URL下载图片并直接上传到OBS(不落盘) + + 参数: + response_json: API 响应的 JSON 数据 + + 返回: + uploaded_urls: 上传到OBS的图片URL列表 + """ + logger.info("\n正在下载图片并上传到OBS...") + + # 获取图片URL列表 + image_urls = await self.get_image_urls(response_json) + uploaded_urls = [] + + if not image_urls: + logger.info("未找到有效的图片 URL") + return uploaded_urls + + logger.info(f"找到 {len(image_urls)} 个图片URL,开始下载并上传到OBS") + + for i, image_url in enumerate(image_urls): + try: + logger.info(f"正在处理图片 {i+1}/{len(image_urls)}: {image_url}") + + # 生成GUID作为文件名 + file_guid = str(uuid.uuid4()) + + # 从URL获取文件扩展名 + if image_url.lower().endswith('.png'): + ext = 'png' + elif image_url.lower().endswith('.jpg') or image_url.lower().endswith('.jpeg'): + ext = 'jpg' + elif image_url.lower().endswith('.webp'): + ext = 'webp' + else: + ext = 'jpg' # 默认使用jpg + + # 构建OBS key + obs_key = f"{OBS_CLOUD_PREFIX}/{self.provider_code.capitalize()}/{file_guid}.{ext}" + logger.debug(f"构建OBS key: {obs_key}") + + # 使用异步方式下载图片 + logger.info(f"正在下载图片: {image_url}") + async with aiohttp.ClientSession() as session: + async with session.get(image_url, timeout=120) as response: + response.raise_for_status() + image_data = await response.read() + logger.debug(f"成功下载图片,大小: {len(image_data)} 字节") + + # 直接上传到OBS(不落盘) + logger.info(f"正在上传到OBS: {obs_key}") + success, result = self.obs_uploader.upload_base64_image( + object_key=obs_key, + base64_data=image_data, + bucket_name=OBS_BUCKET + ) + + if success: + # 构建OBS访问URL + obs_url = f"https://{OBS_BUCKET}.{OBS_SERVER}/{obs_key}" + # 构建CDN访问URL + cdn_url = f"https://{CDN_DOMAIN}/{obs_key}" + logger.info(f"✅ 图片 {i+1} 上传成功: {cdn_url}") + logger.info(f" OBS Key: {obs_key}") + uploaded_urls.append(cdn_url) + else: + logger.error(f"上传到OBS失败: {result}") + + except Exception as e: + logger.error(f"处理图片 {image_url} 失败: {str(e)}") + import traceback + logger.error(f"详细错误堆栈: {traceback.format_exc()}") + + logger.info(f"图片处理完成,成功上传 {len(uploaded_urls)}/{len(image_urls)} 个图片到OBS") + return uploaded_urls + + +async def test_provider(provider_name): + """测试指定供应商""" + print(f"\n{'='*20} 测试 {provider_name} 供应商 {'='*20}") + + try: + # 创建客户端实例 + client = BananaClient(provider=provider_name) + + # 测试1: 文字生图 + logger.info("\n---------------- 文字生图测试 ----------------") + text_prompt = "一只可爱的猫咪坐在月亮上,星空背景,卡通风格" + logger.info(f"提示词: {text_prompt}") + + response = await client.generate_image( + prompt=text_prompt, + n=1, + size="1024x1024" + ) + + if response: + logger.info("文字生图成功!") + await client.print_image_urls(response) + # 可选:上传到OBS + # await client.download_and_upload_to_obs(response) + else: + logger.error("文字生图失败!") + + except Exception as e: + logger.error(f"{provider_name} 测试失败: {e}") + import traceback + logger.error(traceback.format_exc()) + +async def main(): + """ + 异步主函数,用于测试 BananaClient 的功能 + """ + # 配置控制台日志输出 + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[logging.StreamHandler(sys.stdout)] + ) + + logger.info("开始 BananaClient 整合测试...") + + # 测试 TuZi (默认) + await test_provider("TuZi") + + # 测试 GPTNB + #await test_provider("GPTNB") + + logger.info("\n" + "="*50) + logger.info("测试全部完成!") + logger.info("="*50) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/static/Images/login_logo.png b/static/Images/login_logo.png new file mode 100644 index 0000000..ee5c25d Binary files /dev/null and b/static/Images/login_logo.png differ