Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import pipeline, set_seed | |
from diffusers import AutoPipelineForText2Image | |
import openai | |
import os | |
import time | |
import traceback | |
from typing import Optional, Tuple, Union, Literal, TypedDict | |
from PIL import Image | |
# 在代码开头添加: | |
import os | |
os.environ["OPENAI_API_KEY"] = "sk-your-api-key-here" | |
# ---- 类型定义 ---- | |
class ModelConfig(TypedDict): | |
model_id: str | |
dtype: torch.dtype | |
timeout: int | |
class UIConfig(TypedDict): | |
title: str | |
description: str | |
warning_css: str | |
# ---- 配置管理 ---- | |
class AppConfig: | |
# 硬件配置 | |
DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" | |
# 模型配置 | |
MODEL: ModelConfig = { | |
"model_id": "nota-ai/bk-sdm-tiny", | |
"dtype": torch.float32, | |
"timeout": 300 | |
} | |
# 界面配置 | |
UI: UIConfig = { | |
"title": "🎨 轻量级AI图像生成器(CPU/GPU版)", | |
"description": """\ | |
💡 使用技巧:输入简短描述后选择风格和质量选项\n | |
🚀 支持语音输入 • 自动提示词优化 • 快速生成模式\n | |
⚠️ 注意:小模型生成速度快但细节有限,建议使用具体描述""", | |
"warning_css": """ | |
.warning {color: orange !important; border-left: 3px solid orange; padding: 10px;} | |
.success {color: green !important;} | |
""" | |
} | |
# 生成参数 | |
DEFAULT_STEPS: int = 20 | |
MAX_STEPS: int = 40 | |
DEFAULT_GUIDANCE: float = 5.0 | |
# 错误模板 | |
def error_msg(message: str) -> str: | |
return f"❌ 错误:{message}" | |
config = AppConfig() | |
# ---- 初始化检查 ---- | |
openai_client: Optional[openai.OpenAI] = None | |
openai_available: bool = False | |
if os.environ.get("OPENAI_API_KEY"): | |
try: | |
openai_client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) | |
openai_available = True | |
print("✅ OpenAI 客户端初始化成功") | |
except Exception as e: | |
print(config.error_msg(f"OpenAI 初始化失败: {e}")) | |
# ---- 模型加载 ---- | |
class DummyPipe: | |
def __call__(self, *args, **kwargs) -> None: | |
raise RuntimeError("图像生成模型未加载") | |
# 语音识别模型 | |
asr_pipeline = None | |
try: | |
asr_pipeline = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-base", | |
device=config.DEVICE, | |
torch_dtype=config.MODEL["dtype"] | |
) | |
print("✅ 语音识别模型加载成功") | |
except Exception as e: | |
print(config.error_msg(f"语音模型加载失败: {e}")) | |
# 图像生成模型 | |
image_pipe: Union[AutoPipelineForText2Image, DummyPipe] = DummyPipe() | |
try: | |
image_pipe = AutoPipelineForText2Image.from_pretrained( | |
config.MODEL["model_id"], | |
torch_dtype=config.MODEL["dtype"], | |
use_safetensors=True, | |
resume_download=True, | |
timeout=config.MODEL["timeout"] | |
).to(config.DEVICE) | |
print(f"✅ 图像模型 {config.MODEL['model_id']} 加载成功") | |
except Exception as e: | |
print(config.error_msg(f"图像模型加载失败: {e}")) | |
# ---- 核心功能 ---- | |
def enhance_prompt(short_prompt: str, style: str, quality: list) -> str: | |
"""提示词优化处理""" | |
if not short_prompt.strip(): | |
raise gr.Error("描述内容不能为空") | |
# 基础增强模板 | |
base_prompt = f"{short_prompt.strip()}, {style}, {', '.join(quality)}" | |
if not openai_available: | |
return base_prompt | |
try: | |
response = openai_client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[{ | |
"role": "system", | |
"content": "你是一个AI绘画提示词专家,请把用户的简短描述扩展为适合小模型使用的详细提示词。" | |
}, { | |
"role": "user", | |
"content": f"请优化这个提示词:'{base_prompt}'。要求:保持简洁,适合快速生成,包含主要视觉元素。" | |
}], | |
temperature=0.7, | |
max_tokens=100 | |
) | |
return response.choices[0].message.content.strip('"') | |
except Exception as e: | |
print(config.error_msg(f"提示词优化失败: {e}")) | |
return base_prompt | |
def generate_image(prompt: str, neg_prompt: str, cfg: float, steps: int) -> Image.Image: | |
"""图像生成核心函数""" | |
if isinstance(image_pipe, DummyPipe): | |
raise gr.Error("图像生成功能不可用:模型加载失败") | |
try: | |
with torch.no_grad(): | |
result = image_pipe( | |
prompt=prompt, | |
negative_prompt=neg_prompt, | |
guidance_scale=cfg, | |
num_inference_steps=steps, | |
generator=torch.Generator(config.DEVICE).manual_seed(int(time.time())) | |
) | |
return result.images[0] | |
except Exception as e: | |
raise gr.Error(f"生成失败: {str(e)}") | |
def transcribe_audio(audio_path: str) -> str: | |
"""语音转文字处理""" | |
if not asr_pipeline or not audio_path: | |
return "" | |
try: | |
return asr_pipeline(audio_path)["text"].strip() | |
except Exception as e: | |
print(config.error_msg(f"语音识别失败: {e}")) | |
return "" | |
# ---- 界面逻辑 ---- | |
STYLE_OPTIONS = { | |
"🎥 电影风格": "cinematic lighting", | |
"🖼️ 照片写实": "photorealistic", | |
"🇯🇵 二次元": "anime style", | |
"🎨 水彩艺术": "watercolor painting" | |
} | |
QUALITY_OPTIONS = [ | |
"高清细节", "复杂构图", | |
"专业光影", "4K分辨率" | |
] | |
def process_inputs( | |
text: str, | |
audio: Optional[str], | |
style: str, | |
quality: list, | |
neg_prompt: str, | |
cfg: float, | |
steps: int | |
) -> Tuple[str, Optional[Image.Image]]: | |
"""主处理流程""" | |
try: | |
# 输入处理 | |
final_text = text.strip() | |
if audio and os.path.exists(audio): | |
final_text = transcribe_audio(audio) or final_text | |
# 提示词优化 | |
enhanced = enhance_prompt(final_text, STYLE_OPTIONS[style], quality) | |
# 图像生成 | |
start_time = time.time() | |
image = generate_image(enhanced, neg_prompt, cfg, steps) | |
time_cost = time.time() - start_time | |
return f"✅ 生成成功(耗时:{time_cost:.1f}s)\n{enhanced}", image | |
except Exception as e: | |
return f"❌ 生成失败:{str(e)}", None | |
# ---- Gradio界面 ---- | |
with gr.Blocks(theme=gr.themes.Soft(), css=config.UI["warning_css"]) as app: | |
# 标题区 | |
gr.Markdown(f"## {config.UI['title']}") | |
gr.Markdown(config.UI["description"]) | |
# 状态提示 | |
if not openai_available: | |
gr.HTML("<div class='warning'>⚠️ OpenAI服务未启用,使用基础提示优化</div>") | |
if isinstance(image_pipe, DummyPipe): | |
gr.HTML("<div class='warning'>⚠️ 图像生成功能不可用:模型加载失败</div>") | |
with gr.Row(): | |
# 输入列 | |
with gr.Column(scale=1): | |
input_text = gr.Textbox( | |
label="📝 输入描述", | |
placeholder="例:机械猫在火星咖啡馆喝咖啡", | |
max_lines=3 | |
) | |
audio_input = gr.Audio( | |
sources=["microphone"], | |
type="filepath", | |
label="🎤 语音输入", | |
visible=bool(asr_pipeline) | |
) | |
with gr.Accordion("⚙️ 高级参数", open=False): | |
style_select = gr.Dropdown( | |
label="艺术风格", | |
choices=list(STYLE_OPTIONS.keys()), | |
value="🎥 电影风格" | |
) | |
quality_check = gr.CheckboxGroup( | |
label="质量增强", | |
choices=QUALITY_OPTIONS, | |
value=["高清细节"] | |
) | |
neg_prompt = gr.Textbox( | |
label="🚫 排除内容", | |
placeholder="输入不希望出现的元素..." | |
) | |
cfg_slider = gr.Slider( | |
1.0, 10.0, | |
value=config.DEFAULT_GUIDANCE, | |
label="生成引导强度" | |
) | |
steps_slider = gr.Slider( | |
5, config.MAX_STEPS, | |
value=config.DEFAULT_STEPS, | |
label="迭代步数" | |
) | |
generate_btn = gr.Button( | |
"✨ 开始生成", | |
variant="primary", | |
interactive=not isinstance(image_pipe, DummyPipe) | |
) | |
# 输出列 | |
with gr.Column(scale=1): | |
prompt_output = gr.Textbox( | |
label="📋 生成提示", | |
interactive=False, | |
lines=4 | |
) | |
image_output = gr.Image( | |
label="🖼️ 生成结果", | |
type="pil", | |
height=512, | |
show_download_button=True | |
) | |
# 事件绑定 | |
inputs = [input_text, audio_input, style_select, quality_check, neg_prompt, cfg_slider, steps_slider] | |
generate_btn.click(process_inputs, inputs, [prompt_output, image_output]) | |
# 音频输入自动清空文本 | |
if asr_pipeline: | |
audio_input.change( | |
lambda x: "" if x else gr.update(), | |
audio_input, input_text | |
) | |
# ---- 启动应用 ---- | |
if __name__ == "__main__": | |
app.launch(server_name="0.0.0.0", server_port=7860) | |