import gradio as gr import torch from transformers import pipeline, set_seed from diffusers import AutoPipelineForText2Image import openai import os import time import traceback from typing import Optional, Tuple, Union, Literal, TypedDict from PIL import Image # 在代码开头添加: import os os.environ["OPENAI_API_KEY"] = "sk-your-api-key-here" # ---- 类型定义 ---- class ModelConfig(TypedDict): model_id: str dtype: torch.dtype timeout: int class UIConfig(TypedDict): title: str description: str warning_css: str # ---- 配置管理 ---- class AppConfig: # 硬件配置 DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" # 模型配置 MODEL: ModelConfig = { "model_id": "nota-ai/bk-sdm-tiny", "dtype": torch.float32, "timeout": 300 } # 界面配置 UI: UIConfig = { "title": "🎨 轻量级AI图像生成器(CPU/GPU版)", "description": """\ 💡 使用技巧:输入简短描述后选择风格和质量选项\n 🚀 支持语音输入 • 自动提示词优化 • 快速生成模式\n ⚠️ 注意:小模型生成速度快但细节有限,建议使用具体描述""", "warning_css": """ .warning {color: orange !important; border-left: 3px solid orange; padding: 10px;} .success {color: green !important;} """ } # 生成参数 DEFAULT_STEPS: int = 20 MAX_STEPS: int = 40 DEFAULT_GUIDANCE: float = 5.0 # 错误模板 @staticmethod def error_msg(message: str) -> str: return f"❌ 错误:{message}" config = AppConfig() # ---- 初始化检查 ---- openai_client: Optional[openai.OpenAI] = None openai_available: bool = False if os.environ.get("OPENAI_API_KEY"): try: openai_client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) openai_available = True print("✅ OpenAI 客户端初始化成功") except Exception as e: print(config.error_msg(f"OpenAI 初始化失败: {e}")) # ---- 模型加载 ---- class DummyPipe: def __call__(self, *args, **kwargs) -> None: raise RuntimeError("图像生成模型未加载") # 语音识别模型 asr_pipeline = None try: asr_pipeline = pipeline( "automatic-speech-recognition", model="openai/whisper-base", device=config.DEVICE, torch_dtype=config.MODEL["dtype"] ) print("✅ 语音识别模型加载成功") except Exception as e: print(config.error_msg(f"语音模型加载失败: {e}")) # 图像生成模型 image_pipe: Union[AutoPipelineForText2Image, DummyPipe] = DummyPipe() try: image_pipe = AutoPipelineForText2Image.from_pretrained( config.MODEL["model_id"], torch_dtype=config.MODEL["dtype"], use_safetensors=True, resume_download=True, timeout=config.MODEL["timeout"] ).to(config.DEVICE) print(f"✅ 图像模型 {config.MODEL['model_id']} 加载成功") except Exception as e: print(config.error_msg(f"图像模型加载失败: {e}")) # ---- 核心功能 ---- def enhance_prompt(short_prompt: str, style: str, quality: list) -> str: """提示词优化处理""" if not short_prompt.strip(): raise gr.Error("描述内容不能为空") # 基础增强模板 base_prompt = f"{short_prompt.strip()}, {style}, {', '.join(quality)}" if not openai_available: return base_prompt try: response = openai_client.chat.completions.create( model="gpt-3.5-turbo", messages=[{ "role": "system", "content": "你是一个AI绘画提示词专家,请把用户的简短描述扩展为适合小模型使用的详细提示词。" }, { "role": "user", "content": f"请优化这个提示词:'{base_prompt}'。要求:保持简洁,适合快速生成,包含主要视觉元素。" }], temperature=0.7, max_tokens=100 ) return response.choices[0].message.content.strip('"') except Exception as e: print(config.error_msg(f"提示词优化失败: {e}")) return base_prompt def generate_image(prompt: str, neg_prompt: str, cfg: float, steps: int) -> Image.Image: """图像生成核心函数""" if isinstance(image_pipe, DummyPipe): raise gr.Error("图像生成功能不可用:模型加载失败") try: with torch.no_grad(): result = image_pipe( prompt=prompt, negative_prompt=neg_prompt, guidance_scale=cfg, num_inference_steps=steps, generator=torch.Generator(config.DEVICE).manual_seed(int(time.time())) ) return result.images[0] except Exception as e: raise gr.Error(f"生成失败: {str(e)}") def transcribe_audio(audio_path: str) -> str: """语音转文字处理""" if not asr_pipeline or not audio_path: return "" try: return asr_pipeline(audio_path)["text"].strip() except Exception as e: print(config.error_msg(f"语音识别失败: {e}")) return "" # ---- 界面逻辑 ---- STYLE_OPTIONS = { "🎥 电影风格": "cinematic lighting", "🖼️ 照片写实": "photorealistic", "🇯🇵 二次元": "anime style", "🎨 水彩艺术": "watercolor painting" } QUALITY_OPTIONS = [ "高清细节", "复杂构图", "专业光影", "4K分辨率" ] def process_inputs( text: str, audio: Optional[str], style: str, quality: list, neg_prompt: str, cfg: float, steps: int ) -> Tuple[str, Optional[Image.Image]]: """主处理流程""" try: # 输入处理 final_text = text.strip() if audio and os.path.exists(audio): final_text = transcribe_audio(audio) or final_text # 提示词优化 enhanced = enhance_prompt(final_text, STYLE_OPTIONS[style], quality) # 图像生成 start_time = time.time() image = generate_image(enhanced, neg_prompt, cfg, steps) time_cost = time.time() - start_time return f"✅ 生成成功(耗时:{time_cost:.1f}s)\n{enhanced}", image except Exception as e: return f"❌ 生成失败:{str(e)}", None # ---- Gradio界面 ---- with gr.Blocks(theme=gr.themes.Soft(), css=config.UI["warning_css"]) as app: # 标题区 gr.Markdown(f"## {config.UI['title']}") gr.Markdown(config.UI["description"]) # 状态提示 if not openai_available: gr.HTML("
⚠️ OpenAI服务未启用,使用基础提示优化
") if isinstance(image_pipe, DummyPipe): gr.HTML("
⚠️ 图像生成功能不可用:模型加载失败
") with gr.Row(): # 输入列 with gr.Column(scale=1): input_text = gr.Textbox( label="📝 输入描述", placeholder="例:机械猫在火星咖啡馆喝咖啡", max_lines=3 ) audio_input = gr.Audio( sources=["microphone"], type="filepath", label="🎤 语音输入", visible=bool(asr_pipeline) ) with gr.Accordion("⚙️ 高级参数", open=False): style_select = gr.Dropdown( label="艺术风格", choices=list(STYLE_OPTIONS.keys()), value="🎥 电影风格" ) quality_check = gr.CheckboxGroup( label="质量增强", choices=QUALITY_OPTIONS, value=["高清细节"] ) neg_prompt = gr.Textbox( label="🚫 排除内容", placeholder="输入不希望出现的元素..." ) cfg_slider = gr.Slider( 1.0, 10.0, value=config.DEFAULT_GUIDANCE, label="生成引导强度" ) steps_slider = gr.Slider( 5, config.MAX_STEPS, value=config.DEFAULT_STEPS, label="迭代步数" ) generate_btn = gr.Button( "✨ 开始生成", variant="primary", interactive=not isinstance(image_pipe, DummyPipe) ) # 输出列 with gr.Column(scale=1): prompt_output = gr.Textbox( label="📋 生成提示", interactive=False, lines=4 ) image_output = gr.Image( label="🖼️ 生成结果", type="pil", height=512, show_download_button=True ) # 事件绑定 inputs = [input_text, audio_input, style_select, quality_check, neg_prompt, cfg_slider, steps_slider] generate_btn.click(process_inputs, inputs, [prompt_output, image_output]) # 音频输入自动清空文本 if asr_pipeline: audio_input.change( lambda x: "" if x else gr.update(), audio_input, input_text ) # ---- 启动应用 ---- if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860)