File size: 16,660 Bytes
c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e c102ebc a63d56e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 |
import gradio as gr
import torch
from transformers import pipeline, set_seed
from diffusers import StableDiffusionPipeline
import openai
import os
import time
import traceback # For detailed error logging
# ---- Configuration & API Key ----
# Check for OpenAI API Key in Hugging Face Secrets
api_key = os.environ.get("OPENAI_API_KEY")
openai_client = None
openai_available = False
if api_key:
try:
openai.api_key = api_key
# Starting with openai v1, client instantiation is preferred
openai_client = openai.OpenAI(api_key=api_key)
# Simple test to check if the key is valid (optional, but good)
# openai_client.models.list() # This call might incur small cost/quota usage
openai_available = True
print("OpenAI API key found and client initialized.")
except Exception as e:
print(f"Error initializing OpenAI client: {e}")
print("Proceeding without OpenAI features.")
else:
print("WARNING: OPENAI_API_KEY secret not found. Prompt enhancement via OpenAI is disabled.")
# Force CPU usage
device = "cpu"
print(f"Using device: {device}")
# ---- Model Loading (CPU Focused) ----
# 1. 语音转文本模型 (Whisper) - 加分项
asr_pipeline = None
try:
print("Loading ASR pipeline (Whisper) on CPU...")
# Force CPU usage with device=-1 or device="cpu"
asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
print("ASR pipeline loaded successfully on CPU.")
except Exception as e:
print(f"Could not load ASR pipeline: {e}. Voice input will be disabled.")
traceback.print_exc() # Print full traceback for debugging
# 2. 文本到图像模型 (Stable Diffusion) - Step 2 (CPU)
image_generator_pipe = None
try:
print("Loading Stable Diffusion pipeline (v1.5) on CPU...")
print("WARNING: Stable Diffusion on CPU is VERY SLOW (expect minutes per image).")
model_id = "runwayml/stable-diffusion-v1-5"
# Use float32 for CPU
image_generator_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
image_generator_pipe = image_generator_pipe.to(device)
print("Stable Diffusion pipeline loaded successfully on CPU.")
except Exception as e:
print(f"CRITICAL: Could not load Stable Diffusion pipeline: {e}. Image generation will fail.")
traceback.print_exc() # Print full traceback for debugging
# Define a dummy object to prevent crashes later if loading failed
class DummyPipe:
def __call__(self, *args, **kwargs):
raise RuntimeError(f"Stable Diffusion model failed to load: {e}")
image_generator_pipe = DummyPipe()
# ---- Core Function Definitions ----
# Step 1: Prompt-to-Prompt (using OpenAI API)
def enhance_prompt_openai(short_prompt, style_modifier="cinematic", quality_boost="photorealistic, highly detailed"):
"""Uses OpenAI API to enhance the short description."""
if not openai_available or not openai_client:
# Fallback or error if OpenAI key is missing/invalid
print("OpenAI not available. Returning original prompt with modifiers.")
return f"{short_prompt}, {style_modifier}, {quality_boost}"
if not short_prompt:
# Return an error message formatted for Gradio output
raise gr.Error("Input description cannot be empty.")
# Construct the prompt for the OpenAI model
system_message = (
"You are an expert prompt engineer for AI image generation models like Stable Diffusion. "
"Expand the user's short description into a detailed, vivid, and coherent prompt. "
"Focus on visual details: subjects, objects, environment, lighting, atmosphere, composition. "
"Incorporate the requested style and quality keywords naturally. Avoid conversational text."
)
user_message = (
f"Enhance this description: \"{short_prompt}\". "
f"Style: '{style_modifier}'. Quality: '{quality_boost}'."
)
print(f"Sending request to OpenAI for prompt enhancement: {short_prompt}")
try:
response = openai_client.chat.completions.create(
model="gpt-3.5-turbo", # Cost-effective choice, can use gpt-4 if needed/key allows
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
],
temperature=0.7, # Controls creativity vs predictability
max_tokens=150, # Limit output length
n=1, # Generate one response
stop=None # Let the model decide when to stop
)
enhanced_prompt = response.choices[0].message.content.strip()
print("OpenAI enhancement successful.")
# Basic cleanup: remove potential quotes around the whole response
if enhanced_prompt.startswith('"') and enhanced_prompt.endswith('"'):
enhanced_prompt = enhanced_prompt[1:-1]
return enhanced_prompt
except openai.AuthenticationError:
print("OpenAI Authentication Error: Invalid API key?")
raise gr.Error("OpenAI Authentication Error: Check your API key.")
except openai.RateLimitError:
print("OpenAI Rate Limit Error: You've exceeded your quota or rate limit.")
raise gr.Error("OpenAI Error: Rate limit exceeded.")
except openai.APIError as e:
print(f"OpenAI API Error: {e}")
raise gr.Error(f"OpenAI API Error: {e}")
except Exception as e:
print(f"An unexpected error occurred during OpenAI call: {e}")
traceback.print_exc()
raise gr.Error(f"Prompt enhancement failed: {e}")
# Step 2: Prompt-to-Image (CPU)
def generate_image_cpu(prompt, negative_prompt, guidance_scale, num_inference_steps):
"""Generates image using Stable Diffusion on CPU."""
if not isinstance(image_generator_pipe, StableDiffusionPipeline):
raise gr.Error("Stable Diffusion model is not available (failed to load).")
if not prompt or "[Error:" in prompt or "Error:" in prompt:
# Check if the prompt itself is an error message from the previous step
raise gr.Error("Cannot generate image due to invalid or missing prompt.")
print(f"Generating image on CPU for prompt: {prompt[:100]}...") # Log truncated prompt
print(f"Negative prompt: {negative_prompt}")
print(f"Guidance scale: {guidance_scale}, Steps: {num_inference_steps}")
start_time = time.time()
try:
# Use torch.inference_mode() or torch.no_grad() for efficiency
with torch.no_grad():
# Seed for reproducibility (optional, but good practice)
generator = torch.Generator(device=device).manual_seed(int(time.time()))
image = image_generator_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_inference_steps),
generator=generator,
).images[0]
end_time = time.time()
print(f"Image generated successfully on CPU in {end_time - start_time:.2f} seconds.")
return image
except Exception as e:
print(f"Error during image generation on CPU: {e}")
traceback.print_exc()
# Propagate error to Gradio UI
raise gr.Error(f"Image generation failed on CPU: {e}")
# Bonus: Voice-to-Text (CPU)
def transcribe_audio(audio_file_path):
"""Transcribes audio to text using Whisper on CPU."""
if not asr_pipeline:
# This case should ideally be handled by hiding the control, but double-check
return "[Error: ASR model not loaded]", audio_file_path
if audio_file_path is None:
return "", audio_file_path # No audio input
print(f"Transcribing audio file: {audio_file_path} on CPU...")
start_time = time.time()
try:
# Ensure the pipeline uses the correct device (should be CPU based on loading)
transcription = asr_pipeline(audio_file_path)["text"]
end_time = time.time()
print(f"Transcription successful in {end_time - start_time:.2f} seconds.")
print(f"Transcription result: {transcription}")
return transcription, audio_file_path
except Exception as e:
print(f"Error during audio transcription on CPU: {e}")
traceback.print_exc()
# Return error message in the expected tuple format
return f"[Error: Transcription failed: {e}]", audio_file_path
# ---- Gradio Application Flow ----
def process_input(input_text, audio_file, style_choice, quality_choice, neg_prompt, guidance, steps):
"""Main function triggered by Gradio button."""
final_text_input = ""
enhanced_prompt = ""
generated_image = None
status_message = "" # To gather status/errors for the prompt box
# 1. Determine Input (Text or Audio)
if input_text and input_text.strip():
final_text_input = input_text.strip()
print(f"Using text input: '{final_text_input}'")
elif audio_file is not None:
print("Processing audio input...")
transcribed_text, _ = transcribe_audio(audio_file)
if "[Error:" in transcribed_text:
# Display transcription error clearly
status_message = transcribed_text
print(status_message)
# Return error in prompt field, no image
return status_message, None
elif transcribed_text:
final_text_input = transcribed_text
print(f"Using transcribed audio input: '{final_text_input}'")
else:
status_message = "[Error: Audio input received but transcription was empty.]"
print(status_message)
return status_message, None # Return error
else:
status_message = "[Error: No input provided. Please enter text or record audio.]"
print(status_message)
return status_message, None # Return error
# 2. Enhance Prompt (using OpenAI if available)
if final_text_input:
try:
enhanced_prompt = enhance_prompt_openai(final_text_input, style_choice, quality_choice)
status_message = enhanced_prompt # Display the prompt
print(f"Enhanced prompt: {enhanced_prompt}")
except gr.Error as e:
# Catch Gradio-specific errors from enhancement function
status_message = f"[Prompt Enhancement Error: {e}]"
print(status_message)
# Return the error, no image generation attempt
return status_message, None
except Exception as e:
# Catch any other unexpected errors
status_message = f"[Unexpected Prompt Enhancement Error: {e}]"
print(status_message)
traceback.print_exc()
return status_message, None
# 3. Generate Image (if prompt is valid)
if enhanced_prompt and not status_message.startswith("[Error:") and not status_message.startswith("[Prompt Enhancement Error:"):
try:
# Show "Generating..." message while waiting
gr.Info("Starting image generation on CPU... This will take a while (possibly several minutes).")
generated_image = generate_image_cpu(enhanced_prompt, neg_prompt, guidance, steps)
gr.Info("Image generation complete!")
except gr.Error as e:
# Catch Gradio errors from generation function
status_message = f"{enhanced_prompt}\n\n[Image Generation Error: {e}]" # Append error to prompt
print(f"Image Generation Error: {e}")
except Exception as e:
status_message = f"{enhanced_prompt}\n\n[Unexpected Image Generation Error: {e}]"
print(f"Unexpected Image Generation Error: {e}")
traceback.print_exc()
# Set image to None explicitly on error
generated_image = None
# 4. Return results to Gradio UI
# Return the status message (enhanced prompt or error) and the image (or None if error)
return status_message, generated_image
# ---- Gradio Interface Construction ----
style_options = ["cinematic", "photorealistic", "anime", "fantasy art", "cyberpunk", "steampunk", "watercolor", "illustration", "low poly"]
quality_options = ["highly detailed", "sharp focus", "intricate details", "4k", "masterpiece", "best quality", "professional lighting"]
# Reduced steps for faster CPU generation attempt
default_steps = 20
max_steps = 50 # Limit max steps on CPU
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# AI Image Generator (CPU Version)")
gr.Markdown(
"**Enter a short description or use voice input.** The app uses OpenAI (if API key is provided) "
"to create a detailed prompt, then generates an image using Stable Diffusion v1.5 **on the CPU**."
)
# Add specific warning about CPU speed
gr.HTML("<p style='color:orange;font-weight:bold;'>⚠️ Warning: Image generation on CPU is very slow! Expect several minutes per image.</p>")
# Display OpenAI availability status
if not openai_available:
gr.Markdown("**Note:** OpenAI API key not found or invalid. Prompt enhancement will use a basic fallback.")
with gr.Row():
with gr.Column(scale=1):
# --- Inputs ---
inp_text = gr.Textbox(label="Enter short description", placeholder="e.g., A cute robot drinking coffee on Mars")
# Only show Audio input if ASR model loaded successfully
if asr_pipeline:
inp_audio = gr.Audio(sources=["microphone"], type="filepath", label="Or record your idea (clears text box if used)")
else:
gr.Markdown("**Voice input disabled:** Whisper model failed to load.")
inp_audio = gr.Textbox(visible=False) # Hidden placeholder
# --- Controls (Step 3 requirements met) ---
# Control 1: Dropdown
inp_style = gr.Dropdown(label="Base Style", choices=style_options, value="cinematic")
# Control 2: Radio
inp_quality = gr.Radio(label="Quality Boost", choices=quality_options, value="highly detailed")
# Control 3: Textbox (Negative Prompt)
inp_neg_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, text, watermark, signature, deformed")
# Control 4: Slider (Guidance Scale)
inp_guidance = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=7.0, label="Guidance Scale (CFG)") # Slightly lower max maybe better for CPU
# Control 5: Slider (Inference Steps) - Reduced max/default
inp_steps = gr.Slider(minimum=10, maximum=max_steps, step=1, value=default_steps, label=f"Inference Steps (lower = faster but less detail, max {max_steps})")
# --- Action Button ---
btn_generate = gr.Button("Generate Image", variant="primary")
with gr.Column(scale=1):
# --- Outputs ---
out_prompt = gr.Textbox(label="Generated Prompt / Status", interactive=False, lines=5) # Show prompt or error status here
out_image = gr.Image(label="Generated Image", type="pil")
# --- Event Handling ---
# Define inputs list carefully, handling potentially invisible audio input
inputs_list = [inp_text]
if asr_pipeline:
inputs_list.append(inp_audio)
else:
inputs_list.append(gr.State(None)) # Pass None if audio control doesn't exist
inputs_list.extend([inp_style, inp_quality, inp_neg_prompt, inp_guidance, inp_steps])
btn_generate.click(
fn=process_input,
inputs=inputs_list,
outputs=[out_prompt, out_image]
)
# Clear text input if audio is used
if asr_pipeline:
def clear_text_on_audio(audio_data):
if audio_data is not None:
return "" # Clear text box
return gr.update() # No change if no audio data
inp_audio.change(fn=clear_text_on_audio, inputs=inp_audio, outputs=inp_text)
# ---- Application Launch ----
if __name__ == "__main__":
# Check again if SD loaded, maybe prevent launch? Or let it run and fail gracefully in UI.
if not isinstance(image_generator_pipe, StableDiffusionPipeline):
print("CRITICAL FAILURE: Stable Diffusion pipeline did not load. The application UI will load, but image generation WILL NOT WORK.")
# Optionally, you could raise an error here to stop the script if SD is essential
# raise RuntimeError("Failed to load Stable Diffusion pipeline, cannot start application.")
# Launch the Gradio app
demo.launch(share=False) # share=True generates a public link if run locally |