start3406 commited on
Commit
c102ebc
·
verified ·
1 Parent(s): 298a72b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -0
app.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import pipeline, set_seed
4
+ from diffusers import StableDiffusionPipeline
5
+ import os
6
+ import time
7
+
8
+ # ---- 配置与模型加载 (在应用启动时加载一次) ----
9
+
10
+ # 检查是否有可用的GPU,否则使用CPU
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ print(f"Using device: {device}")
13
+
14
+ # 1. 语音转文本模型 (Whisper) - 加分项
15
+ asr_pipeline = None
16
+ try:
17
+ print("Loading ASR pipeline (Whisper)...")
18
+ # 使用较小的模型以节省资源,可根据需要替换 openai/whisper-medium 或 large
19
+ # 在不需要GPU的应用部分可以强制使用CPU
20
+ asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device if device == "cuda" else -1) # whisper在CPU上也可以运行
21
+ print("ASR pipeline loaded.")
22
+ except Exception as e:
23
+ print(f"Could not load ASR pipeline: {e}. Voice input will be disabled.")
24
+
25
+ # 2. 提示词增强模型 (LLM) - Step 1
26
+ prompt_enhancer_pipeline = None
27
+ try:
28
+ print("Loading Prompt Enhancer pipeline (GPT-2)...")
29
+ # 使用 GPT-2 作为示例,实际应用中建议使用更强大的指令微调模型如 Mistral 或 Llama
30
+ # 注意:GPT-2 可能不会生成特别高质量的SD提示词,这里仅作结构演示
31
+ # 如果资源允许,可以替换为 'mistralai/Mistral-7B-Instruct-v0.1' 等,但需要更多内存/GPU
32
+ prompt_enhancer_pipeline = pipeline("text-generation", model="gpt2", device=device if device == "cuda" else -1) # text-generation在CPU上也可以运行
33
+ print("Prompt Enhancer pipeline loaded.")
34
+ except Exception as e:
35
+ print(f"Could not load Prompt Enhancer pipeline: {e}. Prompt enhancement might fail.")
36
+
37
+ # 3. 文本到图像模型 (Stable Diffusion) - Step 2
38
+ image_generator_pipe = None
39
+ try:
40
+ print("Loading Stable Diffusion pipeline (v1.5)...")
41
+ model_id = "runwayml/stable-diffusion-v1-5"
42
+ image_generator_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
43
+ image_generator_pipe = image_generator_pipe.to(device)
44
+ # 如果内存不足,可以启用CPU offloading (需要 accelerate库)
45
+ # image_generator_pipe.enable_model_cpu_offload()
46
+ print("Stable Diffusion pipeline loaded.")
47
+ except Exception as e:
48
+ print(f"Could not load Stable Diffusion pipeline: {e}. Image generation will fail.")
49
+ # 如果模型加载失败,创建一个虚拟对象以避免后续代码出错
50
+ class DummyPipe:
51
+ def __call__(self, *args, **kwargs):
52
+ # 返回一个占位符错误信息或图像
53
+ raise RuntimeError(f"Stable Diffusion model failed to load: {e}")
54
+ image_generator_pipe = DummyPipe()
55
+
56
+
57
+ # ---- 核心功能函数 ----
58
+
59
+ # Step 1: Prompt-to-Prompt
60
+ def enhance_prompt(short_prompt, style_modifier="cinematic", quality_boost="photorealistic, highly detailed"):
61
+ """使用LLM增强简短描述"""
62
+ if not prompt_enhancer_pipeline:
63
+ return f"[Error: LLM not loaded] Original prompt: {short_prompt}"
64
+ if not short_prompt:
65
+ return "[Error: Input description is empty]"
66
+
67
+ # 构建给LLM的指令
68
+ # 注意:这个指令对GPT-2来说可能太复杂,对Mistral等更有效
69
+ input_text = (
70
+ f"Generate a detailed and vivid prompt for an AI image generator based on the following description. "
71
+ f"Incorporate the style '{style_modifier}' and quality boost '{quality_boost}'. "
72
+ f"Focus on visual details, lighting, composition, and mood. "
73
+ f"Description: \"{short_prompt}\"\n\n"
74
+ f"Detailed Prompt:"
75
+ )
76
+
77
+ try:
78
+ # 设置种子以获得可复现的(某种程度上的)结果
79
+ set_seed(int(time.time()))
80
+ # max_length 控制生成文本的总长度 (包括输入)
81
+ # num_return_sequences 返回多少个结果
82
+ # temperature 控制随机性,较低的值更保守
83
+ # no_repeat_ngram_size 避免重复短语
84
+ outputs = prompt_enhancer_pipeline(
85
+ input_text,
86
+ max_length=150, # 限制输出长度,避免过长
87
+ num_return_sequences=1,
88
+ temperature=0.7,
89
+ no_repeat_ngram_size=2,
90
+ pad_token_id=prompt_enhancer_pipeline.tokenizer.eos_token_id # 避免padding warning
91
+ )
92
+ generated_text = outputs[0]['generated_text']
93
+
94
+ # 从LLM的完整输出中提取增强后的提示词部分
95
+ # 简单方法:取 "Detailed Prompt:" 之后的内容
96
+ enhanced = generated_text.split("Detailed Prompt:")[-1].strip()
97
+ # 进一步清理可能包含的原始输入或指令痕迹
98
+ if short_prompt in enhanced[:len(short_prompt)+5]: # 如果开头包含原始输入
99
+ enhanced = enhanced.replace(short_prompt, "", 1).strip(' ,"')
100
+
101
+ # 添加基础的风格和质量词,如果LLM没有包含的话
102
+ if style_modifier not in enhanced:
103
+ enhanced += f", {style_modifier}"
104
+ if quality_boost not in enhanced:
105
+ enhanced += f", {quality_boost}"
106
+
107
+ return enhanced
108
+
109
+ except Exception as e:
110
+ print(f"Error during prompt enhancement: {e}")
111
+ return f"[Error: Prompt enhancement failed] Original prompt: {short_prompt}"
112
+
113
+ # Step 2: Prompt-to-Image
114
+ def generate_image(prompt, negative_prompt, guidance_scale, num_inference_steps):
115
+ """使用Stable Diffusion生成图像"""
116
+ if not isinstance(image_generator_pipe, StableDiffusionPipeline):
117
+ raise gr.Error(f"Stable Diffusion model is not available. Load error: {image_generator_pipe}") # 使用gr.Error在UI上显示错误
118
+ if not prompt or "[Error:" in prompt:
119
+ raise gr.Error("Cannot generate image due to invalid or missing prompt.")
120
+
121
+ print(f"Generating image for prompt: {prompt}")
122
+ print(f"Negative prompt: {negative_prompt}")
123
+ print(f"Guidance scale: {guidance_scale}, Steps: {num_inference_steps}")
124
+
125
+ try:
126
+ # 设置随机种子
127
+ generator = torch.Generator(device=device).manual_seed(int(time.time()))
128
+
129
+ # 执行推理
130
+ with torch.inference_mode(): # 节省内存
131
+ image = image_generator_pipe(
132
+ prompt=prompt,
133
+ negative_prompt=negative_prompt,
134
+ guidance_scale=float(guidance_scale),
135
+ num_inference_steps=int(num_inference_steps),
136
+ generator=generator
137
+ ).images[0]
138
+ print("Image generated successfully.")
139
+ return image
140
+ except Exception as e:
141
+ print(f"Error during image generation: {e}")
142
+ # 将底层错误传递给 Gradio,使其能在 UI 中显示
143
+ raise gr.Error(f"Image generation failed: {e}")
144
+
145
+
146
+ # Bonus: Voice-to-Text
147
+ def transcribe_audio(audio_file_path):
148
+ """将音频文件转录为文本"""
149
+ if not asr_pipeline:
150
+ return "[Error: ASR model not loaded]", "" # 返回错误信息和空路径
151
+ if audio_file_path is None:
152
+ return "", "" # 没有音频输入
153
+
154
+ print(f"Transcribing audio file: {audio_file_path}")
155
+ try:
156
+ # 转录音频
157
+ transcription = asr_pipeline(audio_file_path)["text"]
158
+ print(f"Transcription result: {transcription}")
159
+ return transcription, audio_file_path # 返回文本和路径(可能用于显示)
160
+ except Exception as e:
161
+ print(f"Error during audio transcription: {e}")
162
+ return f"[Error: Transcription failed: {e}]", audio_file_path
163
+
164
+
165
+ # ---- Gradio 应用流程 ----
166
+
167
+ def process_input(input_text, audio_file, style_choice, quality_choice, neg_prompt, guidance, steps):
168
+ """处理输入(文本或语音),生成提示词和图像"""
169
+ final_text_input = ""
170
+ transcription_source = "" # 用于标记来源
171
+
172
+ # 优先使用文本框输入
173
+ if input_text and input_text.strip():
174
+ final_text_input = input_text.strip()
175
+ transcription_source = " (from text input)"
176
+ # 如果文本框为空,且有音频文件,则使用语音输入
177
+ elif audio_file is not None:
178
+ transcribed_text, _ = transcribe_audio(audio_file)
179
+ if transcribed_text and "[Error:" not in transcribed_text:
180
+ final_text_input = transcribed_text
181
+ transcription_source = " (from audio input)"
182
+ elif "[Error:" in transcribed_text:
183
+ # 如果语音识别出错,直接返回错误信息
184
+ return transcribed_text, None # 返回错误提示,不生成图像
185
+ else:
186
+ # 音频为空或识别为空
187
+ return "[Error: Please provide input via text or voice]", None
188
+ else:
189
+ # 没有有效输入
190
+ return "[Error: Please provide input via text or voice]", None
191
+
192
+ print(f"Using input: '{final_text_input}'{transcription_source}")
193
+
194
+ # Step 1: Enhance prompt
195
+ enhanced_prompt = enhance_prompt(final_text_input, style_modifier=style_choice, quality_boost=quality_choice)
196
+ print(f"Enhanced prompt: {enhanced_prompt}")
197
+
198
+ # Step 2: Generate image (如果提示词增强成功)
199
+ generated_image = None
200
+ if "[Error:" not in enhanced_prompt:
201
+ try:
202
+ generated_image = generate_image(enhanced_prompt, neg_prompt, guidance, steps)
203
+ except gr.Error as e:
204
+ # 如果 generate_image 抛出 gr.Error,将其信息作为 enhanced_prompt 返回给UI
205
+ enhanced_prompt = f"{enhanced_prompt}\n\n[Image Generation Error: {e}]"
206
+ # 不再尝试显示图片
207
+ except Exception as e:
208
+ # 捕获其他意外错误
209
+ enhanced_prompt = f"{enhanced_prompt}\n\n[Unexpected Image Generation Error: {e}]"
210
+
211
+ # 返回结果给Gradio界面
212
+ return enhanced_prompt, generated_image
213
+
214
+
215
+ # ---- Gradio 界面构建 (Step 3: Controls & Step 4: Layout) ----
216
+
217
+ # 定义可选的风格和质量提升选项 (用于Dropdown/Radio)
218
+ style_options = ["cinematic", "photorealistic", "anime", "fantasy art", "cyberpunk", "steampunk", "watercolor"]
219
+ quality_options = ["highly detailed", "sharp focus", "intricate details", "4k", "masterpiece", "best quality"]
220
+
221
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
222
+ gr.Markdown("# AI Image Generator: From Idea to Image")
223
+ gr.Markdown("Enter a short description (or use voice input), and the app will enhance it into a detailed prompt and generate an image using Stable Diffusion.")
224
+
225
+ with gr.Row():
226
+ with gr.Column(scale=1):
227
+ # 输入区域
228
+ inp_text = gr.Textbox(label="Enter short description here", placeholder="e.g., A magical treehouse in the sky")
229
+
230
+ # 加分项:语音输入控件
231
+ inp_audio = gr.Audio(sources=["microphone"], type="filepath", label="Or record your idea (clears text box if used)", visible=asr_pipeline is not None) # 只有ASR加载成功才显示
232
+
233
+ # Step 3: 使用不同控件
234
+ # 控件1: Dropdown (下拉菜单)
235
+ inp_style = gr.Dropdown(label="Choose Base Style", choices=style_options, value="cinematic")
236
+
237
+ # 控件2: Radio (单选框) - 也可以用 CheckboxGroup 实现多选
238
+ inp_quality = gr.Radio(label="Quality Boost", choices=quality_options, value="highly detailed")
239
+
240
+ # 控件3: Textbox (用于Negative Prompt)
241
+ inp_neg_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, text, watermark")
242
+
243
+ # 控件4: Slider (滑块)
244
+ inp_guidance = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=7.5, label="Guidance Scale (CFG)")
245
+
246
+ # 控件5: Slider (滑块)
247
+ inp_steps = gr.Slider(minimum=10, maximum=100, step=1, value=30, label="Inference Steps")
248
+
249
+ # 提交按钮
250
+ btn_generate = gr.Button("Generate Image", variant="primary")
251
+
252
+ with gr.Column(scale=1):
253
+ # 输出区域
254
+ out_prompt = gr.Textbox(label="Generated Prompt", interactive=False) # 输出文本框不可编辑
255
+ out_image = gr.Image(label="Generated Image", type="pil") # 输出图像
256
+
257
+ # 设置按钮点击事件
258
+ btn_generate.click(
259
+ fn=process_input,
260
+ inputs=[inp_text, inp_audio, inp_style, inp_quality, inp_neg_prompt, inp_guidance, inp_steps],
261
+ outputs=[out_prompt, out_image]
262
+ )
263
+
264
+ # (可选) 当用户录音后,可以自动清空文本框,以明确优先使用语音
265
+ if asr_pipeline:
266
+ def clear_text_on_audio(audio_data):
267
+ if audio_data is not None:
268
+ return "" # 返回空字符串清空文本框
269
+ return gr.update() # 否则不改变文本框内容 (gr.update()是占位符)
270
+ inp_audio.change(fn=clear_text_on_audio, inputs=inp_audio, outputs=inp_text)
271
+
272
+
273
+ # ---- 启动应用 ----
274
+ if __name__ == "__main__":
275
+ # 设置Hugging Face Hub Token (如果需要从私有仓库加载模型)
276
+ # from huggingface_hub import login
277
+ # login("YOUR_HF_TOKEN") # 在本地运行时取消注释并替换
278
+
279
+ # 在Hugging Face Spaces上运行时,端口通常由平台管理
280
+ # share=True 会创建一个公共链接 (如果在本地运行需要)
281
+ demo.launch(share=False)