File size: 13,035 Bytes
c102ebc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import gradio as gr
import torch
from transformers import pipeline, set_seed
from diffusers import StableDiffusionPipeline
import os
import time

# ---- 配置与模型加载 (在应用启动时加载一次) ----

# 检查是否有可用的GPU,否则使用CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# 1. 语音转文本模型 (Whisper) - 加分项
asr_pipeline = None
try:
    print("Loading ASR pipeline (Whisper)...")
    # 使用较小的模型以节省资源,可根据需要替换 openai/whisper-medium 或 large
    # 在不需要GPU的应用部分可以强制使用CPU
    asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device if device == "cuda" else -1) # whisper在CPU上也可以运行
    print("ASR pipeline loaded.")
except Exception as e:
    print(f"Could not load ASR pipeline: {e}. Voice input will be disabled.")

# 2. 提示词增强模型 (LLM) - Step 1
prompt_enhancer_pipeline = None
try:
    print("Loading Prompt Enhancer pipeline (GPT-2)...")
    # 使用 GPT-2 作为示例,实际应用中建议使用更强大的指令微调模型如 Mistral 或 Llama
    # 注意:GPT-2 可能不会生成特别高质量的SD提示词,这里仅作结构演示
    # 如果资源允许,可以替换为 'mistralai/Mistral-7B-Instruct-v0.1' 等,但需要更多内存/GPU
    prompt_enhancer_pipeline = pipeline("text-generation", model="gpt2", device=device if device == "cuda" else -1) # text-generation在CPU上也可以运行
    print("Prompt Enhancer pipeline loaded.")
except Exception as e:
    print(f"Could not load Prompt Enhancer pipeline: {e}. Prompt enhancement might fail.")

# 3. 文本到图像模型 (Stable Diffusion) - Step 2
image_generator_pipe = None
try:
    print("Loading Stable Diffusion pipeline (v1.5)...")
    model_id = "runwayml/stable-diffusion-v1-5"
    image_generator_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
    image_generator_pipe = image_generator_pipe.to(device)
    # 如果内存不足,可以启用CPU offloading (需要 accelerate库)
    # image_generator_pipe.enable_model_cpu_offload()
    print("Stable Diffusion pipeline loaded.")
except Exception as e:
    print(f"Could not load Stable Diffusion pipeline: {e}. Image generation will fail.")
    # 如果模型加载失败,创建一个虚拟对象以避免后续代码出错
    class DummyPipe:
        def __call__(self, *args, **kwargs):
            # 返回一个占位符错误信息或图像
            raise RuntimeError(f"Stable Diffusion model failed to load: {e}")
    image_generator_pipe = DummyPipe()


# ---- 核心功能函数 ----

# Step 1: Prompt-to-Prompt
def enhance_prompt(short_prompt, style_modifier="cinematic", quality_boost="photorealistic, highly detailed"):
    """使用LLM增强简短描述"""
    if not prompt_enhancer_pipeline:
        return f"[Error: LLM not loaded] Original prompt: {short_prompt}"
    if not short_prompt:
        return "[Error: Input description is empty]"

    # 构建给LLM的指令
    # 注意:这个指令对GPT-2来说可能太复杂,对Mistral等更有效
    input_text = (
        f"Generate a detailed and vivid prompt for an AI image generator based on the following description. "
        f"Incorporate the style '{style_modifier}' and quality boost '{quality_boost}'. "
        f"Focus on visual details, lighting, composition, and mood. "
        f"Description: \"{short_prompt}\"\n\n"
        f"Detailed Prompt:"
    )

    try:
        # 设置种子以获得可复现的(某种程度上的)结果
        set_seed(int(time.time()))
        # max_length 控制生成文本的总长度 (包括输入)
        # num_return_sequences 返回多少个结果
        # temperature 控制随机性,较低的值更保守
        # no_repeat_ngram_size 避免重复短语
        outputs = prompt_enhancer_pipeline(
            input_text,
            max_length=150, # 限制输出长度,避免过长
            num_return_sequences=1,
            temperature=0.7,
            no_repeat_ngram_size=2,
            pad_token_id=prompt_enhancer_pipeline.tokenizer.eos_token_id # 避免padding warning
        )
        generated_text = outputs[0]['generated_text']

        # 从LLM的完整输出中提取增强后的提示词部分
        # 简单方法:取 "Detailed Prompt:" 之后的内容
        enhanced = generated_text.split("Detailed Prompt:")[-1].strip()
        # 进一步清理可能包含的原始输入或指令痕迹
        if short_prompt in enhanced[:len(short_prompt)+5]: # 如果开头包含原始输入
             enhanced = enhanced.replace(short_prompt, "", 1).strip(' ,"')

        # 添加基础的风格和质量词,如果LLM没有包含的话
        if style_modifier not in enhanced:
            enhanced += f", {style_modifier}"
        if quality_boost not in enhanced:
             enhanced += f", {quality_boost}"

        return enhanced

    except Exception as e:
        print(f"Error during prompt enhancement: {e}")
        return f"[Error: Prompt enhancement failed] Original prompt: {short_prompt}"

# Step 2: Prompt-to-Image
def generate_image(prompt, negative_prompt, guidance_scale, num_inference_steps):
    """使用Stable Diffusion生成图像"""
    if not isinstance(image_generator_pipe, StableDiffusionPipeline):
         raise gr.Error(f"Stable Diffusion model is not available. Load error: {image_generator_pipe}") # 使用gr.Error在UI上显示错误
    if not prompt or "[Error:" in prompt:
        raise gr.Error("Cannot generate image due to invalid or missing prompt.")

    print(f"Generating image for prompt: {prompt}")
    print(f"Negative prompt: {negative_prompt}")
    print(f"Guidance scale: {guidance_scale}, Steps: {num_inference_steps}")

    try:
        # 设置随机种子
        generator = torch.Generator(device=device).manual_seed(int(time.time()))

        # 执行推理
        with torch.inference_mode(): # 节省内存
            image = image_generator_pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                guidance_scale=float(guidance_scale),
                num_inference_steps=int(num_inference_steps),
                generator=generator
            ).images[0]
        print("Image generated successfully.")
        return image
    except Exception as e:
        print(f"Error during image generation: {e}")
        # 将底层错误传递给 Gradio,使其能在 UI 中显示
        raise gr.Error(f"Image generation failed: {e}")


# Bonus: Voice-to-Text
def transcribe_audio(audio_file_path):
    """将音频文件转录为文本"""
    if not asr_pipeline:
        return "[Error: ASR model not loaded]", "" # 返回错误信息和空路径
    if audio_file_path is None:
        return "", "" # 没有音频输入

    print(f"Transcribing audio file: {audio_file_path}")
    try:
        # 转录音频
        transcription = asr_pipeline(audio_file_path)["text"]
        print(f"Transcription result: {transcription}")
        return transcription, audio_file_path # 返回文本和路径(可能用于显示)
    except Exception as e:
        print(f"Error during audio transcription: {e}")
        return f"[Error: Transcription failed: {e}]", audio_file_path


# ---- Gradio 应用流程 ----

def process_input(input_text, audio_file, style_choice, quality_choice, neg_prompt, guidance, steps):
    """处理输入(文本或语音),生成提示词和图像"""
    final_text_input = ""
    transcription_source = "" # 用于标记来源

    # 优先使用文本框输入
    if input_text and input_text.strip():
        final_text_input = input_text.strip()
        transcription_source = " (from text input)"
    # 如果文本框为空,且有音频文件,则使用语音输入
    elif audio_file is not None:
        transcribed_text, _ = transcribe_audio(audio_file)
        if transcribed_text and "[Error:" not in transcribed_text:
            final_text_input = transcribed_text
            transcription_source = " (from audio input)"
        elif "[Error:" in transcribed_text:
             # 如果语音识别出错,直接返回错误信息
             return transcribed_text, None # 返回错误提示,不生成图像
        else:
             # 音频为空或识别为空
             return "[Error: Please provide input via text or voice]", None
    else:
        # 没有有效输入
        return "[Error: Please provide input via text or voice]", None

    print(f"Using input: '{final_text_input}'{transcription_source}")

    # Step 1: Enhance prompt
    enhanced_prompt = enhance_prompt(final_text_input, style_modifier=style_choice, quality_boost=quality_choice)
    print(f"Enhanced prompt: {enhanced_prompt}")

    # Step 2: Generate image (如果提示词增强成功)
    generated_image = None
    if "[Error:" not in enhanced_prompt:
        try:
            generated_image = generate_image(enhanced_prompt, neg_prompt, guidance, steps)
        except gr.Error as e:
            # 如果 generate_image 抛出 gr.Error,将其信息作为 enhanced_prompt 返回给UI
            enhanced_prompt = f"{enhanced_prompt}\n\n[Image Generation Error: {e}]"
            # 不再尝试显示图片
        except Exception as e:
            # 捕获其他意外错误
             enhanced_prompt = f"{enhanced_prompt}\n\n[Unexpected Image Generation Error: {e}]"

    # 返回结果给Gradio界面
    return enhanced_prompt, generated_image


# ---- Gradio 界面构建 (Step 3: Controls & Step 4: Layout) ----

# 定义可选的风格和质量提升选项 (用于Dropdown/Radio)
style_options = ["cinematic", "photorealistic", "anime", "fantasy art", "cyberpunk", "steampunk", "watercolor"]
quality_options = ["highly detailed", "sharp focus", "intricate details", "4k", "masterpiece", "best quality"]

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# AI Image Generator: From Idea to Image")
    gr.Markdown("Enter a short description (or use voice input), and the app will enhance it into a detailed prompt and generate an image using Stable Diffusion.")

    with gr.Row():
        with gr.Column(scale=1):
            # 输入区域
            inp_text = gr.Textbox(label="Enter short description here", placeholder="e.g., A magical treehouse in the sky")

            # 加分项:语音输入控件
            inp_audio = gr.Audio(sources=["microphone"], type="filepath", label="Or record your idea (clears text box if used)", visible=asr_pipeline is not None) # 只有ASR加载成功才显示

            # Step 3: 使用不同控件
            # 控件1: Dropdown (下拉菜单)
            inp_style = gr.Dropdown(label="Choose Base Style", choices=style_options, value="cinematic")

            # 控件2: Radio (单选框) - 也可以用 CheckboxGroup 实现多选
            inp_quality = gr.Radio(label="Quality Boost", choices=quality_options, value="highly detailed")

            # 控件3: Textbox (用于Negative Prompt)
            inp_neg_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, text, watermark")

            # 控件4: Slider (滑块)
            inp_guidance = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=7.5, label="Guidance Scale (CFG)")

            # 控件5: Slider (滑块)
            inp_steps = gr.Slider(minimum=10, maximum=100, step=1, value=30, label="Inference Steps")

            # 提交按钮
            btn_generate = gr.Button("Generate Image", variant="primary")

        with gr.Column(scale=1):
            # 输出区域
            out_prompt = gr.Textbox(label="Generated Prompt", interactive=False) # 输出文本框不可编辑
            out_image = gr.Image(label="Generated Image", type="pil") # 输出图像

    # 设置按钮点击事件
    btn_generate.click(
        fn=process_input,
        inputs=[inp_text, inp_audio, inp_style, inp_quality, inp_neg_prompt, inp_guidance, inp_steps],
        outputs=[out_prompt, out_image]
    )

    # (可选) 当用户录音后,可以自动清空文本框,以明确优先使用语音
    if asr_pipeline:
        def clear_text_on_audio(audio_data):
            if audio_data is not None:
                return "" # 返回空字符串清空文本框
            return gr.update() # 否则不改变文本框内容 (gr.update()是占位符)
        inp_audio.change(fn=clear_text_on_audio, inputs=inp_audio, outputs=inp_text)


# ---- 启动应用 ----
if __name__ == "__main__":
    # 设置Hugging Face Hub Token (如果需要从私有仓库加载模型)
    # from huggingface_hub import login
    # login("YOUR_HF_TOKEN") # 在本地运行时取消注释并替换

    # 在Hugging Face Spaces上运行时,端口通常由平台管理
    # share=True 会创建一个公共链接 (如果在本地运行需要)
    demo.launch(share=False)