start3406 commited on
Commit
b314d94
·
verified ·
1 Parent(s): 0ad3783

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -81
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import pipeline, set_seed
4
- from diffusers import StableDiffusionPipeline
 
5
  import openai
6
  import os
7
  import time
@@ -15,7 +16,6 @@ openai_available = False
15
 
16
  if api_key:
17
  try:
18
- openai.api_key = api_key
19
  # Starting with openai v1, client instantiation is preferred
20
  openai_client = openai.OpenAI(api_key=api_key)
21
  # Simple test to check if the key is valid (optional, but good)
@@ -39,29 +39,31 @@ asr_pipeline = None
39
  try:
40
  print("Loading ASR pipeline (Whisper) on CPU...")
41
  # Force CPU usage with device=-1 or device="cpu"
42
- asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
 
43
  print("ASR pipeline loaded successfully on CPU.")
44
  except Exception as e:
45
  print(f"Could not load ASR pipeline: {e}. Voice input will be disabled.")
46
  traceback.print_exc() # Print full traceback for debugging
47
 
48
- # 2. 文本到图像模型 (Stable Diffusion) - Step 2 (CPU)
49
  image_generator_pipe = None
 
 
50
  try:
51
- print("Loading Stable Diffusion pipeline (v1.5) on CPU...")
52
- print("WARNING: Stable Diffusion on CPU is VERY SLOW (expect minutes per image).")
53
- model_id = "runwayml/stable-diffusion-v1-5"
54
- # Use float32 for CPU
55
- image_generator_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
56
  image_generator_pipe = image_generator_pipe.to(device)
57
- print("Stable Diffusion pipeline loaded successfully on CPU.")
58
  except Exception as e:
59
- print(f"CRITICAL: Could not load Stable Diffusion pipeline: {e}. Image generation will fail.")
60
  traceback.print_exc() # Print full traceback for debugging
61
  # Define a dummy object to prevent crashes later if loading failed
62
  class DummyPipe:
63
  def __call__(self, *args, **kwargs):
64
- raise RuntimeError(f"Stable Diffusion model failed to load: {e}")
65
  image_generator_pipe = DummyPipe()
66
 
67
 
@@ -73,17 +75,25 @@ def enhance_prompt_openai(short_prompt, style_modifier="cinematic", quality_boos
73
  if not openai_available or not openai_client:
74
  # Fallback or error if OpenAI key is missing/invalid
75
  print("OpenAI not available. Returning original prompt with modifiers.")
76
- return f"{short_prompt}, {style_modifier}, {quality_boost}"
 
 
 
 
 
 
 
77
  if not short_prompt:
78
  # Return an error message formatted for Gradio output
79
  raise gr.Error("Input description cannot be empty.")
80
 
81
  # Construct the prompt for the OpenAI model
82
  system_message = (
83
- "You are an expert prompt engineer for AI image generation models like Stable Diffusion. "
84
- "Expand the user's short description into a detailed, vivid, and coherent prompt. "
85
- "Focus on visual details: subjects, objects, environment, lighting, atmosphere, composition. "
86
- "Incorporate the requested style and quality keywords naturally. Avoid conversational text."
 
87
  )
88
  user_message = (
89
  f"Enhance this description: \"{short_prompt}\". "
@@ -94,13 +104,13 @@ def enhance_prompt_openai(short_prompt, style_modifier="cinematic", quality_boos
94
 
95
  try:
96
  response = openai_client.chat.completions.create(
97
- model="gpt-3.5-turbo", # Cost-effective choice, can use gpt-4 if needed/key allows
98
  messages=[
99
  {"role": "system", "content": system_message},
100
  {"role": "user", "content": user_message},
101
  ],
102
  temperature=0.7, # Controls creativity vs predictability
103
- max_tokens=150, # Limit output length
104
  n=1, # Generate one response
105
  stop=None # Let the model decide when to stop
106
  )
@@ -108,7 +118,7 @@ def enhance_prompt_openai(short_prompt, style_modifier="cinematic", quality_boos
108
  print("OpenAI enhancement successful.")
109
  # Basic cleanup: remove potential quotes around the whole response
110
  if enhanced_prompt.startswith('"') and enhanced_prompt.endswith('"'):
111
- enhanced_prompt = enhanced_prompt[1:-1]
112
  return enhanced_prompt
113
  except openai.AuthenticationError:
114
  print("OpenAI Authentication Error: Invalid API key?")
@@ -127,38 +137,61 @@ def enhance_prompt_openai(short_prompt, style_modifier="cinematic", quality_boos
127
 
128
  # Step 2: Prompt-to-Image (CPU)
129
  def generate_image_cpu(prompt, negative_prompt, guidance_scale, num_inference_steps):
130
- """Generates image using Stable Diffusion on CPU."""
131
- if not isinstance(image_generator_pipe, StableDiffusionPipeline):
132
- raise gr.Error("Stable Diffusion model is not available (failed to load).")
 
 
 
 
 
 
 
 
 
133
  if not prompt or "[Error:" in prompt or "Error:" in prompt:
134
  # Check if the prompt itself is an error message from the previous step
135
  raise gr.Error("Cannot generate image due to invalid or missing prompt.")
136
 
137
  print(f"Generating image on CPU for prompt: {prompt[:100]}...") # Log truncated prompt
138
- print(f"Negative prompt: {negative_prompt}")
139
- print(f"Guidance scale: {guidance_scale}, Steps: {num_inference_steps}")
 
 
 
140
  start_time = time.time()
141
 
142
  try:
143
  # Use torch.inference_mode() or torch.no_grad() for efficiency
144
  with torch.no_grad():
145
  # Seed for reproducibility (optional, but good practice)
146
- generator = torch.Generator(device=device).manual_seed(int(time.time()))
147
- image = image_generator_pipe(
148
- prompt=prompt,
149
- negative_prompt=negative_prompt,
150
- guidance_scale=float(guidance_scale),
151
- num_inference_steps=int(num_inference_steps),
152
- generator=generator,
153
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
154
  end_time = time.time()
155
- print(f"Image generated successfully on CPU in {end_time - start_time:.2f} seconds.")
156
  return image
157
  except Exception as e:
158
- print(f"Error during image generation on CPU: {e}")
159
  traceback.print_exc()
160
  # Propagate error to Gradio UI
161
- raise gr.Error(f"Image generation failed on CPU: {e}")
162
 
163
 
164
  # Bonus: Voice-to-Text (CPU)
@@ -174,6 +207,7 @@ def transcribe_audio(audio_file_path):
174
  start_time = time.time()
175
  try:
176
  # Ensure the pipeline uses the correct device (should be CPU based on loading)
 
177
  transcription = asr_pipeline(audio_file_path)["text"]
178
  end_time = time.time()
179
  print(f"Transcription successful in {end_time - start_time:.2f} seconds.")
@@ -201,20 +235,37 @@ def process_input(input_text, audio_file, style_choice, quality_choice, neg_prom
201
  print(f"Using text input: '{final_text_input}'")
202
  elif audio_file is not None:
203
  print("Processing audio input...")
204
- transcribed_text, _ = transcribe_audio(audio_file)
205
- if "[Error:" in transcribed_text:
206
- # Display transcription error clearly
207
- status_message = transcribed_text
208
- print(status_message)
209
- # Return error in prompt field, no image
210
- return status_message, None
211
- elif transcribed_text:
212
- final_text_input = transcribed_text
213
- print(f"Using transcribed audio input: '{final_text_input}'")
214
- else:
215
- status_message = "[Error: Audio input received but transcription was empty.]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  print(status_message)
 
217
  return status_message, None # Return error
 
218
  else:
219
  status_message = "[Error: No input provided. Please enter text or record audio.]"
220
  print(status_message)
@@ -224,7 +275,7 @@ def process_input(input_text, audio_file, style_choice, quality_choice, neg_prom
224
  if final_text_input:
225
  try:
226
  enhanced_prompt = enhance_prompt_openai(final_text_input, style_choice, quality_choice)
227
- status_message = enhanced_prompt # Display the prompt
228
  print(f"Enhanced prompt: {enhanced_prompt}")
229
  except gr.Error as e:
230
  # Catch Gradio-specific errors from enhancement function
@@ -240,22 +291,31 @@ def process_input(input_text, audio_file, style_choice, quality_choice, neg_prom
240
  return status_message, None
241
 
242
  # 3. Generate Image (if prompt is valid)
 
243
  if enhanced_prompt and not status_message.startswith("[Error:") and not status_message.startswith("[Prompt Enhancement Error:"):
244
  try:
245
  # Show "Generating..." message while waiting
246
- gr.Info("Starting image generation on CPU... This will take a while (possibly several minutes).")
247
  generated_image = generate_image_cpu(enhanced_prompt, neg_prompt, guidance, steps)
248
  gr.Info("Image generation complete!")
249
  except gr.Error as e:
250
  # Catch Gradio errors from generation function
251
- status_message = f"{enhanced_prompt}\n\n[Image Generation Error: {e}]" # Append error to prompt
 
252
  print(f"Image Generation Error: {e}")
 
253
  except Exception as e:
 
254
  status_message = f"{enhanced_prompt}\n\n[Unexpected Image Generation Error: {e}]"
255
  print(f"Unexpected Image Generation Error: {e}")
256
  traceback.print_exc()
257
- # Set image to None explicitly on error
258
- generated_image = None
 
 
 
 
 
259
 
260
  # 4. Return results to Gradio UI
261
  # Return the status message (enhanced prompt or error) and the image (or None if error)
@@ -267,22 +327,31 @@ def process_input(input_text, audio_file, style_choice, quality_choice, neg_prom
267
  style_options = ["cinematic", "photorealistic", "anime", "fantasy art", "cyberpunk", "steampunk", "watercolor", "illustration", "low poly"]
268
  quality_options = ["highly detailed", "sharp focus", "intricate details", "4k", "masterpiece", "best quality", "professional lighting"]
269
 
270
- # Reduced steps for faster CPU generation attempt
271
- default_steps = 20
272
- max_steps = 50 # Limit max steps on CPU
 
273
 
274
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
275
- gr.Markdown("# AI Image Generator (CPU Version)")
276
  gr.Markdown(
277
  "**Enter a short description or use voice input.** The app uses OpenAI (if API key is provided) "
278
- "to create a detailed prompt, then generates an image using Stable Diffusion v1.5 **on the CPU**."
279
  )
280
- # Add specific warning about CPU speed
281
- gr.HTML("<p style='color:orange;font-weight:bold;'>⚠️ Warning: Image generation on CPU is very slow! Expect several minutes per image.</p>")
282
 
283
  # Display OpenAI availability status
284
  if not openai_available:
285
  gr.Markdown("**Note:** OpenAI API key not found or invalid. Prompt enhancement will use a basic fallback.")
 
 
 
 
 
 
 
 
286
 
287
  with gr.Row():
288
  with gr.Column(scale=1):
@@ -294,27 +363,31 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
294
  inp_audio = gr.Audio(sources=["microphone"], type="filepath", label="Or record your idea (clears text box if used)")
295
  else:
296
  gr.Markdown("**Voice input disabled:** Whisper model failed to load.")
297
- inp_audio = gr.Textbox(visible=False) # Hidden placeholder
 
298
 
299
  # --- Controls (Step 3 requirements met) ---
 
 
300
  # Control 1: Dropdown
301
- inp_style = gr.Dropdown(label="Base Style", choices=style_options, value="cinematic")
302
  # Control 2: Radio
303
- inp_quality = gr.Radio(label="Quality Boost", choices=quality_options, value="highly detailed")
304
  # Control 3: Textbox (Negative Prompt)
305
- inp_neg_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, text, watermark, signature, deformed")
306
  # Control 4: Slider (Guidance Scale)
307
- inp_guidance = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=7.0, label="Guidance Scale (CFG)") # Slightly lower max maybe better for CPU
308
  # Control 5: Slider (Inference Steps) - Reduced max/default
309
- inp_steps = gr.Slider(minimum=10, maximum=max_steps, step=1, value=default_steps, label=f"Inference Steps (lower = faster but less detail, max {max_steps})")
310
 
311
  # --- Action Button ---
312
- btn_generate = gr.Button("Generate Image", variant="primary")
 
313
 
314
  with gr.Column(scale=1):
315
  # --- Outputs ---
316
  out_prompt = gr.Textbox(label="Generated Prompt / Status", interactive=False, lines=5) # Show prompt or error status here
317
- out_image = gr.Image(label="Generated Image", type="pil")
318
 
319
  # --- Event Handling ---
320
  # Define inputs list carefully, handling potentially invisible audio input
@@ -322,32 +395,44 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
322
  if asr_pipeline:
323
  inputs_list.append(inp_audio)
324
  else:
325
- inputs_list.append(gr.State(None)) # Pass None if audio control doesn't exist
 
326
 
327
  inputs_list.extend([inp_style, inp_quality, inp_neg_prompt, inp_guidance, inp_steps])
328
 
 
329
  btn_generate.click(
330
  fn=process_input,
331
  inputs=inputs_list,
332
  outputs=[out_prompt, out_image]
333
  )
334
 
335
- # Clear text input if audio is used
336
  if asr_pipeline:
337
- def clear_text_on_audio(audio_data):
338
- if audio_data is not None:
339
- return "" # Clear text box
340
- return gr.update() # No change if no audio data
341
- inp_audio.change(fn=clear_text_on_audio, inputs=inp_audio, outputs=inp_text)
 
 
 
 
 
342
 
343
 
344
  # ---- Application Launch ----
345
  if __name__ == "__main__":
346
- # Check again if SD loaded, maybe prevent launch? Or let it run and fail gracefully in UI.
347
- if not isinstance(image_generator_pipe, StableDiffusionPipeline):
348
- print("CRITICAL FAILURE: Stable Diffusion pipeline did not load. The application UI will load, but image generation WILL NOT WORK.")
349
- # Optionally, you could raise an error here to stop the script if SD is essential
350
- # raise RuntimeError("Failed to load Stable Diffusion pipeline, cannot start application.")
 
 
 
 
351
 
352
  # Launch the Gradio app
353
- demo.launch(share=False) # share=True generates a public link if run locally
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import pipeline, set_seed
4
+ # 导入 AutoPipelineForText2Image 以便兼容不同模型
5
+ from diffusers import AutoPipelineForText2Image
6
  import openai
7
  import os
8
  import time
 
16
 
17
  if api_key:
18
  try:
 
19
  # Starting with openai v1, client instantiation is preferred
20
  openai_client = openai.OpenAI(api_key=api_key)
21
  # Simple test to check if the key is valid (optional, but good)
 
39
  try:
40
  print("Loading ASR pipeline (Whisper) on CPU...")
41
  # Force CPU usage with device=-1 or device="cpu"
42
+ # 使用 fp16 会更快但需要GPU,CPU上用 float32
43
+ asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device, torch_dtype=torch.float32)
44
  print("ASR pipeline loaded successfully on CPU.")
45
  except Exception as e:
46
  print(f"Could not load ASR pipeline: {e}. Voice input will be disabled.")
47
  traceback.print_exc() # Print full traceback for debugging
48
 
49
+ # 2. 文本到图像模型 (Tiny Text-to-Image) - 资源友好模型
50
  image_generator_pipe = None
51
+ # 使用资源需求极低的 Tiny Text-to-Image 模型
52
+ model_id = "hf-internal-testing/tiny-text-to-image"
53
  try:
54
+ print(f"Loading Text-to-Image pipeline ({model_id}) on CPU...")
55
+ print("NOTE: Using a very small model for resource efficiency. Image quality will be lower than Stable Diffusion.")
56
+ # 使用 AutoPipelineForText2Image 自动识别模型类型
57
+ image_generator_pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float32)
 
58
  image_generator_pipe = image_generator_pipe.to(device)
59
+ print(f"Text-to-Image pipeline ({model_id}) loaded successfully on CPU.")
60
  except Exception as e:
61
+ print(f"CRITICAL: Could not load Text-to-Image pipeline ({model_id}): {e}. Image generation will fail.")
62
  traceback.print_exc() # Print full traceback for debugging
63
  # Define a dummy object to prevent crashes later if loading failed
64
  class DummyPipe:
65
  def __call__(self, *args, **kwargs):
66
+ raise RuntimeError(f"Text-to-Image model failed to load: {e}")
67
  image_generator_pipe = DummyPipe()
68
 
69
 
 
75
  if not openai_available or not openai_client:
76
  # Fallback or error if OpenAI key is missing/invalid
77
  print("OpenAI not available. Returning original prompt with modifiers.")
78
+ # Basic fallback prompt enhancement
79
+ if short_prompt:
80
+ return f"{short_prompt}, {style_modifier}, {quality_boost}"
81
+ else:
82
+ # If short prompt is empty, fallback should also indicate error
83
+ raise gr.Error("Input description cannot be empty.")
84
+
85
+
86
  if not short_prompt:
87
  # Return an error message formatted for Gradio output
88
  raise gr.Error("Input description cannot be empty.")
89
 
90
  # Construct the prompt for the OpenAI model
91
  system_message = (
92
+ "You are an expert prompt engineer for AI image generation models. "
93
+ "Expand the user's short description into a detailed, vivid, and coherent prompt, suitable for smaller, faster text-to-image models. "
94
+ "Focus on clear subjects, objects, and main scene elements. "
95
+ "Incorporate the requested style and quality keywords naturally, but keep the overall prompt concise enough for smaller models. Avoid conversational text."
96
+ # Adjusting guidance for smaller models
97
  )
98
  user_message = (
99
  f"Enhance this description: \"{short_prompt}\". "
 
104
 
105
  try:
106
  response = openai_client.chat.completions.create(
107
+ model="gpt-3.5-turbo", # Cost-effective choice
108
  messages=[
109
  {"role": "system", "content": system_message},
110
  {"role": "user", "content": user_message},
111
  ],
112
  temperature=0.7, # Controls creativity vs predictability
113
+ max_tokens=100, # Limit output length - reduced for potentially shorter prompts for smaller models
114
  n=1, # Generate one response
115
  stop=None # Let the model decide when to stop
116
  )
 
118
  print("OpenAI enhancement successful.")
119
  # Basic cleanup: remove potential quotes around the whole response
120
  if enhanced_prompt.startswith('"') and enhanced_prompt.endswith('"'):
121
+ enhanced_prompt = enhanced_prompt[1:-1]
122
  return enhanced_prompt
123
  except openai.AuthenticationError:
124
  print("OpenAI Authentication Error: Invalid API key?")
 
137
 
138
  # Step 2: Prompt-to-Image (CPU)
139
  def generate_image_cpu(prompt, negative_prompt, guidance_scale, num_inference_steps):
140
+ """Generates image using the loaded model on CPU."""
141
+ # 检查加载的模型是否是期望的pipeline类型或DummyPipe
142
+ if not isinstance(image_generator_pipe, AutoPipelineForText2Image):
143
+ # If it's a DummyPipe or None for some reason
144
+ if isinstance(image_generator_pipe, DummyPipe):
145
+ # DummyPipe will raise its own error when called, so just let it
146
+ pass # The call below will raise the intended error
147
+ else:
148
+ # Handle unexpected case where pipe is not loaded correctly
149
+ raise gr.Error("Image generation pipeline is not available (failed to load model).")
150
+
151
+
152
  if not prompt or "[Error:" in prompt or "Error:" in prompt:
153
  # Check if the prompt itself is an error message from the previous step
154
  raise gr.Error("Cannot generate image due to invalid or missing prompt.")
155
 
156
  print(f"Generating image on CPU for prompt: {prompt[:100]}...") # Log truncated prompt
157
+ # Note: Negative prompt and guidance scale might have less impact or behave differently
158
+ # on very small models like tiny-text-to-image.
159
+ print(f"Negative prompt: {negative_prompt}") # Will likely be ignored by tiny model
160
+ print(f"Guidance scale: {guidance_scale}, Steps: {num_inference_steps}") # Steps might be fixed internally by tiny model
161
+
162
  start_time = time.time()
163
 
164
  try:
165
  # Use torch.inference_mode() or torch.no_grad() for efficiency
166
  with torch.no_grad():
167
  # Seed for reproducibility (optional, but good practice)
168
+ # generator = torch.Generator(device=device).manual_seed(int(time.time())) # Tiny model might not use generator param
169
+ # Tiny Text-to-Image pipeline call structure might be simpler
170
+ # Check model specific documentation if parameters like guidance_scale, num_inference_steps, negative_prompt
171
+ # are actually supported. They might be ignored.
172
+ # Using a simple call that is generally compatible
173
+ output = image_generator_pipe(prompt=prompt) # Tiny model might only take prompt
174
+
175
+ # The output structure varies between pipelines, assuming it has .images
176
+ # if hasattr(output, 'images') and isinstance(output.images, list) and len(output.images) > 0:
177
+ # image = output.images[0] # Access the first image
178
+ # else:
179
+ # # Handle cases where output format is different
180
+ # print("Warning: Pipeline output format unexpected. Assuming the output itself is the image.")
181
+ # image = output # Assume output is the image if no .images
182
+
183
+ # Based on tiny-text-to-image, the output is likely a tuple where the first element is a list of images
184
+ image = output[0][0] # Access the first image in the first list of the tuple output structure
185
+
186
+
187
  end_time = time.time()
188
+ print(f"Image generated successfully on CPU in {end_time - start_time:.2f} seconds (using {model_id}).")
189
  return image
190
  except Exception as e:
191
+ print(f"Error during image generation on CPU ({model_id}): {e}")
192
  traceback.print_exc()
193
  # Propagate error to Gradio UI
194
+ raise gr.Error(f"Image generation failed on CPU ({model_id}): {e}")
195
 
196
 
197
  # Bonus: Voice-to-Text (CPU)
 
207
  start_time = time.time()
208
  try:
209
  # Ensure the pipeline uses the correct device (should be CPU based on loading)
210
+ # Ensure input is in expected format for Whisper pipeline (filepath or audio array)
211
  transcription = asr_pipeline(audio_file_path)["text"]
212
  end_time = time.time()
213
  print(f"Transcription successful in {end_time - start_time:.2f} seconds.")
 
235
  print(f"Using text input: '{final_text_input}'")
236
  elif audio_file is not None:
237
  print("Processing audio input...")
238
+ try:
239
+ # Gradio might pass a tuple (samplerate, audio_data) or a filepath depending on type="filepath" vs "numpy"
240
+ # transcribe_audio expects a filepath based on the Gradio component config
241
+ if isinstance(audio_file, tuple):
242
+ # If Gradio gives tuple for some reason, try to save to temp file or adjust transcribe_audio
243
+ # Assuming type="filepath" works as expected and passes filepath
244
+ audio_filepath_to_transcribe = audio_file[0] # This might be incorrect depending on Gradio version/config
245
+ print(f"Warning: Gradio audio input was tuple, attempting to use first element as path: {audio_filepath_to_transcribe}")
246
+ else:
247
+ audio_filepath_to_transcribe = audio_file # This is expected for type="filepath"
248
+
249
+ transcribed_text, _ = transcribe_audio(audio_filepath_to_transcribe)
250
+
251
+ if "[Error:" in transcribed_text:
252
+ # Display transcription error clearly
253
+ status_message = transcribed_text
254
+ print(status_message)
255
+ return status_message, None # Return error in prompt field, no image
256
+ elif transcribed_text:
257
+ final_text_input = transcribed_text
258
+ print(f"Using transcribed audio input: '{final_text_input}'")
259
+ else:
260
+ status_message = "[Error: Audio input received but transcription was empty.]"
261
+ print(status_message)
262
+ return status_message, None # Return error
263
+ except Exception as e:
264
+ status_message = f"[Unexpected Audio Transcription Error: {e}]"
265
  print(status_message)
266
+ traceback.print_exc()
267
  return status_message, None # Return error
268
+
269
  else:
270
  status_message = "[Error: No input provided. Please enter text or record audio.]"
271
  print(status_message)
 
275
  if final_text_input:
276
  try:
277
  enhanced_prompt = enhance_prompt_openai(final_text_input, style_choice, quality_choice)
278
+ status_message = enhanced_prompt # Display the prompt initially
279
  print(f"Enhanced prompt: {enhanced_prompt}")
280
  except gr.Error as e:
281
  # Catch Gradio-specific errors from enhancement function
 
291
  return status_message, None
292
 
293
  # 3. Generate Image (if prompt is valid)
294
+ # Check if the enhanced prompt step resulted in an error message
295
  if enhanced_prompt and not status_message.startswith("[Error:") and not status_message.startswith("[Prompt Enhancement Error:"):
296
  try:
297
  # Show "Generating..." message while waiting
298
+ gr.Info(f"Starting image generation on CPU using {model_id}. This should be fast but quality is low.")
299
  generated_image = generate_image_cpu(enhanced_prompt, neg_prompt, guidance, steps)
300
  gr.Info("Image generation complete!")
301
  except gr.Error as e:
302
  # Catch Gradio errors from generation function
303
+ # Prepend original enhanced prompt to the error message for context
304
+ status_message = f"{enhanced_prompt}\n\n[Image Generation Error: {e}]"
305
  print(f"Image Generation Error: {e}")
306
+ generated_image = None # Ensure image is None on error
307
  except Exception as e:
308
+ # Catch any other unexpected errors
309
  status_message = f"{enhanced_prompt}\n\n[Unexpected Image Generation Error: {e}]"
310
  print(f"Unexpected Image Generation Error: {e}")
311
  traceback.print_exc()
312
+ generated_image = None # Ensure image is None on error
313
+
314
+ else:
315
+ # If prompt enhancement failed, status_message already contains the error
316
+ # In this case, we just return the existing status_message and None image
317
+ print("Skipping image generation due to prompt enhancement failure.")
318
+
319
 
320
  # 4. Return results to Gradio UI
321
  # Return the status message (enhanced prompt or error) and the image (or None if error)
 
327
  style_options = ["cinematic", "photorealistic", "anime", "fantasy art", "cyberpunk", "steampunk", "watercolor", "illustration", "low poly"]
328
  quality_options = ["highly detailed", "sharp focus", "intricate details", "4k", "masterpiece", "best quality", "professional lighting"]
329
 
330
+ # Tiny model is very fast, steps/guidance might be ignored or have less effect
331
+ # Keep sliders but note their limited impact on this specific model
332
+ default_steps = 10 # Tiny model often uses few steps internally
333
+ max_steps = 20 # Limit max steps as they might not matter much
334
 
335
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
336
+ gr.Markdown("# AI Image Generator (Resource-Friendly CPU Version)")
337
  gr.Markdown(
338
  "**Enter a short description or use voice input.** The app uses OpenAI (if API key is provided) "
339
+ f"to create a detailed prompt, then generates an image using a **small, fast model ({model_id}) on the CPU**."
340
  )
341
+ # Add specific warning about image quality for the tiny model
342
+ gr.HTML("<p style='color:orange;font-weight:bold;'>⚠️ Note: Using a small model for compatibility. Image quality and resolution will be significantly lower than models like Stable Diffusion.</p>")
343
 
344
  # Display OpenAI availability status
345
  if not openai_available:
346
  gr.Markdown("**Note:** OpenAI API key not found or invalid. Prompt enhancement will use a basic fallback.")
347
+ else:
348
+ gr.Markdown("**Note:** OpenAI API key found. Prompt will be enhanced using OpenAI.")
349
+
350
+
351
+ # Display Model loading status
352
+ if not isinstance(image_generator_pipe, AutoPipelineForText2Image):
353
+ gr.Markdown(f"**CRITICAL:** Image generation model ({model_id}) failed to load. Image generation is disabled. Check logs.")
354
+
355
 
356
  with gr.Row():
357
  with gr.Column(scale=1):
 
363
  inp_audio = gr.Audio(sources=["microphone"], type="filepath", label="Or record your idea (clears text box if used)")
364
  else:
365
  gr.Markdown("**Voice input disabled:** Whisper model failed to load.")
366
+ # Using gr.State as a placeholder that holds None
367
+ inp_audio = gr.State(None)
368
 
369
  # --- Controls (Step 3 requirements met) ---
370
+ # Note: These controls might have limited effect on the small model
371
+ gr.Markdown("*(Optional controls - Note: These may have limited or no effect on the small model used)*")
372
  # Control 1: Dropdown
373
+ inp_style = gr.Dropdown(label="Base Style", choices=style_options, value="cinematic", interactive=True)
374
  # Control 2: Radio
375
+ inp_quality = gr.Radio(label="Quality Boost", choices=quality_options, value="highly detailed", interactive=True)
376
  # Control 3: Textbox (Negative Prompt)
377
+ inp_neg_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, text, watermark", interactive=True)
378
  # Control 4: Slider (Guidance Scale)
379
+ inp_guidance = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=3.0, label="Guidance Scale (CFG)", interactive=True) # Lower default for small model
380
  # Control 5: Slider (Inference Steps) - Reduced max/default
381
+ inp_steps = gr.Slider(minimum=1, maximum=max_steps, step=1, value=default_steps, label=f"Inference Steps (lower = faster but less detail, max {max_steps})", interactive=True)
382
 
383
  # --- Action Button ---
384
+ # Disable button if model failed to load
385
+ btn_generate = gr.Button("Generate Image", variant="primary", interactive=isinstance(image_generator_pipe, AutoPipelineForText2Image))
386
 
387
  with gr.Column(scale=1):
388
  # --- Outputs ---
389
  out_prompt = gr.Textbox(label="Generated Prompt / Status", interactive=False, lines=5) # Show prompt or error status here
390
+ out_image = gr.Image(label="Generated Image", type="pil", show_label=True) # Ensure label is shown
391
 
392
  # --- Event Handling ---
393
  # Define inputs list carefully, handling potentially invisible audio input
 
395
  if asr_pipeline:
396
  inputs_list.append(inp_audio)
397
  else:
398
+ inputs_list.append(inp_audio) # Pass the gr.State(None) placeholder
399
+
400
 
401
  inputs_list.extend([inp_style, inp_quality, inp_neg_prompt, inp_guidance, inp_steps])
402
 
403
+ # Link button click to processing function
404
  btn_generate.click(
405
  fn=process_input,
406
  inputs=inputs_list,
407
  outputs=[out_prompt, out_image]
408
  )
409
 
410
+ # Clear text input if audio is used (only if ASR is available)
411
  if asr_pipeline:
412
+ def clear_text_on_audio_change(audio_data):
413
+ # Check if audio_data is not None or empty (depending on how Gradio signals recording)
414
+ if audio_data is not None:
415
+ print("Audio input detected, clearing text box.")
416
+ return "" # Clear text box
417
+ # If audio_data becomes None (e.g., recording cleared), don't clear text
418
+ return gr.update()
419
+
420
+ # .change event fires when the value changes, including becoming None if cleared
421
+ inp_audio.change(fn=clear_text_on_audio_change, inputs=inp_audio, outputs=inp_text, api_name="clear_text_on_audio")
422
 
423
 
424
  # ---- Application Launch ----
425
  if __name__ == "__main__":
426
+ # Final check before launch
427
+ if not isinstance(image_generator_pipe, AutoPipelineForText2Image):
428
+ print("\n" + "="*50)
429
+ print("CRITICAL WARNING:")
430
+ print(f"Image generation model ({model_id}) failed to load during startup.")
431
+ print("The Gradio UI will launch, but the 'Generate Image' button will be disabled.")
432
+ print("Check the logs above for the specific model loading error.")
433
+ print("="*50 + "\n")
434
+
435
 
436
  # Launch the Gradio app
437
+ # Running on 0.0.0.0 is necessary for Hugging Face Spaces
438
+ demo.launch(share=False, server_name="0.0.0.0", server_port=7860)