File size: 16,660 Bytes
c102ebc
 
 
 
a63d56e
c102ebc
 
a63d56e
c102ebc
a63d56e
 
 
 
 
c102ebc
a63d56e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c102ebc
 
a63d56e
 
c102ebc
 
 
a63d56e
 
 
 
c102ebc
 
a63d56e
c102ebc
a63d56e
c102ebc
 
a63d56e
 
c102ebc
a63d56e
 
c102ebc
a63d56e
c102ebc
a63d56e
 
 
c102ebc
 
a63d56e
c102ebc
 
 
a63d56e
c102ebc
a63d56e
 
 
 
 
 
 
c102ebc
a63d56e
 
 
 
 
 
 
 
 
c102ebc
a63d56e
 
 
 
 
 
c102ebc
 
a63d56e
 
 
 
 
 
 
 
 
 
c102ebc
a63d56e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c102ebc
a63d56e
 
 
 
c102ebc
a63d56e
 
 
c102ebc
a63d56e
 
 
c102ebc
 
a63d56e
c102ebc
 
a63d56e
c102ebc
 
a63d56e
 
 
 
 
 
 
 
 
 
 
 
 
c102ebc
 
a63d56e
 
 
 
c102ebc
 
a63d56e
c102ebc
a63d56e
c102ebc
a63d56e
 
c102ebc
a63d56e
c102ebc
a63d56e
 
c102ebc
a63d56e
c102ebc
a63d56e
 
c102ebc
a63d56e
c102ebc
a63d56e
 
 
c102ebc
 
 
a63d56e
c102ebc
 
a63d56e
c102ebc
a63d56e
 
 
c102ebc
a63d56e
c102ebc
 
a63d56e
c102ebc
a63d56e
c102ebc
a63d56e
 
 
 
 
 
 
c102ebc
a63d56e
c102ebc
a63d56e
 
 
c102ebc
a63d56e
 
 
c102ebc
a63d56e
 
c102ebc
a63d56e
 
 
c102ebc
a63d56e
 
 
 
 
c102ebc
a63d56e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c102ebc
a63d56e
 
 
c102ebc
 
a63d56e
c102ebc
a63d56e
 
 
 
 
 
c102ebc
 
a63d56e
 
 
 
 
 
 
 
 
 
 
c102ebc
 
 
a63d56e
 
 
 
 
 
 
 
 
 
 
 
 
 
c102ebc
a63d56e
 
 
 
 
 
 
 
c102ebc
 
 
a63d56e
 
 
 
 
 
 
 
 
 
 
 
 
c102ebc
 
 
a63d56e
c102ebc
 
 
a63d56e
c102ebc
 
 
a63d56e
 
c102ebc
 
 
a63d56e
c102ebc
a63d56e
 
 
 
 
c102ebc
a63d56e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import gradio as gr
import torch
from transformers import pipeline, set_seed
from diffusers import StableDiffusionPipeline
import openai
import os
import time
import traceback # For detailed error logging

# ---- Configuration & API Key ----
# Check for OpenAI API Key in Hugging Face Secrets
api_key = os.environ.get("OPENAI_API_KEY")
openai_client = None
openai_available = False

if api_key:
    try:
        openai.api_key = api_key
        # Starting with openai v1, client instantiation is preferred
        openai_client = openai.OpenAI(api_key=api_key)
        # Simple test to check if the key is valid (optional, but good)
        # openai_client.models.list() # This call might incur small cost/quota usage
        openai_available = True
        print("OpenAI API key found and client initialized.")
    except Exception as e:
        print(f"Error initializing OpenAI client: {e}")
        print("Proceeding without OpenAI features.")
else:
    print("WARNING: OPENAI_API_KEY secret not found. Prompt enhancement via OpenAI is disabled.")

# Force CPU usage
device = "cpu"
print(f"Using device: {device}")

# ---- Model Loading (CPU Focused) ----

# 1. 语音转文本模型 (Whisper) - 加分项
asr_pipeline = None
try:
    print("Loading ASR pipeline (Whisper) on CPU...")
    # Force CPU usage with device=-1 or device="cpu"
    asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
    print("ASR pipeline loaded successfully on CPU.")
except Exception as e:
    print(f"Could not load ASR pipeline: {e}. Voice input will be disabled.")
    traceback.print_exc() # Print full traceback for debugging

# 2. 文本到图像模型 (Stable Diffusion) - Step 2 (CPU)
image_generator_pipe = None
try:
    print("Loading Stable Diffusion pipeline (v1.5) on CPU...")
    print("WARNING: Stable Diffusion on CPU is VERY SLOW (expect minutes per image).")
    model_id = "runwayml/stable-diffusion-v1-5"
    # Use float32 for CPU
    image_generator_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
    image_generator_pipe = image_generator_pipe.to(device)
    print("Stable Diffusion pipeline loaded successfully on CPU.")
except Exception as e:
    print(f"CRITICAL: Could not load Stable Diffusion pipeline: {e}. Image generation will fail.")
    traceback.print_exc() # Print full traceback for debugging
    # Define a dummy object to prevent crashes later if loading failed
    class DummyPipe:
        def __call__(self, *args, **kwargs):
             raise RuntimeError(f"Stable Diffusion model failed to load: {e}")
    image_generator_pipe = DummyPipe()


# ---- Core Function Definitions ----

# Step 1: Prompt-to-Prompt (using OpenAI API)
def enhance_prompt_openai(short_prompt, style_modifier="cinematic", quality_boost="photorealistic, highly detailed"):
    """Uses OpenAI API to enhance the short description."""
    if not openai_available or not openai_client:
        # Fallback or error if OpenAI key is missing/invalid
        print("OpenAI not available. Returning original prompt with modifiers.")
        return f"{short_prompt}, {style_modifier}, {quality_boost}"
    if not short_prompt:
        # Return an error message formatted for Gradio output
        raise gr.Error("Input description cannot be empty.")

    # Construct the prompt for the OpenAI model
    system_message = (
        "You are an expert prompt engineer for AI image generation models like Stable Diffusion. "
        "Expand the user's short description into a detailed, vivid, and coherent prompt. "
        "Focus on visual details: subjects, objects, environment, lighting, atmosphere, composition. "
        "Incorporate the requested style and quality keywords naturally. Avoid conversational text."
    )
    user_message = (
        f"Enhance this description: \"{short_prompt}\". "
        f"Style: '{style_modifier}'. Quality: '{quality_boost}'."
    )

    print(f"Sending request to OpenAI for prompt enhancement: {short_prompt}")

    try:
        response = openai_client.chat.completions.create(
            model="gpt-3.5-turbo", # Cost-effective choice, can use gpt-4 if needed/key allows
            messages=[
                {"role": "system", "content": system_message},
                {"role": "user", "content": user_message},
            ],
            temperature=0.7, # Controls creativity vs predictability
            max_tokens=150, # Limit output length
            n=1, # Generate one response
            stop=None # Let the model decide when to stop
        )
        enhanced_prompt = response.choices[0].message.content.strip()
        print("OpenAI enhancement successful.")
        # Basic cleanup: remove potential quotes around the whole response
        if enhanced_prompt.startswith('"') and enhanced_prompt.endswith('"'):
            enhanced_prompt = enhanced_prompt[1:-1]
        return enhanced_prompt
    except openai.AuthenticationError:
        print("OpenAI Authentication Error: Invalid API key?")
        raise gr.Error("OpenAI Authentication Error: Check your API key.")
    except openai.RateLimitError:
         print("OpenAI Rate Limit Error: You've exceeded your quota or rate limit.")
         raise gr.Error("OpenAI Error: Rate limit exceeded.")
    except openai.APIError as e:
        print(f"OpenAI API Error: {e}")
        raise gr.Error(f"OpenAI API Error: {e}")
    except Exception as e:
        print(f"An unexpected error occurred during OpenAI call: {e}")
        traceback.print_exc()
        raise gr.Error(f"Prompt enhancement failed: {e}")


# Step 2: Prompt-to-Image (CPU)
def generate_image_cpu(prompt, negative_prompt, guidance_scale, num_inference_steps):
    """Generates image using Stable Diffusion on CPU."""
    if not isinstance(image_generator_pipe, StableDiffusionPipeline):
         raise gr.Error("Stable Diffusion model is not available (failed to load).")
    if not prompt or "[Error:" in prompt or "Error:" in prompt:
        # Check if the prompt itself is an error message from the previous step
        raise gr.Error("Cannot generate image due to invalid or missing prompt.")

    print(f"Generating image on CPU for prompt: {prompt[:100]}...") # Log truncated prompt
    print(f"Negative prompt: {negative_prompt}")
    print(f"Guidance scale: {guidance_scale}, Steps: {num_inference_steps}")
    start_time = time.time()

    try:
        # Use torch.inference_mode() or torch.no_grad() for efficiency
        with torch.no_grad():
             # Seed for reproducibility (optional, but good practice)
             generator = torch.Generator(device=device).manual_seed(int(time.time()))
             image = image_generator_pipe(
                 prompt=prompt,
                 negative_prompt=negative_prompt,
                 guidance_scale=float(guidance_scale),
                 num_inference_steps=int(num_inference_steps),
                 generator=generator,
             ).images[0]
        end_time = time.time()
        print(f"Image generated successfully on CPU in {end_time - start_time:.2f} seconds.")
        return image
    except Exception as e:
        print(f"Error during image generation on CPU: {e}")
        traceback.print_exc()
        # Propagate error to Gradio UI
        raise gr.Error(f"Image generation failed on CPU: {e}")


# Bonus: Voice-to-Text (CPU)
def transcribe_audio(audio_file_path):
    """Transcribes audio to text using Whisper on CPU."""
    if not asr_pipeline:
        # This case should ideally be handled by hiding the control, but double-check
        return "[Error: ASR model not loaded]", audio_file_path
    if audio_file_path is None:
        return "", audio_file_path # No audio input

    print(f"Transcribing audio file: {audio_file_path} on CPU...")
    start_time = time.time()
    try:
        # Ensure the pipeline uses the correct device (should be CPU based on loading)
        transcription = asr_pipeline(audio_file_path)["text"]
        end_time = time.time()
        print(f"Transcription successful in {end_time - start_time:.2f} seconds.")
        print(f"Transcription result: {transcription}")
        return transcription, audio_file_path
    except Exception as e:
        print(f"Error during audio transcription on CPU: {e}")
        traceback.print_exc()
        # Return error message in the expected tuple format
        return f"[Error: Transcription failed: {e}]", audio_file_path


# ---- Gradio Application Flow ----

def process_input(input_text, audio_file, style_choice, quality_choice, neg_prompt, guidance, steps):
    """Main function triggered by Gradio button."""
    final_text_input = ""
    enhanced_prompt = ""
    generated_image = None
    status_message = "" # To gather status/errors for the prompt box

    # 1. Determine Input (Text or Audio)
    if input_text and input_text.strip():
        final_text_input = input_text.strip()
        print(f"Using text input: '{final_text_input}'")
    elif audio_file is not None:
        print("Processing audio input...")
        transcribed_text, _ = transcribe_audio(audio_file)
        if "[Error:" in transcribed_text:
            # Display transcription error clearly
            status_message = transcribed_text
            print(status_message)
            # Return error in prompt field, no image
            return status_message, None
        elif transcribed_text:
            final_text_input = transcribed_text
            print(f"Using transcribed audio input: '{final_text_input}'")
        else:
            status_message = "[Error: Audio input received but transcription was empty.]"
            print(status_message)
            return status_message, None # Return error
    else:
        status_message = "[Error: No input provided. Please enter text or record audio.]"
        print(status_message)
        return status_message, None # Return error

    # 2. Enhance Prompt (using OpenAI if available)
    if final_text_input:
        try:
            enhanced_prompt = enhance_prompt_openai(final_text_input, style_choice, quality_choice)
            status_message = enhanced_prompt # Display the prompt
            print(f"Enhanced prompt: {enhanced_prompt}")
        except gr.Error as e:
            # Catch Gradio-specific errors from enhancement function
            status_message = f"[Prompt Enhancement Error: {e}]"
            print(status_message)
            # Return the error, no image generation attempt
            return status_message, None
        except Exception as e:
             # Catch any other unexpected errors
             status_message = f"[Unexpected Prompt Enhancement Error: {e}]"
             print(status_message)
             traceback.print_exc()
             return status_message, None

    # 3. Generate Image (if prompt is valid)
    if enhanced_prompt and not status_message.startswith("[Error:") and not status_message.startswith("[Prompt Enhancement Error:"):
        try:
            # Show "Generating..." message while waiting
            gr.Info("Starting image generation on CPU... This will take a while (possibly several minutes).")
            generated_image = generate_image_cpu(enhanced_prompt, neg_prompt, guidance, steps)
            gr.Info("Image generation complete!")
        except gr.Error as e:
            # Catch Gradio errors from generation function
            status_message = f"{enhanced_prompt}\n\n[Image Generation Error: {e}]" # Append error to prompt
            print(f"Image Generation Error: {e}")
        except Exception as e:
             status_message = f"{enhanced_prompt}\n\n[Unexpected Image Generation Error: {e}]"
             print(f"Unexpected Image Generation Error: {e}")
             traceback.print_exc()
             # Set image to None explicitly on error
             generated_image = None

    # 4. Return results to Gradio UI
    # Return the status message (enhanced prompt or error) and the image (or None if error)
    return status_message, generated_image


# ---- Gradio Interface Construction ----

style_options = ["cinematic", "photorealistic", "anime", "fantasy art", "cyberpunk", "steampunk", "watercolor", "illustration", "low poly"]
quality_options = ["highly detailed", "sharp focus", "intricate details", "4k", "masterpiece", "best quality", "professional lighting"]

# Reduced steps for faster CPU generation attempt
default_steps = 20
max_steps = 50 # Limit max steps on CPU

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# AI Image Generator (CPU Version)")
    gr.Markdown(
        "**Enter a short description or use voice input.** The app uses OpenAI (if API key is provided) "
        "to create a detailed prompt, then generates an image using Stable Diffusion v1.5 **on the CPU**."
    )
    # Add specific warning about CPU speed
    gr.HTML("<p style='color:orange;font-weight:bold;'>⚠️ Warning: Image generation on CPU is very slow! Expect several minutes per image.</p>")

    # Display OpenAI availability status
    if not openai_available:
        gr.Markdown("**Note:** OpenAI API key not found or invalid. Prompt enhancement will use a basic fallback.")

    with gr.Row():
        with gr.Column(scale=1):
            # --- Inputs ---
            inp_text = gr.Textbox(label="Enter short description", placeholder="e.g., A cute robot drinking coffee on Mars")

            # Only show Audio input if ASR model loaded successfully
            if asr_pipeline:
                inp_audio = gr.Audio(sources=["microphone"], type="filepath", label="Or record your idea (clears text box if used)")
            else:
                gr.Markdown("**Voice input disabled:** Whisper model failed to load.")
                inp_audio = gr.Textbox(visible=False) # Hidden placeholder

            # --- Controls (Step 3 requirements met) ---
            # Control 1: Dropdown
            inp_style = gr.Dropdown(label="Base Style", choices=style_options, value="cinematic")
            # Control 2: Radio
            inp_quality = gr.Radio(label="Quality Boost", choices=quality_options, value="highly detailed")
            # Control 3: Textbox (Negative Prompt)
            inp_neg_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, text, watermark, signature, deformed")
            # Control 4: Slider (Guidance Scale)
            inp_guidance = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=7.0, label="Guidance Scale (CFG)") # Slightly lower max maybe better for CPU
            # Control 5: Slider (Inference Steps) - Reduced max/default
            inp_steps = gr.Slider(minimum=10, maximum=max_steps, step=1, value=default_steps, label=f"Inference Steps (lower = faster but less detail, max {max_steps})")

            # --- Action Button ---
            btn_generate = gr.Button("Generate Image", variant="primary")

        with gr.Column(scale=1):
            # --- Outputs ---
            out_prompt = gr.Textbox(label="Generated Prompt / Status", interactive=False, lines=5) # Show prompt or error status here
            out_image = gr.Image(label="Generated Image", type="pil")

    # --- Event Handling ---
    # Define inputs list carefully, handling potentially invisible audio input
    inputs_list = [inp_text]
    if asr_pipeline:
        inputs_list.append(inp_audio)
    else:
         inputs_list.append(gr.State(None)) # Pass None if audio control doesn't exist

    inputs_list.extend([inp_style, inp_quality, inp_neg_prompt, inp_guidance, inp_steps])

    btn_generate.click(
        fn=process_input,
        inputs=inputs_list,
        outputs=[out_prompt, out_image]
    )

    # Clear text input if audio is used
    if asr_pipeline:
        def clear_text_on_audio(audio_data):
            if audio_data is not None:
                return "" # Clear text box
            return gr.update() # No change if no audio data
        inp_audio.change(fn=clear_text_on_audio, inputs=inp_audio, outputs=inp_text)


# ---- Application Launch ----
if __name__ == "__main__":
    # Check again if SD loaded, maybe prevent launch? Or let it run and fail gracefully in UI.
    if not isinstance(image_generator_pipe, StableDiffusionPipeline):
        print("CRITICAL FAILURE: Stable Diffusion pipeline did not load. The application UI will load, but image generation WILL NOT WORK.")
        # Optionally, you could raise an error here to stop the script if SD is essential
        # raise RuntimeError("Failed to load Stable Diffusion pipeline, cannot start application.")

    # Launch the Gradio app
    demo.launch(share=False) # share=True generates a public link if run locally