File size: 13,035 Bytes
c102ebc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 |
import gradio as gr
import torch
from transformers import pipeline, set_seed
from diffusers import StableDiffusionPipeline
import os
import time
# ---- 配置与模型加载 (在应用启动时加载一次) ----
# 检查是否有可用的GPU,否则使用CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# 1. 语音转文本模型 (Whisper) - 加分项
asr_pipeline = None
try:
print("Loading ASR pipeline (Whisper)...")
# 使用较小的模型以节省资源,可根据需要替换 openai/whisper-medium 或 large
# 在不需要GPU的应用部分可以强制使用CPU
asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device if device == "cuda" else -1) # whisper在CPU上也可以运行
print("ASR pipeline loaded.")
except Exception as e:
print(f"Could not load ASR pipeline: {e}. Voice input will be disabled.")
# 2. 提示词增强模型 (LLM) - Step 1
prompt_enhancer_pipeline = None
try:
print("Loading Prompt Enhancer pipeline (GPT-2)...")
# 使用 GPT-2 作为示例,实际应用中建议使用更强大的指令微调模型如 Mistral 或 Llama
# 注意:GPT-2 可能不会生成特别高质量的SD提示词,这里仅作结构演示
# 如果资源允许,可以替换为 'mistralai/Mistral-7B-Instruct-v0.1' 等,但需要更多内存/GPU
prompt_enhancer_pipeline = pipeline("text-generation", model="gpt2", device=device if device == "cuda" else -1) # text-generation在CPU上也可以运行
print("Prompt Enhancer pipeline loaded.")
except Exception as e:
print(f"Could not load Prompt Enhancer pipeline: {e}. Prompt enhancement might fail.")
# 3. 文本到图像模型 (Stable Diffusion) - Step 2
image_generator_pipe = None
try:
print("Loading Stable Diffusion pipeline (v1.5)...")
model_id = "runwayml/stable-diffusion-v1-5"
image_generator_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
image_generator_pipe = image_generator_pipe.to(device)
# 如果内存不足,可以启用CPU offloading (需要 accelerate库)
# image_generator_pipe.enable_model_cpu_offload()
print("Stable Diffusion pipeline loaded.")
except Exception as e:
print(f"Could not load Stable Diffusion pipeline: {e}. Image generation will fail.")
# 如果模型加载失败,创建一个虚拟对象以避免后续代码出错
class DummyPipe:
def __call__(self, *args, **kwargs):
# 返回一个占位符错误信息或图像
raise RuntimeError(f"Stable Diffusion model failed to load: {e}")
image_generator_pipe = DummyPipe()
# ---- 核心功能函数 ----
# Step 1: Prompt-to-Prompt
def enhance_prompt(short_prompt, style_modifier="cinematic", quality_boost="photorealistic, highly detailed"):
"""使用LLM增强简短描述"""
if not prompt_enhancer_pipeline:
return f"[Error: LLM not loaded] Original prompt: {short_prompt}"
if not short_prompt:
return "[Error: Input description is empty]"
# 构建给LLM的指令
# 注意:这个指令对GPT-2来说可能太复杂,对Mistral等更有效
input_text = (
f"Generate a detailed and vivid prompt for an AI image generator based on the following description. "
f"Incorporate the style '{style_modifier}' and quality boost '{quality_boost}'. "
f"Focus on visual details, lighting, composition, and mood. "
f"Description: \"{short_prompt}\"\n\n"
f"Detailed Prompt:"
)
try:
# 设置种子以获得可复现的(某种程度上的)结果
set_seed(int(time.time()))
# max_length 控制生成文本的总长度 (包括输入)
# num_return_sequences 返回多少个结果
# temperature 控制随机性,较低的值更保守
# no_repeat_ngram_size 避免重复短语
outputs = prompt_enhancer_pipeline(
input_text,
max_length=150, # 限制输出长度,避免过长
num_return_sequences=1,
temperature=0.7,
no_repeat_ngram_size=2,
pad_token_id=prompt_enhancer_pipeline.tokenizer.eos_token_id # 避免padding warning
)
generated_text = outputs[0]['generated_text']
# 从LLM的完整输出中提取增强后的提示词部分
# 简单方法:取 "Detailed Prompt:" 之后的内容
enhanced = generated_text.split("Detailed Prompt:")[-1].strip()
# 进一步清理可能包含的原始输入或指令痕迹
if short_prompt in enhanced[:len(short_prompt)+5]: # 如果开头包含原始输入
enhanced = enhanced.replace(short_prompt, "", 1).strip(' ,"')
# 添加基础的风格和质量词,如果LLM没有包含的话
if style_modifier not in enhanced:
enhanced += f", {style_modifier}"
if quality_boost not in enhanced:
enhanced += f", {quality_boost}"
return enhanced
except Exception as e:
print(f"Error during prompt enhancement: {e}")
return f"[Error: Prompt enhancement failed] Original prompt: {short_prompt}"
# Step 2: Prompt-to-Image
def generate_image(prompt, negative_prompt, guidance_scale, num_inference_steps):
"""使用Stable Diffusion生成图像"""
if not isinstance(image_generator_pipe, StableDiffusionPipeline):
raise gr.Error(f"Stable Diffusion model is not available. Load error: {image_generator_pipe}") # 使用gr.Error在UI上显示错误
if not prompt or "[Error:" in prompt:
raise gr.Error("Cannot generate image due to invalid or missing prompt.")
print(f"Generating image for prompt: {prompt}")
print(f"Negative prompt: {negative_prompt}")
print(f"Guidance scale: {guidance_scale}, Steps: {num_inference_steps}")
try:
# 设置随机种子
generator = torch.Generator(device=device).manual_seed(int(time.time()))
# 执行推理
with torch.inference_mode(): # 节省内存
image = image_generator_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_inference_steps),
generator=generator
).images[0]
print("Image generated successfully.")
return image
except Exception as e:
print(f"Error during image generation: {e}")
# 将底层错误传递给 Gradio,使其能在 UI 中显示
raise gr.Error(f"Image generation failed: {e}")
# Bonus: Voice-to-Text
def transcribe_audio(audio_file_path):
"""将音频文件转录为文本"""
if not asr_pipeline:
return "[Error: ASR model not loaded]", "" # 返回错误信息和空路径
if audio_file_path is None:
return "", "" # 没有音频输入
print(f"Transcribing audio file: {audio_file_path}")
try:
# 转录音频
transcription = asr_pipeline(audio_file_path)["text"]
print(f"Transcription result: {transcription}")
return transcription, audio_file_path # 返回文本和路径(可能用于显示)
except Exception as e:
print(f"Error during audio transcription: {e}")
return f"[Error: Transcription failed: {e}]", audio_file_path
# ---- Gradio 应用流程 ----
def process_input(input_text, audio_file, style_choice, quality_choice, neg_prompt, guidance, steps):
"""处理输入(文本或语音),生成提示词和图像"""
final_text_input = ""
transcription_source = "" # 用于标记来源
# 优先使用文本框输入
if input_text and input_text.strip():
final_text_input = input_text.strip()
transcription_source = " (from text input)"
# 如果文本框为空,且有音频文件,则使用语音输入
elif audio_file is not None:
transcribed_text, _ = transcribe_audio(audio_file)
if transcribed_text and "[Error:" not in transcribed_text:
final_text_input = transcribed_text
transcription_source = " (from audio input)"
elif "[Error:" in transcribed_text:
# 如果语音识别出错,直接返回错误信息
return transcribed_text, None # 返回错误提示,不生成图像
else:
# 音频为空或识别为空
return "[Error: Please provide input via text or voice]", None
else:
# 没有有效输入
return "[Error: Please provide input via text or voice]", None
print(f"Using input: '{final_text_input}'{transcription_source}")
# Step 1: Enhance prompt
enhanced_prompt = enhance_prompt(final_text_input, style_modifier=style_choice, quality_boost=quality_choice)
print(f"Enhanced prompt: {enhanced_prompt}")
# Step 2: Generate image (如果提示词增强成功)
generated_image = None
if "[Error:" not in enhanced_prompt:
try:
generated_image = generate_image(enhanced_prompt, neg_prompt, guidance, steps)
except gr.Error as e:
# 如果 generate_image 抛出 gr.Error,将其信息作为 enhanced_prompt 返回给UI
enhanced_prompt = f"{enhanced_prompt}\n\n[Image Generation Error: {e}]"
# 不再尝试显示图片
except Exception as e:
# 捕获其他意外错误
enhanced_prompt = f"{enhanced_prompt}\n\n[Unexpected Image Generation Error: {e}]"
# 返回结果给Gradio界面
return enhanced_prompt, generated_image
# ---- Gradio 界面构建 (Step 3: Controls & Step 4: Layout) ----
# 定义可选的风格和质量提升选项 (用于Dropdown/Radio)
style_options = ["cinematic", "photorealistic", "anime", "fantasy art", "cyberpunk", "steampunk", "watercolor"]
quality_options = ["highly detailed", "sharp focus", "intricate details", "4k", "masterpiece", "best quality"]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# AI Image Generator: From Idea to Image")
gr.Markdown("Enter a short description (or use voice input), and the app will enhance it into a detailed prompt and generate an image using Stable Diffusion.")
with gr.Row():
with gr.Column(scale=1):
# 输入区域
inp_text = gr.Textbox(label="Enter short description here", placeholder="e.g., A magical treehouse in the sky")
# 加分项:语音输入控件
inp_audio = gr.Audio(sources=["microphone"], type="filepath", label="Or record your idea (clears text box if used)", visible=asr_pipeline is not None) # 只有ASR加载成功才显示
# Step 3: 使用不同控件
# 控件1: Dropdown (下拉菜单)
inp_style = gr.Dropdown(label="Choose Base Style", choices=style_options, value="cinematic")
# 控件2: Radio (单选框) - 也可以用 CheckboxGroup 实现多选
inp_quality = gr.Radio(label="Quality Boost", choices=quality_options, value="highly detailed")
# 控件3: Textbox (用于Negative Prompt)
inp_neg_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, text, watermark")
# 控件4: Slider (滑块)
inp_guidance = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=7.5, label="Guidance Scale (CFG)")
# 控件5: Slider (滑块)
inp_steps = gr.Slider(minimum=10, maximum=100, step=1, value=30, label="Inference Steps")
# 提交按钮
btn_generate = gr.Button("Generate Image", variant="primary")
with gr.Column(scale=1):
# 输出区域
out_prompt = gr.Textbox(label="Generated Prompt", interactive=False) # 输出文本框不可编辑
out_image = gr.Image(label="Generated Image", type="pil") # 输出图像
# 设置按钮点击事件
btn_generate.click(
fn=process_input,
inputs=[inp_text, inp_audio, inp_style, inp_quality, inp_neg_prompt, inp_guidance, inp_steps],
outputs=[out_prompt, out_image]
)
# (可选) 当用户录音后,可以自动清空文本框,以明确优先使用语音
if asr_pipeline:
def clear_text_on_audio(audio_data):
if audio_data is not None:
return "" # 返回空字符串清空文本框
return gr.update() # 否则不改变文本框内容 (gr.update()是占位符)
inp_audio.change(fn=clear_text_on_audio, inputs=inp_audio, outputs=inp_text)
# ---- 启动应用 ----
if __name__ == "__main__":
# 设置Hugging Face Hub Token (如果需要从私有仓库加载模型)
# from huggingface_hub import login
# login("YOUR_HF_TOKEN") # 在本地运行时取消注释并替换
# 在Hugging Face Spaces上运行时,端口通常由平台管理
# share=True 会创建一个公共链接 (如果在本地运行需要)
demo.launch(share=False) |