tangchao / app.py
tangchao5355's picture
Update app.py
a4780f4 verified
import gradio as gr
import torch
from transformers import pipeline, set_seed
from diffusers import AutoPipelineForText2Image
import openai
import os
import time
import traceback
from typing import Optional, Tuple, Union, Literal, TypedDict
from PIL import Image
# 在代码开头添加:
import os
os.environ["OPENAI_API_KEY"] = "sk-your-api-key-here"
# ---- 类型定义 ----
class ModelConfig(TypedDict):
model_id: str
dtype: torch.dtype
timeout: int
class UIConfig(TypedDict):
title: str
description: str
warning_css: str
# ---- 配置管理 ----
class AppConfig:
# 硬件配置
DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
# 模型配置
MODEL: ModelConfig = {
"model_id": "nota-ai/bk-sdm-tiny",
"dtype": torch.float32,
"timeout": 300
}
# 界面配置
UI: UIConfig = {
"title": "🎨 轻量级AI图像生成器(CPU/GPU版)",
"description": """\
💡 使用技巧:输入简短描述后选择风格和质量选项\n
🚀 支持语音输入 • 自动提示词优化 • 快速生成模式\n
⚠️ 注意:小模型生成速度快但细节有限,建议使用具体描述""",
"warning_css": """
.warning {color: orange !important; border-left: 3px solid orange; padding: 10px;}
.success {color: green !important;}
"""
}
# 生成参数
DEFAULT_STEPS: int = 20
MAX_STEPS: int = 40
DEFAULT_GUIDANCE: float = 5.0
# 错误模板
@staticmethod
def error_msg(message: str) -> str:
return f"❌ 错误:{message}"
config = AppConfig()
# ---- 初始化检查 ----
openai_client: Optional[openai.OpenAI] = None
openai_available: bool = False
if os.environ.get("OPENAI_API_KEY"):
try:
openai_client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
openai_available = True
print("✅ OpenAI 客户端初始化成功")
except Exception as e:
print(config.error_msg(f"OpenAI 初始化失败: {e}"))
# ---- 模型加载 ----
class DummyPipe:
def __call__(self, *args, **kwargs) -> None:
raise RuntimeError("图像生成模型未加载")
# 语音识别模型
asr_pipeline = None
try:
asr_pipeline = pipeline(
"automatic-speech-recognition",
model="openai/whisper-base",
device=config.DEVICE,
torch_dtype=config.MODEL["dtype"]
)
print("✅ 语音识别模型加载成功")
except Exception as e:
print(config.error_msg(f"语音模型加载失败: {e}"))
# 图像生成模型
image_pipe: Union[AutoPipelineForText2Image, DummyPipe] = DummyPipe()
try:
image_pipe = AutoPipelineForText2Image.from_pretrained(
config.MODEL["model_id"],
torch_dtype=config.MODEL["dtype"],
use_safetensors=True,
resume_download=True,
timeout=config.MODEL["timeout"]
).to(config.DEVICE)
print(f"✅ 图像模型 {config.MODEL['model_id']} 加载成功")
except Exception as e:
print(config.error_msg(f"图像模型加载失败: {e}"))
# ---- 核心功能 ----
def enhance_prompt(short_prompt: str, style: str, quality: list) -> str:
"""提示词优化处理"""
if not short_prompt.strip():
raise gr.Error("描述内容不能为空")
# 基础增强模板
base_prompt = f"{short_prompt.strip()}, {style}, {', '.join(quality)}"
if not openai_available:
return base_prompt
try:
response = openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{
"role": "system",
"content": "你是一个AI绘画提示词专家,请把用户的简短描述扩展为适合小模型使用的详细提示词。"
}, {
"role": "user",
"content": f"请优化这个提示词:'{base_prompt}'。要求:保持简洁,适合快速生成,包含主要视觉元素。"
}],
temperature=0.7,
max_tokens=100
)
return response.choices[0].message.content.strip('"')
except Exception as e:
print(config.error_msg(f"提示词优化失败: {e}"))
return base_prompt
def generate_image(prompt: str, neg_prompt: str, cfg: float, steps: int) -> Image.Image:
"""图像生成核心函数"""
if isinstance(image_pipe, DummyPipe):
raise gr.Error("图像生成功能不可用:模型加载失败")
try:
with torch.no_grad():
result = image_pipe(
prompt=prompt,
negative_prompt=neg_prompt,
guidance_scale=cfg,
num_inference_steps=steps,
generator=torch.Generator(config.DEVICE).manual_seed(int(time.time()))
)
return result.images[0]
except Exception as e:
raise gr.Error(f"生成失败: {str(e)}")
def transcribe_audio(audio_path: str) -> str:
"""语音转文字处理"""
if not asr_pipeline or not audio_path:
return ""
try:
return asr_pipeline(audio_path)["text"].strip()
except Exception as e:
print(config.error_msg(f"语音识别失败: {e}"))
return ""
# ---- 界面逻辑 ----
STYLE_OPTIONS = {
"🎥 电影风格": "cinematic lighting",
"🖼️ 照片写实": "photorealistic",
"🇯🇵 二次元": "anime style",
"🎨 水彩艺术": "watercolor painting"
}
QUALITY_OPTIONS = [
"高清细节", "复杂构图",
"专业光影", "4K分辨率"
]
def process_inputs(
text: str,
audio: Optional[str],
style: str,
quality: list,
neg_prompt: str,
cfg: float,
steps: int
) -> Tuple[str, Optional[Image.Image]]:
"""主处理流程"""
try:
# 输入处理
final_text = text.strip()
if audio and os.path.exists(audio):
final_text = transcribe_audio(audio) or final_text
# 提示词优化
enhanced = enhance_prompt(final_text, STYLE_OPTIONS[style], quality)
# 图像生成
start_time = time.time()
image = generate_image(enhanced, neg_prompt, cfg, steps)
time_cost = time.time() - start_time
return f"✅ 生成成功(耗时:{time_cost:.1f}s)\n{enhanced}", image
except Exception as e:
return f"❌ 生成失败:{str(e)}", None
# ---- Gradio界面 ----
with gr.Blocks(theme=gr.themes.Soft(), css=config.UI["warning_css"]) as app:
# 标题区
gr.Markdown(f"## {config.UI['title']}")
gr.Markdown(config.UI["description"])
# 状态提示
if not openai_available:
gr.HTML("<div class='warning'>⚠️ OpenAI服务未启用,使用基础提示优化</div>")
if isinstance(image_pipe, DummyPipe):
gr.HTML("<div class='warning'>⚠️ 图像生成功能不可用:模型加载失败</div>")
with gr.Row():
# 输入列
with gr.Column(scale=1):
input_text = gr.Textbox(
label="📝 输入描述",
placeholder="例:机械猫在火星咖啡馆喝咖啡",
max_lines=3
)
audio_input = gr.Audio(
sources=["microphone"],
type="filepath",
label="🎤 语音输入",
visible=bool(asr_pipeline)
)
with gr.Accordion("⚙️ 高级参数", open=False):
style_select = gr.Dropdown(
label="艺术风格",
choices=list(STYLE_OPTIONS.keys()),
value="🎥 电影风格"
)
quality_check = gr.CheckboxGroup(
label="质量增强",
choices=QUALITY_OPTIONS,
value=["高清细节"]
)
neg_prompt = gr.Textbox(
label="🚫 排除内容",
placeholder="输入不希望出现的元素..."
)
cfg_slider = gr.Slider(
1.0, 10.0,
value=config.DEFAULT_GUIDANCE,
label="生成引导强度"
)
steps_slider = gr.Slider(
5, config.MAX_STEPS,
value=config.DEFAULT_STEPS,
label="迭代步数"
)
generate_btn = gr.Button(
"✨ 开始生成",
variant="primary",
interactive=not isinstance(image_pipe, DummyPipe)
)
# 输出列
with gr.Column(scale=1):
prompt_output = gr.Textbox(
label="📋 生成提示",
interactive=False,
lines=4
)
image_output = gr.Image(
label="🖼️ 生成结果",
type="pil",
height=512,
show_download_button=True
)
# 事件绑定
inputs = [input_text, audio_input, style_select, quality_check, neg_prompt, cfg_slider, steps_slider]
generate_btn.click(process_inputs, inputs, [prompt_output, image_output])
# 音频输入自动清空文本
if asr_pipeline:
audio_input.change(
lambda x: "" if x else gr.update(),
audio_input, input_text
)
# ---- 启动应用 ----
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)