tangchao5355 commited on
Commit
925ae35
·
verified ·
1 Parent(s): b7205d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -41
app.py CHANGED
@@ -1,83 +1,126 @@
1
- # app.py
2
  import gradio as gr
3
  import torch
4
  from transformers import pipeline, AutoTokenizer, T5ForConditionalGeneration
5
  from diffusers import StableDiffusionPipeline
6
  import speech_recognition as sr
7
- from io import BytesIO
 
8
 
9
- # ========== Step 1: Prompt Enhancement ==========
10
- prompt_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
11
- prompt_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
 
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def enhance_prompt(raw_input, style_choice):
14
  template = f"Generate a detailed Stable Diffusion prompt about: {raw_input} in {style_choice} style."
15
  inputs = prompt_tokenizer(template, return_tensors="pt")
16
  outputs = prompt_model.generate(inputs.input_ids, max_length=100)
17
  return prompt_tokenizer.decode(outputs[0], skip_special_tokens=True)
18
 
19
- # ========== Step 2: Image Generation ==========
20
- sd_pipe = StableDiffusionPipeline.from_pretrained(
21
- "runwayml/stable-diffusion-v1-5",
22
- torch_dtype=torch.float32,
23
- use_safetensors=True
24
- )
25
- sd_pipe.enable_attention_slicing() # 降低内存消耗
26
-
27
  def generate_image(enhanced_prompt, steps=20, guidance=7.5):
28
- return sd_pipe(
29
- enhanced_prompt,
30
- num_inference_steps=int(steps),
31
- guidance_scale=guidance,
32
- generator=torch.Generator().manual_seed(42)
33
- ).images[0]
 
 
 
 
 
 
 
34
 
35
- # ========== Step 3: Voice Input ==========
 
 
 
 
 
 
 
 
 
 
36
  recognizer = sr.Recognizer()
37
 
38
  def audio_to_text(audio_file):
39
- with sr.AudioFile(audio_file) as source:
40
- audio = recognizer.record(source)
41
- return recognizer.recognize_whisper(audio, model="tiny.en")
 
 
 
 
 
 
42
 
43
- # ========== Gradio Interface ==========
44
- with gr.Blocks(title="AI Art Studio") as app:
45
- gr.Markdown("## 🎨 AI Art Generator (CPU Optimized)")
46
 
47
  with gr.Row():
48
  with gr.Column(scale=2):
49
- # ===== 交互控件 =====
50
- input_type = gr.Radio(["Text", "Voice"], label="输入方式")
51
- voice_input = gr.Audio(source="upload", type="filepath", visible=False)
52
- text_input = gr.Textbox(label="输入描述", placeholder="描述你想生成的画面...")
 
 
 
 
 
 
53
 
 
54
  style_choice = gr.Dropdown(
55
- ["Digital Art", "Oil Painting", "Anime", "Photorealistic"],
56
- value="Digital Art",
57
  label="艺术风格"
58
  )
59
 
 
60
  generate_btn = gr.Button("生成作品", variant="primary")
61
 
 
62
  with gr.Accordion("高级设置", open=False):
63
  steps_slider = gr.Slider(10, 30, value=20, step=1, label="生成步数")
64
  guidance_slider = gr.Slider(5.0, 10.0, value=7.5, label="创意自由度")
65
-
66
  with gr.Column(scale=3):
67
- # ===== 输出展示 =====
68
  prompt_output = gr.Textbox(label="优化后的Prompt", interactive=False)
69
- image_output = gr.Image(label="生成结果", show_label=False)
70
 
71
- # ===== 交互逻辑 =====
72
  input_type.change(
73
- fn=lambda x: (gr.update(visible=x=="Voice"), gr.update(visible=x=="Text")),
74
  inputs=input_type,
75
- outputs=[voice_input, text_input]
76
  )
77
 
78
  generate_btn.click(
79
- fn=lambda x,t: audio_to_text(x) if t=="Voice" else t,
80
- inputs=[voice_input, input_type],
81
  outputs=text_input
82
  ).success(
83
  fn=enhance_prompt,
@@ -89,6 +132,5 @@ with gr.Blocks(title="AI Art Studio") as app:
89
  outputs=image_output
90
  )
91
 
92
- # ========== Step 4: Huggingface Deployment ==========
93
  if __name__ == "__main__":
94
  app.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import pipeline, AutoTokenizer, T5ForConditionalGeneration
4
  from diffusers import StableDiffusionPipeline
5
  import speech_recognition as sr
6
+ import gc
7
+ from accelerate import init_empty_weights
8
 
9
+ # ===== 模型初始化 =====
10
+ def load_models():
11
+ # Prompt增强模型
12
+ prompt_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
13
+ prompt_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
14
 
15
+ # Stable Diffusion管道
16
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
17
+ "runwayml/stable-diffusion-v1-5",
18
+ torch_dtype=torch.float32,
19
+ use_safetensors=True,
20
+ variant="fp16",
21
+ device_map="auto",
22
+ offload_state_dict=True
23
+ )
24
+ sd_pipe.enable_attention_slicing()
25
+ sd_pipe.enable_sequential_cpu_offload()
26
+
27
+ return prompt_model, prompt_tokenizer, sd_pipe
28
+
29
+ prompt_model, prompt_tokenizer, sd_pipe = load_models()
30
+
31
+ # ===== 核心功能 =====
32
  def enhance_prompt(raw_input, style_choice):
33
  template = f"Generate a detailed Stable Diffusion prompt about: {raw_input} in {style_choice} style."
34
  inputs = prompt_tokenizer(template, return_tensors="pt")
35
  outputs = prompt_model.generate(inputs.input_ids, max_length=100)
36
  return prompt_tokenizer.decode(outputs[0], skip_special_tokens=True)
37
 
 
 
 
 
 
 
 
 
38
  def generate_image(enhanced_prompt, steps=20, guidance=7.5):
39
+ try:
40
+ image = sd_pipe(
41
+ enhanced_prompt,
42
+ num_inference_steps=int(steps),
43
+ guidance_scale=guidance,
44
+ generator=torch.Generator().manual_seed(42)
45
+ ).images[0]
46
+ finally:
47
+ # 清理内存
48
+ gc.collect()
49
+ with init_empty_weights():
50
+ reload_models()
51
+ return image
52
 
53
+ def reload_models():
54
+ global sd_pipe
55
+ del sd_pipe
56
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
57
+ "runwayml/stable-diffusion-v1-5",
58
+ torch_dtype=torch.float32,
59
+ device_map="auto",
60
+ offload_folder="offload"
61
+ )
62
+
63
+ # ===== 语音处理 =====
64
  recognizer = sr.Recognizer()
65
 
66
  def audio_to_text(audio_file):
67
+ if not audio_file:
68
+ return ""
69
+ try:
70
+ with sr.AudioFile(audio_file) as source:
71
+ audio = recognizer.record(source)
72
+ return recognizer.recognize_whisper(audio, model="tiny.en")
73
+ except Exception as e:
74
+ print(f"语音识别错误: {e}")
75
+ return ""
76
 
77
+ # ===== Gradio界面 =====
78
+ with gr.Blocks(title="AI Art Studio", css=".gradio-container {max-width: 800px !important}") as app:
79
+ gr.Markdown("## 🎨 AI 艺术生成器 (CPU优化版)")
80
 
81
  with gr.Row():
82
  with gr.Column(scale=2):
83
+ # 输入控件
84
+ input_type = gr.Radio(["文字", "语音"], label="输入方式", value="文字")
85
+ voice_input = gr.Audio(
86
+ sources=["upload"],
87
+ type="filepath",
88
+ visible=False,
89
+ label="上传语音文件",
90
+ elem_classes="voice-input"
91
+ )
92
+ text_input = gr.Textbox(label="输入描述", placeholder="例:空中的魔法树屋...", lines=3)
93
 
94
+ # 风格选择
95
  style_choice = gr.Dropdown(
96
+ ["数字艺术", "油画", "动漫", "照片写实"],
97
+ value="数字艺术",
98
  label="艺术风格"
99
  )
100
 
101
+ # 生成按钮
102
  generate_btn = gr.Button("生成作品", variant="primary")
103
 
104
+ # 高级设置
105
  with gr.Accordion("高级设置", open=False):
106
  steps_slider = gr.Slider(10, 30, value=20, step=1, label="生成步数")
107
  guidance_slider = gr.Slider(5.0, 10.0, value=7.5, label="创意自由度")
108
+
109
  with gr.Column(scale=3):
110
+ # 输出展示
111
  prompt_output = gr.Textbox(label="优化后的Prompt", interactive=False)
112
+ image_output = gr.Image(label="生成结果", show_label=False, elem_id="output-image")
113
 
114
+ # 交互逻辑
115
  input_type.change(
116
+ fn=lambda x: gr.update(visible=x == "语音"),
117
  inputs=input_type,
118
+ outputs=voice_input
119
  )
120
 
121
  generate_btn.click(
122
+ fn=audio_to_text,
123
+ inputs=voice_input,
124
  outputs=text_input
125
  ).success(
126
  fn=enhance_prompt,
 
132
  outputs=image_output
133
  )
134
 
 
135
  if __name__ == "__main__":
136
  app.launch(server_name="0.0.0.0", server_port=7860)