start3406 commited on
Commit
a63d56e
·
verified ·
1 Parent(s): ec797fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -183
app.py CHANGED
@@ -2,280 +2,352 @@ import gradio as gr
2
  import torch
3
  from transformers import pipeline, set_seed
4
  from diffusers import StableDiffusionPipeline
 
5
  import os
6
  import time
 
7
 
8
- # ---- 配置与模型加载 (在应用启动时加载一次) ----
 
 
 
 
9
 
10
- # 检查是否有可用的GPU,否则使用CPU
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  print(f"Using device: {device}")
13
 
 
 
14
  # 1. 语音转文本模型 (Whisper) - 加分项
15
  asr_pipeline = None
16
  try:
17
- print("Loading ASR pipeline (Whisper)...")
18
- # 使用较小的模型以节省资源,可根据需要替换 openai/whisper-medium large
19
- # 在不需要GPU的应用部分可以强制使用CPU
20
- asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device if device == "cuda" else -1) # whisper在CPU上也可以运行
21
- print("ASR pipeline loaded.")
22
  except Exception as e:
23
  print(f"Could not load ASR pipeline: {e}. Voice input will be disabled.")
 
24
 
25
- # 2. 提示词增强模型 (LLM) - Step 1
26
- prompt_enhancer_pipeline = None
27
- try:
28
- print("Loading Prompt Enhancer pipeline (GPT-2)...")
29
- # 使用 GPT-2 作为示例,实际应用中建议使用更强大的指令微调模型如 Mistral 或 Llama
30
- # 注意:GPT-2 可能不会生成特别高质量的SD提示词,这里仅作结构演示
31
- # 如果资源允许,可以替换为 'mistralai/Mistral-7B-Instruct-v0.1' 等,但需要更多内存/GPU
32
- prompt_enhancer_pipeline = pipeline("text-generation", model="gpt2", device=device if device == "cuda" else -1) # text-generation在CPU上也可以运行
33
- print("Prompt Enhancer pipeline loaded.")
34
- except Exception as e:
35
- print(f"Could not load Prompt Enhancer pipeline: {e}. Prompt enhancement might fail.")
36
-
37
- # 3. 文本到图像模型 (Stable Diffusion) - Step 2
38
  image_generator_pipe = None
39
  try:
40
- print("Loading Stable Diffusion pipeline (v1.5)...")
 
41
  model_id = "runwayml/stable-diffusion-v1-5"
42
- image_generator_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
 
43
  image_generator_pipe = image_generator_pipe.to(device)
44
- # 如果内存不足,可以启用CPU offloading (需要 accelerate库)
45
- # image_generator_pipe.enable_model_cpu_offload()
46
- print("Stable Diffusion pipeline loaded.")
47
  except Exception as e:
48
- print(f"Could not load Stable Diffusion pipeline: {e}. Image generation will fail.")
49
- # 如果模型加载失败,创建一个虚拟对象以避免后续代码出错
 
50
  class DummyPipe:
51
  def __call__(self, *args, **kwargs):
52
- # 返回一个占位符错误信息或图像
53
- raise RuntimeError(f"Stable Diffusion model failed to load: {e}")
54
  image_generator_pipe = DummyPipe()
55
 
56
 
57
- # ---- 核心功能函数 ----
58
 
59
- # Step 1: Prompt-to-Prompt
60
- def enhance_prompt(short_prompt, style_modifier="cinematic", quality_boost="photorealistic, highly detailed"):
61
- """使用LLM增强简短描述"""
62
- if not prompt_enhancer_pipeline:
63
- return f"[Error: LLM not loaded] Original prompt: {short_prompt}"
 
 
64
  if not short_prompt:
65
- return "[Error: Input description is empty]"
66
-
67
- # 构建给LLM的指令
68
- # 注意:这个指令对GPT-2来说可能太复杂,对Mistral等更有效
69
- input_text = (
70
- f"Generate a detailed and vivid prompt for an AI image generator based on the following description. "
71
- f"Incorporate the style '{style_modifier}' and quality boost '{quality_boost}'. "
72
- f"Focus on visual details, lighting, composition, and mood. "
73
- f"Description: \"{short_prompt}\"\n\n"
74
- f"Detailed Prompt:"
75
  )
 
 
 
 
 
 
76
 
77
  try:
78
- # 设置种子以获得可复现的(某种程度上的)结果
79
- set_seed(int(time.time()))
80
- # max_length 控制生成文本的总长度 (包括输入)
81
- # num_return_sequences 返回多少个结果
82
- # temperature 控制随机性,较低的值更保守
83
- # no_repeat_ngram_size 避免重复短语
84
- outputs = prompt_enhancer_pipeline(
85
- input_text,
86
- max_length=150, # 限制输出长度,避免过长
87
- num_return_sequences=1,
88
- temperature=0.7,
89
- no_repeat_ngram_size=2,
90
- pad_token_id=prompt_enhancer_pipeline.tokenizer.eos_token_id # 避免padding warning
91
  )
92
- generated_text = outputs[0]['generated_text']
93
-
94
- # 从LLM的完整输出中提取增强后的提示词部分
95
- # 简单方法:取 "Detailed Prompt:" 之后的内容
96
- enhanced = generated_text.split("Detailed Prompt:")[-1].strip()
97
- # 进一步清理可能包含的原始输入或指令痕迹
98
- if short_prompt in enhanced[:len(short_prompt)+5]: # 如果开头包含原始输入
99
- enhanced = enhanced.replace(short_prompt, "", 1).strip(' ,"')
100
-
101
- # 添加基础的风格和质量词,如果LLM没有包含的话
102
- if style_modifier not in enhanced:
103
- enhanced += f", {style_modifier}"
104
- if quality_boost not in enhanced:
105
- enhanced += f", {quality_boost}"
106
-
107
- return enhanced
108
-
109
  except Exception as e:
110
- print(f"Error during prompt enhancement: {e}")
111
- return f"[Error: Prompt enhancement failed] Original prompt: {short_prompt}"
 
 
112
 
113
- # Step 2: Prompt-to-Image
114
- def generate_image(prompt, negative_prompt, guidance_scale, num_inference_steps):
115
- """使用Stable Diffusion生成图像"""
116
  if not isinstance(image_generator_pipe, StableDiffusionPipeline):
117
- raise gr.Error(f"Stable Diffusion model is not available. Load error: {image_generator_pipe}") # 使用gr.Error在UI上显示错误
118
- if not prompt or "[Error:" in prompt:
 
119
  raise gr.Error("Cannot generate image due to invalid or missing prompt.")
120
 
121
- print(f"Generating image for prompt: {prompt}")
122
  print(f"Negative prompt: {negative_prompt}")
123
  print(f"Guidance scale: {guidance_scale}, Steps: {num_inference_steps}")
 
124
 
125
  try:
126
- # 设置随机种子
127
- generator = torch.Generator(device=device).manual_seed(int(time.time()))
128
-
129
- # 执行推理
130
- with torch.inference_mode(): # 节省内存
131
- image = image_generator_pipe(
132
- prompt=prompt,
133
- negative_prompt=negative_prompt,
134
- guidance_scale=float(guidance_scale),
135
- num_inference_steps=int(num_inference_steps),
136
- generator=generator
137
- ).images[0]
138
- print("Image generated successfully.")
139
  return image
140
  except Exception as e:
141
- print(f"Error during image generation: {e}")
142
- # 将底层错误传递给 Gradio,使其能在 UI 中显示
143
- raise gr.Error(f"Image generation failed: {e}")
 
144
 
145
 
146
- # Bonus: Voice-to-Text
147
  def transcribe_audio(audio_file_path):
148
- """将音频文件转录为文本"""
149
  if not asr_pipeline:
150
- return "[Error: ASR model not loaded]", "" # 返回错误信息和空路径
 
151
  if audio_file_path is None:
152
- return "", "" # 没有音频输入
153
 
154
- print(f"Transcribing audio file: {audio_file_path}")
 
155
  try:
156
- # 转录音频
157
  transcription = asr_pipeline(audio_file_path)["text"]
 
 
158
  print(f"Transcription result: {transcription}")
159
- return transcription, audio_file_path # 返回文本和路径(可能用于显示)
160
  except Exception as e:
161
- print(f"Error during audio transcription: {e}")
 
 
162
  return f"[Error: Transcription failed: {e}]", audio_file_path
163
 
164
 
165
- # ---- Gradio 应用流程 ----
166
 
167
  def process_input(input_text, audio_file, style_choice, quality_choice, neg_prompt, guidance, steps):
168
- """处理输入(文本或语音),生成提示词和图像"""
169
  final_text_input = ""
170
- transcription_source = "" # 用于标记来源
 
 
171
 
172
- # 优先使用文本框输入
173
  if input_text and input_text.strip():
174
  final_text_input = input_text.strip()
175
- transcription_source = " (from text input)"
176
- # 如果文本框为空,且有音频文件,则使用语音输入
177
  elif audio_file is not None:
 
178
  transcribed_text, _ = transcribe_audio(audio_file)
179
- if transcribed_text and "[Error:" not in transcribed_text:
 
 
 
 
 
 
180
  final_text_input = transcribed_text
181
- transcription_source = " (from audio input)"
182
- elif "[Error:" in transcribed_text:
183
- # 如果语音识别出错,直接返回错误信息
184
- return transcribed_text, None # 返回错误提示,不生成图像
185
  else:
186
- # 音频为空或识别为空
187
- return "[Error: Please provide input via text or voice]", None
 
188
  else:
189
- # 没有有效输入
190
- return "[Error: Please provide input via text or voice]", None
191
-
192
- print(f"Using input: '{final_text_input}'{transcription_source}")
193
 
194
- # Step 1: Enhance prompt
195
- enhanced_prompt = enhance_prompt(final_text_input, style_modifier=style_choice, quality_boost=quality_choice)
196
- print(f"Enhanced prompt: {enhanced_prompt}")
197
-
198
- # Step 2: Generate image (如果提示词增强成功)
199
- generated_image = None
200
- if "[Error:" not in enhanced_prompt:
201
  try:
202
- generated_image = generate_image(enhanced_prompt, neg_prompt, guidance, steps)
 
 
203
  except gr.Error as e:
204
- # 如果 generate_image 抛出 gr.Error,将其信息作为 enhanced_prompt 返回给UI
205
- enhanced_prompt = f"{enhanced_prompt}\n\n[Image Generation Error: {e}]"
206
- # 不再尝试显示图片
 
 
207
  except Exception as e:
208
- # 捕获其他意外错误
209
- enhanced_prompt = f"{enhanced_prompt}\n\n[Unexpected Image Generation Error: {e}]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- # 返回结果给Gradio界面
212
- return enhanced_prompt, generated_image
 
213
 
214
 
215
- # ---- Gradio 界面构建 (Step 3: Controls & Step 4: Layout) ----
216
 
217
- # 定义可选的风格和质量提升选项 (用于Dropdown/Radio)
218
- style_options = ["cinematic", "photorealistic", "anime", "fantasy art", "cyberpunk", "steampunk", "watercolor"]
219
- quality_options = ["highly detailed", "sharp focus", "intricate details", "4k", "masterpiece", "best quality"]
 
 
 
220
 
221
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
222
- gr.Markdown("# AI Image Generator: From Idea to Image")
223
- gr.Markdown("Enter a short description (or use voice input), and the app will enhance it into a detailed prompt and generate an image using Stable Diffusion.")
 
 
 
 
 
 
 
 
 
224
 
225
  with gr.Row():
226
  with gr.Column(scale=1):
227
- # 输入区域
228
- inp_text = gr.Textbox(label="Enter short description here", placeholder="e.g., A magical treehouse in the sky")
229
-
230
- # 加分项:语音输入控件
231
- inp_audio = gr.Audio(sources=["microphone"], type="filepath", label="Or record your idea (clears text box if used)", visible=asr_pipeline is not None) # 只有ASR加载成功才显示
232
-
233
- # Step 3: 使用不同控件
234
- # 控件1: Dropdown (下拉菜单)
235
- inp_style = gr.Dropdown(label="Choose Base Style", choices=style_options, value="cinematic")
236
-
237
- # 控件2: Radio (单选框) - 也可以用 CheckboxGroup 实现多选
 
 
 
238
  inp_quality = gr.Radio(label="Quality Boost", choices=quality_options, value="highly detailed")
239
-
240
- # 控件3: Textbox (用于Negative Prompt)
241
- inp_neg_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, text, watermark")
242
-
243
- # 控件4: Slider (滑块)
244
- inp_guidance = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=7.5, label="Guidance Scale (CFG)")
245
-
246
- # 控件5: Slider (滑块)
247
- inp_steps = gr.Slider(minimum=10, maximum=100, step=1, value=30, label="Inference Steps")
248
-
249
- # 提交按钮
250
  btn_generate = gr.Button("Generate Image", variant="primary")
251
 
252
  with gr.Column(scale=1):
253
- # 输出区域
254
- out_prompt = gr.Textbox(label="Generated Prompt", interactive=False) # 输出文本框不可编辑
255
- out_image = gr.Image(label="Generated Image", type="pil") # 输出图像
 
 
 
 
 
 
 
 
 
 
256
 
257
- # 设置按钮点击事件
258
  btn_generate.click(
259
  fn=process_input,
260
- inputs=[inp_text, inp_audio, inp_style, inp_quality, inp_neg_prompt, inp_guidance, inp_steps],
261
  outputs=[out_prompt, out_image]
262
  )
263
 
264
- # (可选) 当用户录音后,可以自动清空文本框,以明确优先使用语音
265
  if asr_pipeline:
266
  def clear_text_on_audio(audio_data):
267
  if audio_data is not None:
268
- return "" # 返回空字符串清空文本框
269
- return gr.update() # 否则不改变文本框内容 (gr.update()是占位符)
270
  inp_audio.change(fn=clear_text_on_audio, inputs=inp_audio, outputs=inp_text)
271
 
272
 
273
- # ---- 启动应用 ----
274
  if __name__ == "__main__":
275
- # 设置Hugging Face Hub Token (如果需要从私有仓库加载模型)
276
- # from huggingface_hub import login
277
- # login("YOUR_HF_TOKEN") # 在本地运行时取消注释并替换
 
 
278
 
279
- # 在Hugging Face Spaces上运行时,端口通常由平台管理
280
- # share=True 会创建一个公共链接 (如果在本地运行需要)
281
- demo.launch(share=False)
 
2
  import torch
3
  from transformers import pipeline, set_seed
4
  from diffusers import StableDiffusionPipeline
5
+ import openai
6
  import os
7
  import time
8
+ import traceback # For detailed error logging
9
 
10
+ # ---- Configuration & API Key ----
11
+ # Check for OpenAI API Key in Hugging Face Secrets
12
+ api_key = os.environ.get("OPENAI_API_KEY")
13
+ openai_client = None
14
+ openai_available = False
15
 
16
+ if api_key:
17
+ try:
18
+ openai.api_key = api_key
19
+ # Starting with openai v1, client instantiation is preferred
20
+ openai_client = openai.OpenAI(api_key=api_key)
21
+ # Simple test to check if the key is valid (optional, but good)
22
+ # openai_client.models.list() # This call might incur small cost/quota usage
23
+ openai_available = True
24
+ print("OpenAI API key found and client initialized.")
25
+ except Exception as e:
26
+ print(f"Error initializing OpenAI client: {e}")
27
+ print("Proceeding without OpenAI features.")
28
+ else:
29
+ print("WARNING: OPENAI_API_KEY secret not found. Prompt enhancement via OpenAI is disabled.")
30
+
31
+ # Force CPU usage
32
+ device = "cpu"
33
  print(f"Using device: {device}")
34
 
35
+ # ---- Model Loading (CPU Focused) ----
36
+
37
  # 1. 语音转文本模型 (Whisper) - 加分项
38
  asr_pipeline = None
39
  try:
40
+ print("Loading ASR pipeline (Whisper) on CPU...")
41
+ # Force CPU usage with device=-1 or device="cpu"
42
+ asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
43
+ print("ASR pipeline loaded successfully on CPU.")
 
44
  except Exception as e:
45
  print(f"Could not load ASR pipeline: {e}. Voice input will be disabled.")
46
+ traceback.print_exc() # Print full traceback for debugging
47
 
48
+ # 2. 文本到图像模型 (Stable Diffusion) - Step 2 (CPU)
 
 
 
 
 
 
 
 
 
 
 
 
49
  image_generator_pipe = None
50
  try:
51
+ print("Loading Stable Diffusion pipeline (v1.5) on CPU...")
52
+ print("WARNING: Stable Diffusion on CPU is VERY SLOW (expect minutes per image).")
53
  model_id = "runwayml/stable-diffusion-v1-5"
54
+ # Use float32 for CPU
55
+ image_generator_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
56
  image_generator_pipe = image_generator_pipe.to(device)
57
+ print("Stable Diffusion pipeline loaded successfully on CPU.")
 
 
58
  except Exception as e:
59
+ print(f"CRITICAL: Could not load Stable Diffusion pipeline: {e}. Image generation will fail.")
60
+ traceback.print_exc() # Print full traceback for debugging
61
+ # Define a dummy object to prevent crashes later if loading failed
62
  class DummyPipe:
63
  def __call__(self, *args, **kwargs):
64
+ raise RuntimeError(f"Stable Diffusion model failed to load: {e}")
 
65
  image_generator_pipe = DummyPipe()
66
 
67
 
68
+ # ---- Core Function Definitions ----
69
 
70
+ # Step 1: Prompt-to-Prompt (using OpenAI API)
71
+ def enhance_prompt_openai(short_prompt, style_modifier="cinematic", quality_boost="photorealistic, highly detailed"):
72
+ """Uses OpenAI API to enhance the short description."""
73
+ if not openai_available or not openai_client:
74
+ # Fallback or error if OpenAI key is missing/invalid
75
+ print("OpenAI not available. Returning original prompt with modifiers.")
76
+ return f"{short_prompt}, {style_modifier}, {quality_boost}"
77
  if not short_prompt:
78
+ # Return an error message formatted for Gradio output
79
+ raise gr.Error("Input description cannot be empty.")
80
+
81
+ # Construct the prompt for the OpenAI model
82
+ system_message = (
83
+ "You are an expert prompt engineer for AI image generation models like Stable Diffusion. "
84
+ "Expand the user's short description into a detailed, vivid, and coherent prompt. "
85
+ "Focus on visual details: subjects, objects, environment, lighting, atmosphere, composition. "
86
+ "Incorporate the requested style and quality keywords naturally. Avoid conversational text."
 
87
  )
88
+ user_message = (
89
+ f"Enhance this description: \"{short_prompt}\". "
90
+ f"Style: '{style_modifier}'. Quality: '{quality_boost}'."
91
+ )
92
+
93
+ print(f"Sending request to OpenAI for prompt enhancement: {short_prompt}")
94
 
95
  try:
96
+ response = openai_client.chat.completions.create(
97
+ model="gpt-3.5-turbo", # Cost-effective choice, can use gpt-4 if needed/key allows
98
+ messages=[
99
+ {"role": "system", "content": system_message},
100
+ {"role": "user", "content": user_message},
101
+ ],
102
+ temperature=0.7, # Controls creativity vs predictability
103
+ max_tokens=150, # Limit output length
104
+ n=1, # Generate one response
105
+ stop=None # Let the model decide when to stop
 
 
 
106
  )
107
+ enhanced_prompt = response.choices[0].message.content.strip()
108
+ print("OpenAI enhancement successful.")
109
+ # Basic cleanup: remove potential quotes around the whole response
110
+ if enhanced_prompt.startswith('"') and enhanced_prompt.endswith('"'):
111
+ enhanced_prompt = enhanced_prompt[1:-1]
112
+ return enhanced_prompt
113
+ except openai.AuthenticationError:
114
+ print("OpenAI Authentication Error: Invalid API key?")
115
+ raise gr.Error("OpenAI Authentication Error: Check your API key.")
116
+ except openai.RateLimitError:
117
+ print("OpenAI Rate Limit Error: You've exceeded your quota or rate limit.")
118
+ raise gr.Error("OpenAI Error: Rate limit exceeded.")
119
+ except openai.APIError as e:
120
+ print(f"OpenAI API Error: {e}")
121
+ raise gr.Error(f"OpenAI API Error: {e}")
 
 
122
  except Exception as e:
123
+ print(f"An unexpected error occurred during OpenAI call: {e}")
124
+ traceback.print_exc()
125
+ raise gr.Error(f"Prompt enhancement failed: {e}")
126
+
127
 
128
+ # Step 2: Prompt-to-Image (CPU)
129
+ def generate_image_cpu(prompt, negative_prompt, guidance_scale, num_inference_steps):
130
+ """Generates image using Stable Diffusion on CPU."""
131
  if not isinstance(image_generator_pipe, StableDiffusionPipeline):
132
+ raise gr.Error("Stable Diffusion model is not available (failed to load).")
133
+ if not prompt or "[Error:" in prompt or "Error:" in prompt:
134
+ # Check if the prompt itself is an error message from the previous step
135
  raise gr.Error("Cannot generate image due to invalid or missing prompt.")
136
 
137
+ print(f"Generating image on CPU for prompt: {prompt[:100]}...") # Log truncated prompt
138
  print(f"Negative prompt: {negative_prompt}")
139
  print(f"Guidance scale: {guidance_scale}, Steps: {num_inference_steps}")
140
+ start_time = time.time()
141
 
142
  try:
143
+ # Use torch.inference_mode() or torch.no_grad() for efficiency
144
+ with torch.no_grad():
145
+ # Seed for reproducibility (optional, but good practice)
146
+ generator = torch.Generator(device=device).manual_seed(int(time.time()))
147
+ image = image_generator_pipe(
148
+ prompt=prompt,
149
+ negative_prompt=negative_prompt,
150
+ guidance_scale=float(guidance_scale),
151
+ num_inference_steps=int(num_inference_steps),
152
+ generator=generator,
153
+ ).images[0]
154
+ end_time = time.time()
155
+ print(f"Image generated successfully on CPU in {end_time - start_time:.2f} seconds.")
156
  return image
157
  except Exception as e:
158
+ print(f"Error during image generation on CPU: {e}")
159
+ traceback.print_exc()
160
+ # Propagate error to Gradio UI
161
+ raise gr.Error(f"Image generation failed on CPU: {e}")
162
 
163
 
164
+ # Bonus: Voice-to-Text (CPU)
165
  def transcribe_audio(audio_file_path):
166
+ """Transcribes audio to text using Whisper on CPU."""
167
  if not asr_pipeline:
168
+ # This case should ideally be handled by hiding the control, but double-check
169
+ return "[Error: ASR model not loaded]", audio_file_path
170
  if audio_file_path is None:
171
+ return "", audio_file_path # No audio input
172
 
173
+ print(f"Transcribing audio file: {audio_file_path} on CPU...")
174
+ start_time = time.time()
175
  try:
176
+ # Ensure the pipeline uses the correct device (should be CPU based on loading)
177
  transcription = asr_pipeline(audio_file_path)["text"]
178
+ end_time = time.time()
179
+ print(f"Transcription successful in {end_time - start_time:.2f} seconds.")
180
  print(f"Transcription result: {transcription}")
181
+ return transcription, audio_file_path
182
  except Exception as e:
183
+ print(f"Error during audio transcription on CPU: {e}")
184
+ traceback.print_exc()
185
+ # Return error message in the expected tuple format
186
  return f"[Error: Transcription failed: {e}]", audio_file_path
187
 
188
 
189
+ # ---- Gradio Application Flow ----
190
 
191
  def process_input(input_text, audio_file, style_choice, quality_choice, neg_prompt, guidance, steps):
192
+ """Main function triggered by Gradio button."""
193
  final_text_input = ""
194
+ enhanced_prompt = ""
195
+ generated_image = None
196
+ status_message = "" # To gather status/errors for the prompt box
197
 
198
+ # 1. Determine Input (Text or Audio)
199
  if input_text and input_text.strip():
200
  final_text_input = input_text.strip()
201
+ print(f"Using text input: '{final_text_input}'")
 
202
  elif audio_file is not None:
203
+ print("Processing audio input...")
204
  transcribed_text, _ = transcribe_audio(audio_file)
205
+ if "[Error:" in transcribed_text:
206
+ # Display transcription error clearly
207
+ status_message = transcribed_text
208
+ print(status_message)
209
+ # Return error in prompt field, no image
210
+ return status_message, None
211
+ elif transcribed_text:
212
  final_text_input = transcribed_text
213
+ print(f"Using transcribed audio input: '{final_text_input}'")
 
 
 
214
  else:
215
+ status_message = "[Error: Audio input received but transcription was empty.]"
216
+ print(status_message)
217
+ return status_message, None # Return error
218
  else:
219
+ status_message = "[Error: No input provided. Please enter text or record audio.]"
220
+ print(status_message)
221
+ return status_message, None # Return error
 
222
 
223
+ # 2. Enhance Prompt (using OpenAI if available)
224
+ if final_text_input:
 
 
 
 
 
225
  try:
226
+ enhanced_prompt = enhance_prompt_openai(final_text_input, style_choice, quality_choice)
227
+ status_message = enhanced_prompt # Display the prompt
228
+ print(f"Enhanced prompt: {enhanced_prompt}")
229
  except gr.Error as e:
230
+ # Catch Gradio-specific errors from enhancement function
231
+ status_message = f"[Prompt Enhancement Error: {e}]"
232
+ print(status_message)
233
+ # Return the error, no image generation attempt
234
+ return status_message, None
235
  except Exception as e:
236
+ # Catch any other unexpected errors
237
+ status_message = f"[Unexpected Prompt Enhancement Error: {e}]"
238
+ print(status_message)
239
+ traceback.print_exc()
240
+ return status_message, None
241
+
242
+ # 3. Generate Image (if prompt is valid)
243
+ if enhanced_prompt and not status_message.startswith("[Error:") and not status_message.startswith("[Prompt Enhancement Error:"):
244
+ try:
245
+ # Show "Generating..." message while waiting
246
+ gr.Info("Starting image generation on CPU... This will take a while (possibly several minutes).")
247
+ generated_image = generate_image_cpu(enhanced_prompt, neg_prompt, guidance, steps)
248
+ gr.Info("Image generation complete!")
249
+ except gr.Error as e:
250
+ # Catch Gradio errors from generation function
251
+ status_message = f"{enhanced_prompt}\n\n[Image Generation Error: {e}]" # Append error to prompt
252
+ print(f"Image Generation Error: {e}")
253
+ except Exception as e:
254
+ status_message = f"{enhanced_prompt}\n\n[Unexpected Image Generation Error: {e}]"
255
+ print(f"Unexpected Image Generation Error: {e}")
256
+ traceback.print_exc()
257
+ # Set image to None explicitly on error
258
+ generated_image = None
259
 
260
+ # 4. Return results to Gradio UI
261
+ # Return the status message (enhanced prompt or error) and the image (or None if error)
262
+ return status_message, generated_image
263
 
264
 
265
+ # ---- Gradio Interface Construction ----
266
 
267
+ style_options = ["cinematic", "photorealistic", "anime", "fantasy art", "cyberpunk", "steampunk", "watercolor", "illustration", "low poly"]
268
+ quality_options = ["highly detailed", "sharp focus", "intricate details", "4k", "masterpiece", "best quality", "professional lighting"]
269
+
270
+ # Reduced steps for faster CPU generation attempt
271
+ default_steps = 20
272
+ max_steps = 50 # Limit max steps on CPU
273
 
274
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
275
+ gr.Markdown("# AI Image Generator (CPU Version)")
276
+ gr.Markdown(
277
+ "**Enter a short description or use voice input.** The app uses OpenAI (if API key is provided) "
278
+ "to create a detailed prompt, then generates an image using Stable Diffusion v1.5 **on the CPU**."
279
+ )
280
+ # Add specific warning about CPU speed
281
+ gr.HTML("<p style='color:orange;font-weight:bold;'>⚠️ Warning: Image generation on CPU is very slow! Expect several minutes per image.</p>")
282
+
283
+ # Display OpenAI availability status
284
+ if not openai_available:
285
+ gr.Markdown("**Note:** OpenAI API key not found or invalid. Prompt enhancement will use a basic fallback.")
286
 
287
  with gr.Row():
288
  with gr.Column(scale=1):
289
+ # --- Inputs ---
290
+ inp_text = gr.Textbox(label="Enter short description", placeholder="e.g., A cute robot drinking coffee on Mars")
291
+
292
+ # Only show Audio input if ASR model loaded successfully
293
+ if asr_pipeline:
294
+ inp_audio = gr.Audio(sources=["microphone"], type="filepath", label="Or record your idea (clears text box if used)")
295
+ else:
296
+ gr.Markdown("**Voice input disabled:** Whisper model failed to load.")
297
+ inp_audio = gr.Textbox(visible=False) # Hidden placeholder
298
+
299
+ # --- Controls (Step 3 requirements met) ---
300
+ # Control 1: Dropdown
301
+ inp_style = gr.Dropdown(label="Base Style", choices=style_options, value="cinematic")
302
+ # Control 2: Radio
303
  inp_quality = gr.Radio(label="Quality Boost", choices=quality_options, value="highly detailed")
304
+ # Control 3: Textbox (Negative Prompt)
305
+ inp_neg_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, text, watermark, signature, deformed")
306
+ # Control 4: Slider (Guidance Scale)
307
+ inp_guidance = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=7.0, label="Guidance Scale (CFG)") # Slightly lower max maybe better for CPU
308
+ # Control 5: Slider (Inference Steps) - Reduced max/default
309
+ inp_steps = gr.Slider(minimum=10, maximum=max_steps, step=1, value=default_steps, label=f"Inference Steps (lower = faster but less detail, max {max_steps})")
310
+
311
+ # --- Action Button ---
 
 
 
312
  btn_generate = gr.Button("Generate Image", variant="primary")
313
 
314
  with gr.Column(scale=1):
315
+ # --- Outputs ---
316
+ out_prompt = gr.Textbox(label="Generated Prompt / Status", interactive=False, lines=5) # Show prompt or error status here
317
+ out_image = gr.Image(label="Generated Image", type="pil")
318
+
319
+ # --- Event Handling ---
320
+ # Define inputs list carefully, handling potentially invisible audio input
321
+ inputs_list = [inp_text]
322
+ if asr_pipeline:
323
+ inputs_list.append(inp_audio)
324
+ else:
325
+ inputs_list.append(gr.State(None)) # Pass None if audio control doesn't exist
326
+
327
+ inputs_list.extend([inp_style, inp_quality, inp_neg_prompt, inp_guidance, inp_steps])
328
 
 
329
  btn_generate.click(
330
  fn=process_input,
331
+ inputs=inputs_list,
332
  outputs=[out_prompt, out_image]
333
  )
334
 
335
+ # Clear text input if audio is used
336
  if asr_pipeline:
337
  def clear_text_on_audio(audio_data):
338
  if audio_data is not None:
339
+ return "" # Clear text box
340
+ return gr.update() # No change if no audio data
341
  inp_audio.change(fn=clear_text_on_audio, inputs=inp_audio, outputs=inp_text)
342
 
343
 
344
+ # ---- Application Launch ----
345
  if __name__ == "__main__":
346
+ # Check again if SD loaded, maybe prevent launch? Or let it run and fail gracefully in UI.
347
+ if not isinstance(image_generator_pipe, StableDiffusionPipeline):
348
+ print("CRITICAL FAILURE: Stable Diffusion pipeline did not load. The application UI will load, but image generation WILL NOT WORK.")
349
+ # Optionally, you could raise an error here to stop the script if SD is essential
350
+ # raise RuntimeError("Failed to load Stable Diffusion pipeline, cannot start application.")
351
 
352
+ # Launch the Gradio app
353
+ demo.launch(share=False) # share=True generates a public link if run locally