Update app.py
Browse files
app.py
CHANGED
@@ -46,13 +46,13 @@ except Exception as e:
|
|
46 |
print(f"Could not load ASR pipeline: {e}. Voice input will be disabled.")
|
47 |
traceback.print_exc() # Print full traceback for debugging
|
48 |
|
49 |
-
# 2. 文本到图像模型 (
|
50 |
image_generator_pipe = None
|
51 |
-
#
|
52 |
-
model_id = "
|
53 |
try:
|
54 |
print(f"Loading Text-to-Image pipeline ({model_id}) on CPU...")
|
55 |
-
print("NOTE: Using a
|
56 |
# 使用 AutoPipelineForText2Image 自动识别模型类型
|
57 |
image_generator_pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float32)
|
58 |
image_generator_pipe = image_generator_pipe.to(device)
|
@@ -155,7 +155,7 @@ def generate_image_cpu(prompt, negative_prompt, guidance_scale, num_inference_st
|
|
155 |
|
156 |
print(f"Generating image on CPU for prompt: {prompt[:100]}...") # Log truncated prompt
|
157 |
# Note: Negative prompt and guidance scale might have less impact or behave differently
|
158 |
-
# on very small models
|
159 |
print(f"Negative prompt: {negative_prompt}") # Will likely be ignored by tiny model
|
160 |
print(f"Guidance scale: {guidance_scale}, Steps: {num_inference_steps}") # Steps might be fixed internally by tiny model
|
161 |
|
@@ -166,23 +166,25 @@ def generate_image_cpu(prompt, negative_prompt, guidance_scale, num_inference_st
|
|
166 |
with torch.no_grad():
|
167 |
# Seed for reproducibility (optional, but good practice)
|
168 |
# generator = torch.Generator(device=device).manual_seed(int(time.time())) # Tiny model might not use generator param
|
169 |
-
#
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
#
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
186 |
|
187 |
end_time = time.time()
|
188 |
print(f"Image generated successfully on CPU in {end_time - start_time:.2f} seconds (using {model_id}).")
|
@@ -208,7 +210,19 @@ def transcribe_audio(audio_file_path):
|
|
208 |
try:
|
209 |
# Ensure the pipeline uses the correct device (should be CPU based on loading)
|
210 |
# Ensure input is in expected format for Whisper pipeline (filepath or audio array)
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
end_time = time.time()
|
213 |
print(f"Transcription successful in {end_time - start_time:.2f} seconds.")
|
214 |
print(f"Transcription result: {transcription}")
|
@@ -236,17 +250,8 @@ def process_input(input_text, audio_file, style_choice, quality_choice, neg_prom
|
|
236 |
elif audio_file is not None:
|
237 |
print("Processing audio input...")
|
238 |
try:
|
239 |
-
#
|
240 |
-
|
241 |
-
if isinstance(audio_file, tuple):
|
242 |
-
# If Gradio gives tuple for some reason, try to save to temp file or adjust transcribe_audio
|
243 |
-
# Assuming type="filepath" works as expected and passes filepath
|
244 |
-
audio_filepath_to_transcribe = audio_file[0] # This might be incorrect depending on Gradio version/config
|
245 |
-
print(f"Warning: Gradio audio input was tuple, attempting to use first element as path: {audio_filepath_to_transcribe}")
|
246 |
-
else:
|
247 |
-
audio_filepath_to_transcribe = audio_file # This is expected for type="filepath"
|
248 |
-
|
249 |
-
transcribed_text, _ = transcribe_audio(audio_filepath_to_transcribe)
|
250 |
|
251 |
if "[Error:" in transcribed_text:
|
252 |
# Display transcription error clearly
|
@@ -295,7 +300,7 @@ def process_input(input_text, audio_file, style_choice, quality_choice, neg_prom
|
|
295 |
if enhanced_prompt and not status_message.startswith("[Error:") and not status_message.startswith("[Prompt Enhancement Error:"):
|
296 |
try:
|
297 |
# Show "Generating..." message while waiting
|
298 |
-
gr.Info(f"Starting image generation on CPU using {model_id}. This should be
|
299 |
generated_image = generate_image_cpu(enhanced_prompt, neg_prompt, guidance, steps)
|
300 |
gr.Info("Image generation complete!")
|
301 |
except gr.Error as e:
|
@@ -327,19 +332,21 @@ def process_input(input_text, audio_file, style_choice, quality_choice, neg_prom
|
|
327 |
style_options = ["cinematic", "photorealistic", "anime", "fantasy art", "cyberpunk", "steampunk", "watercolor", "illustration", "low poly"]
|
328 |
quality_options = ["highly detailed", "sharp focus", "intricate details", "4k", "masterpiece", "best quality", "professional lighting"]
|
329 |
|
330 |
-
#
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
|
335 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
336 |
-
gr.Markdown("# AI Image Generator (
|
337 |
gr.Markdown(
|
338 |
"**Enter a short description or use voice input.** The app uses OpenAI (if API key is provided) "
|
339 |
-
f"to create a detailed prompt, then generates an image using a **small
|
340 |
)
|
341 |
-
# Add specific warning about
|
342 |
-
gr.HTML("<p style='color:orange;font-weight:bold;'>⚠️ Note: Using a small model for compatibility
|
|
|
|
|
343 |
|
344 |
# Display OpenAI availability status
|
345 |
if not openai_available:
|
@@ -347,10 +354,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
347 |
else:
|
348 |
gr.Markdown("**Note:** OpenAI API key found. Prompt will be enhanced using OpenAI.")
|
349 |
|
350 |
-
|
351 |
# Display Model loading status
|
|
|
352 |
if not isinstance(image_generator_pipe, AutoPipelineForText2Image):
|
353 |
-
gr.Markdown(f"**CRITICAL:** Image generation model ({model_id}) failed to load. Image generation is disabled. Check logs.")
|
354 |
|
355 |
|
356 |
with gr.Row():
|
@@ -366,19 +373,19 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
366 |
# Using gr.State as a placeholder that holds None
|
367 |
inp_audio = gr.State(None)
|
368 |
|
369 |
-
# --- Controls
|
370 |
-
# Note: These controls might have
|
371 |
-
gr.Markdown("*(Optional controls - Note:
|
372 |
# Control 1: Dropdown
|
373 |
-
inp_style = gr.Dropdown(label="Base Style", choices=style_options, value="cinematic"
|
374 |
# Control 2: Radio
|
375 |
-
inp_quality = gr.Radio(label="Quality Boost", choices=quality_options, value="highly detailed"
|
376 |
# Control 3: Textbox (Negative Prompt)
|
377 |
-
inp_neg_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, text, watermark
|
378 |
# Control 4: Slider (Guidance Scale)
|
379 |
-
inp_guidance = gr.Slider(minimum=1.0, maximum=
|
380 |
-
# Control 5: Slider (Inference Steps) -
|
381 |
-
inp_steps = gr.Slider(minimum=
|
382 |
|
383 |
# --- Action Button ---
|
384 |
# Disable button if model failed to load
|
@@ -397,7 +404,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
397 |
else:
|
398 |
inputs_list.append(inp_audio) # Pass the gr.State(None) placeholder
|
399 |
|
400 |
-
|
401 |
inputs_list.extend([inp_style, inp_quality, inp_neg_prompt, inp_guidance, inp_steps])
|
402 |
|
403 |
# Link button click to processing function
|
@@ -424,12 +430,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
424 |
# ---- Application Launch ----
|
425 |
if __name__ == "__main__":
|
426 |
# Final check before launch
|
|
|
427 |
if not isinstance(image_generator_pipe, AutoPipelineForText2Image):
|
428 |
print("\n" + "="*50)
|
429 |
print("CRITICAL WARNING:")
|
430 |
print(f"Image generation model ({model_id}) failed to load during startup.")
|
431 |
print("The Gradio UI will launch, but the 'Generate Image' button will be disabled.")
|
432 |
-
print("Check the logs above for the specific model loading error.")
|
433 |
print("="*50 + "\n")
|
434 |
|
435 |
|
|
|
46 |
print(f"Could not load ASR pipeline: {e}. Voice input will be disabled.")
|
47 |
traceback.print_exc() # Print full traceback for debugging
|
48 |
|
49 |
+
# 2. 文本到图像模型 (nota-ai/bk-sdm-tiny) - 资源友好模型
|
50 |
image_generator_pipe = None
|
51 |
+
# 使用 nota-ai/bk-sdm-tiny 模型
|
52 |
+
model_id = "nota-ai/bk-sdm-tiny"
|
53 |
try:
|
54 |
print(f"Loading Text-to-Image pipeline ({model_id}) on CPU...")
|
55 |
+
print("NOTE: Using a small model for resource efficiency. Image quality and details may differ from larger models.")
|
56 |
# 使用 AutoPipelineForText2Image 自动识别模型类型
|
57 |
image_generator_pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float32)
|
58 |
image_generator_pipe = image_generator_pipe.to(device)
|
|
|
155 |
|
156 |
print(f"Generating image on CPU for prompt: {prompt[:100]}...") # Log truncated prompt
|
157 |
# Note: Negative prompt and guidance scale might have less impact or behave differently
|
158 |
+
# on very small models.
|
159 |
print(f"Negative prompt: {negative_prompt}") # Will likely be ignored by tiny model
|
160 |
print(f"Guidance scale: {guidance_scale}, Steps: {num_inference_steps}") # Steps might be fixed internally by tiny model
|
161 |
|
|
|
166 |
with torch.no_grad():
|
167 |
# Seed for reproducibility (optional, but good practice)
|
168 |
# generator = torch.Generator(device=device).manual_seed(int(time.time())) # Tiny model might not use generator param
|
169 |
+
# Call the pipeline - assuming standard parameters are accepted
|
170 |
+
output = image_generator_pipe(
|
171 |
+
prompt=prompt,
|
172 |
+
# It's possible tiny models ignore some parameters, but passing them is safer
|
173 |
+
negative_prompt=negative_prompt,
|
174 |
+
guidance_scale=float(guidance_scale),
|
175 |
+
num_inference_steps=int(num_inference_steps),
|
176 |
+
# generator=generator, # Omit if tiny model pipeline doesn't accept it
|
177 |
+
# height and width might need to be specified or limited for tiny models
|
178 |
+
# height=..., width=...
|
179 |
+
)
|
180 |
+
|
181 |
+
# Access the generated image(s). Assuming standard diffusers output structure (.images[0])
|
182 |
+
if hasattr(output, 'images') and isinstance(output.images, list) and len(output.images) > 0:
|
183 |
+
image = output.images[0] # Access the first image
|
184 |
+
else:
|
185 |
+
# Handle cases where output format is different (less common for AutoPipelines)
|
186 |
+
print("Warning: Pipeline output format unexpected. Attempting to use the output directly.")
|
187 |
+
image = output # Assume output is the image
|
188 |
|
189 |
end_time = time.time()
|
190 |
print(f"Image generated successfully on CPU in {end_time - start_time:.2f} seconds (using {model_id}).")
|
|
|
210 |
try:
|
211 |
# Ensure the pipeline uses the correct device (should be CPU based on loading)
|
212 |
# Ensure input is in expected format for Whisper pipeline (filepath or audio array)
|
213 |
+
if isinstance(audio_file_path, tuple): # Handle case where Gradio might pass tuple
|
214 |
+
# Assuming tuple is (samplerate, numpy_array), need to save to temp file or process directly
|
215 |
+
# For simplicity with type="filepath", assume it passes path directly
|
216 |
+
print("Warning: Audio input was tuple, expecting filepath. This might fail.")
|
217 |
+
# Attempting to process numpy array if it's the second element
|
218 |
+
if isinstance(audio_file_path[1], torch.Tensor) or isinstance(audio_file_path[1], list) or isinstance(audio_file_path[1], (int, float)):
|
219 |
+
# This path is complex, sticking to filepath assumption for now
|
220 |
+
pass # Let the pipeline call below handle potential error
|
221 |
+
audio_input_for_pipeline = audio_file_path # Pass original tuple, let pipeline handle
|
222 |
+
else:
|
223 |
+
audio_input_for_pipeline = audio_file_path # Expected filepath
|
224 |
+
|
225 |
+
transcription = asr_pipeline(audio_input_for_pipeline)["text"]
|
226 |
end_time = time.time()
|
227 |
print(f"Transcription successful in {end_time - start_time:.2f} seconds.")
|
228 |
print(f"Transcription result: {transcription}")
|
|
|
250 |
elif audio_file is not None:
|
251 |
print("Processing audio input...")
|
252 |
try:
|
253 |
+
# transcribe_audio handles different Gradio audio output types potentially
|
254 |
+
transcribed_text, _ = transcribe_audio(audio_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
if "[Error:" in transcribed_text:
|
257 |
# Display transcription error clearly
|
|
|
300 |
if enhanced_prompt and not status_message.startswith("[Error:") and not status_message.startswith("[Prompt Enhancement Error:"):
|
301 |
try:
|
302 |
# Show "Generating..." message while waiting
|
303 |
+
gr.Info(f"Starting image generation on CPU using {model_id}. This should be faster than full SD, but might still take time.")
|
304 |
generated_image = generate_image_cpu(enhanced_prompt, neg_prompt, guidance, steps)
|
305 |
gr.Info("Image generation complete!")
|
306 |
except gr.Error as e:
|
|
|
332 |
style_options = ["cinematic", "photorealistic", "anime", "fantasy art", "cyberpunk", "steampunk", "watercolor", "illustration", "low poly"]
|
333 |
quality_options = ["highly detailed", "sharp focus", "intricate details", "4k", "masterpiece", "best quality", "professional lighting"]
|
334 |
|
335 |
+
# Adjust steps/guidance defaults for a smaller model, still might be ignored by some pipelines
|
336 |
+
default_steps = 20
|
337 |
+
max_steps = 40 # Adjusted max steps
|
338 |
+
default_guidance = 5.0 # Adjusted default guidance
|
339 |
|
340 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
341 |
+
gr.Markdown("# AI Image Generator (CPU Version - Using Small Model)")
|
342 |
gr.Markdown(
|
343 |
"**Enter a short description or use voice input.** The app uses OpenAI (if API key is provided) "
|
344 |
+
f"to create a detailed prompt, then generates an image using a **small model ({model_id}) on the CPU**."
|
345 |
)
|
346 |
+
# Add specific warning about CPU speed and potential resource issues for this specific model
|
347 |
+
gr.HTML("<p style='color:orange;font-weight:bold;'>⚠️ Note: Using a small model for better compatibility on CPU. Generation should be faster than full Stable Diffusion, but quality/details may differ.</p>")
|
348 |
+
gr.HTML("<p style='color:red;font-weight:bold;'>⏰ CPU generation can still take 1-5 minutes per image depending on load and model specifics.</p>")
|
349 |
+
|
350 |
|
351 |
# Display OpenAI availability status
|
352 |
if not openai_available:
|
|
|
354 |
else:
|
355 |
gr.Markdown("**Note:** OpenAI API key found. Prompt will be enhanced using OpenAI.")
|
356 |
|
|
|
357 |
# Display Model loading status
|
358 |
+
# Check against AutoPipelineForText2Image type
|
359 |
if not isinstance(image_generator_pipe, AutoPipelineForText2Image):
|
360 |
+
gr.Markdown(f"**CRITICAL:** Image generation model ({model_id}) failed to load. Image generation is disabled. Check Space logs for details.")
|
361 |
|
362 |
|
363 |
with gr.Row():
|
|
|
373 |
# Using gr.State as a placeholder that holds None
|
374 |
inp_audio = gr.State(None)
|
375 |
|
376 |
+
# --- Controls ---
|
377 |
+
# Note: These controls might have less impact than on larger models
|
378 |
+
gr.Markdown("*(Optional controls - Note: Their impact might vary on this small model)*")
|
379 |
# Control 1: Dropdown
|
380 |
+
inp_style = gr.Dropdown(label="Base Style", choices=style_options, value="cinematic")
|
381 |
# Control 2: Radio
|
382 |
+
inp_quality = gr.Radio(label="Quality Boost", choices=quality_options, value="highly detailed")
|
383 |
# Control 3: Textbox (Negative Prompt)
|
384 |
+
inp_neg_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, text, watermark, signature, deformed")
|
385 |
# Control 4: Slider (Guidance Scale)
|
386 |
+
inp_guidance = gr.Slider(minimum=1.0, maximum=10.0, step=0.5, value=default_guidance, label="Guidance Scale (CFG)") # Lower max guidance
|
387 |
+
# Control 5: Slider (Inference Steps) - Adjusted max/default
|
388 |
+
inp_steps = gr.Slider(minimum=5, maximum=max_steps, step=1, value=default_steps, label=f"Inference Steps (lower = faster but less detail, max {max_steps})") # Lower min steps
|
389 |
|
390 |
# --- Action Button ---
|
391 |
# Disable button if model failed to load
|
|
|
404 |
else:
|
405 |
inputs_list.append(inp_audio) # Pass the gr.State(None) placeholder
|
406 |
|
|
|
407 |
inputs_list.extend([inp_style, inp_quality, inp_neg_prompt, inp_guidance, inp_steps])
|
408 |
|
409 |
# Link button click to processing function
|
|
|
430 |
# ---- Application Launch ----
|
431 |
if __name__ == "__main__":
|
432 |
# Final check before launch
|
433 |
+
# Check against AutoPipelineForText2Image type
|
434 |
if not isinstance(image_generator_pipe, AutoPipelineForText2Image):
|
435 |
print("\n" + "="*50)
|
436 |
print("CRITICAL WARNING:")
|
437 |
print(f"Image generation model ({model_id}) failed to load during startup.")
|
438 |
print("The Gradio UI will launch, but the 'Generate Image' button will be disabled.")
|
439 |
+
print("Check the Space logs above for the specific model loading error.")
|
440 |
print("="*50 + "\n")
|
441 |
|
442 |
|