lisonallen commited on
Commit
7d0da85
·
1 Parent(s): 0cfce88

Update app.py to address model loading issues and improve backup image generation

Browse files
Files changed (2) hide show
  1. app.py +152 -75
  2. requirements.txt +3 -3
app.py CHANGED
@@ -55,7 +55,7 @@ except Exception as e:
55
 
56
  # 创建一个备用图像
57
  def create_backup_image(prompt=""):
58
- logger.info(f"创建备用图像: {prompt}")
59
  img = PILImage.new('RGB', (512, 512), color=(240, 240, 250))
60
 
61
  try:
@@ -63,80 +63,72 @@ def create_backup_image(prompt=""):
63
  draw = ImageDraw.Draw(img)
64
  font = ImageFont.load_default()
65
 
66
- draw.text((20, 20), f"提示词: {prompt}", fill=(0, 0, 0), font=font)
67
- draw.text((20, 60), "模型加载失败,无法生成图像", fill=(255, 0, 0), font=font)
 
68
 
69
  except Exception as e:
70
- logger.error(f"创建备用图像时出错: {e}")
71
 
72
  return img
73
 
74
- # 预加载 AI 模型
75
- model = None
76
 
77
- def load_model():
78
- global model
79
- if model is not None:
80
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  try:
83
- logger.info("开始加载AI模型...")
84
 
85
- # 延迟导入,确保所有依赖都已正确安装
86
- import torch
87
- from diffusers import StableDiffusionPipeline
88
 
89
- # 使用较低版本的模型
90
- model_id = "CompVis/stable-diffusion-v1-4"
91
-
92
- # 设置加载参数
93
- load_options = {
94
- "revision": "fp16" if torch.cuda.is_available() else None,
95
- "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
96
- "safety_checker": None
97
- }
98
-
99
- logger.info(f"使用模型: {model_id}")
100
- pipe = StableDiffusionPipeline.from_pretrained(model_id, **load_options)
101
-
102
- # 转移到适当的设备
103
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
104
  pipe = pipe.to(device)
105
 
106
- # 优化
107
  if torch.cuda.is_available():
108
  pipe.enable_attention_slicing()
 
109
 
110
- logger.info("AI模型加载成功")
111
- model = pipe
112
- return model
113
- except Exception as e:
114
- logger.error(f"AI模型加载失败: {e}")
115
- return None
116
-
117
- # AI 图像生成函数
118
- def generate_ai_image(prompt, seed=None):
119
- # 尝试加载模型
120
- pipe = load_model()
121
- if pipe is None:
122
- logger.error("AI模型不可用")
123
- return None
124
-
125
- try:
126
- logger.info(f"使用AI生成图像: {prompt}")
127
-
128
- # 设置生成参数
129
- if seed is None:
130
- seed = random.randint(0, 2147483647)
131
-
132
- # 确定正确的设备
133
- generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed)
134
 
135
  # 生成图像
 
136
  image = pipe(
137
  prompt=prompt,
138
  guidance_scale=7.5,
139
- num_inference_steps=5, # 降低步数以加快速度
140
  generator=generator,
141
  height=512,
142
  width=512
@@ -146,47 +138,132 @@ def generate_ai_image(prompt, seed=None):
146
  if torch.cuda.is_available():
147
  torch.cuda.empty_cache()
148
 
149
- logger.info(f"AI图像生成成功,种子: {seed}")
150
  return image
 
 
 
 
 
 
 
 
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  except Exception as e:
153
- logger.error(f"AI图像生成失败: {e}")
154
- return None
 
155
 
156
  # 入口点函数 - 处理请求并生成图像
157
  def generate_image(prompt):
158
  # 处理空提示
159
  if not prompt or prompt.strip() == "":
160
  prompt = "a beautiful landscape"
161
- logger.info(f"输入为空,使用默认提示词: {prompt}")
162
 
163
- logger.info(f"收到提示词: {prompt}")
164
 
165
  # 尝试使用AI生成
166
- image = generate_ai_image(prompt)
 
 
 
 
 
 
167
 
168
- # 检查结果
169
- if image is not None:
170
- return image
171
- else:
172
- logger.warning("使用备用生成器")
173
- return create_backup_image(prompt)
174
 
175
  # 创建Gradio界面
176
  def create_demo():
177
- with gr.Blocks(title="AI 文本到图像生成器") as demo:
178
- gr.Markdown("# AI 文本到图像生成器")
179
- gr.Markdown("输���文本描述,AI将为你生成相应的图像。")
180
 
181
  with gr.Row():
182
  with gr.Column(scale=3):
183
  # 输入区域
184
  prompt_input = gr.Textbox(
185
- label="输入提示词",
186
- placeholder="描述你想要的图像,例如:一只可爱的猫,日落下的山脉...",
187
  lines=2
188
  )
189
- generate_button = gr.Button("生成图像", variant="primary")
190
 
191
  # 示例
192
  gr.Examples(
@@ -201,7 +278,7 @@ def create_demo():
201
 
202
  # 输出区域
203
  with gr.Column(scale=5):
204
- output_image = gr.Image(label="生成的图像", type="pil")
205
 
206
  # 绑定按钮事件
207
  generate_button.click(
@@ -225,11 +302,11 @@ demo = create_demo()
225
  # 启动应用
226
  if __name__ == "__main__":
227
  try:
228
- logger.info("启动Gradio界面...")
229
  demo.launch(
230
  server_name="0.0.0.0",
231
  show_api=False,
232
  share=False
233
  )
234
  except Exception as e:
235
- logger.error(f"启动失败: {e}")
 
55
 
56
  # 创建一个备用图像
57
  def create_backup_image(prompt=""):
58
+ logger.info(f"Creating backup image for: {prompt}")
59
  img = PILImage.new('RGB', (512, 512), color=(240, 240, 250))
60
 
61
  try:
 
63
  draw = ImageDraw.Draw(img)
64
  font = ImageFont.load_default()
65
 
66
+ # 使用英文消息避免编码问题
67
+ draw.text((20, 20), f"Prompt: {prompt}", fill=(0, 0, 0), font=font)
68
+ draw.text((20, 60), "Model loading failed. Showing placeholder image.", fill=(255, 0, 0), font=font)
69
 
70
  except Exception as e:
71
+ logger.error(f"Error creating backup image: {e}")
72
 
73
  return img
74
 
75
+ # 预加载图像用于快速响应
76
+ PLACEHOLDER_IMAGE = create_backup_image("placeholder")
77
 
78
+ # 尝试导入必要的AI库
79
+ try:
80
+ import torch
81
+ from diffusers import StableDiffusionPipeline
82
+
83
+ HAS_AI_LIBS = True
84
+ logger.info("Successfully imported AI libraries")
85
+ except ImportError as e:
86
+ logger.error(f"Failed to import AI libraries: {e}")
87
+ HAS_AI_LIBS = False
88
+
89
+ # AI 模型加载和图像生成
90
+ def generate_ai_image(prompt, seed=None):
91
+ if not HAS_AI_LIBS:
92
+ logger.error("AI libraries not available")
93
+ return PLACEHOLDER_IMAGE
94
+
95
+ # 设置随机种子
96
+ if seed is None:
97
+ seed = random.randint(0, 2147483647)
98
 
99
  try:
100
+ logger.info(f"Generating image for: {prompt}")
101
 
102
+ # 使用兼容的旧版本API加载模型
103
+ model_id = "runwayml/stable-diffusion-v1-5"
104
+ logger.info(f"Loading model: {model_id}")
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  device = "cuda" if torch.cuda.is_available() else "cpu"
107
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
108
+
109
+ # 加载模型 - 使用兼容的低级API
110
+ pipe = StableDiffusionPipeline.from_pretrained(
111
+ model_id,
112
+ torch_dtype=torch_dtype,
113
+ use_auth_token=False, # 明确不使用认证
114
+ revision="main", # 使用主分支
115
+ safety_checker=None, # 禁用安全检查器
116
+ )
117
  pipe = pipe.to(device)
118
 
119
+ # 优化内存
120
  if torch.cuda.is_available():
121
  pipe.enable_attention_slicing()
122
+ torch.cuda.empty_cache()
123
 
124
+ logger.info("Model loaded, generating image...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  # 生成图像
127
+ generator = torch.Generator(device).manual_seed(seed)
128
  image = pipe(
129
  prompt=prompt,
130
  guidance_scale=7.5,
131
+ num_inference_steps=4, # 最小步数
132
  generator=generator,
133
  height=512,
134
  width=512
 
138
  if torch.cuda.is_available():
139
  torch.cuda.empty_cache()
140
 
141
+ logger.info(f"Image generation successful with seed: {seed}")
142
  return image
143
+
144
+ except Exception as e:
145
+ logger.error(f"AI image generation failed: {e}")
146
+ return create_backup_image(prompt)
147
+
148
+ # 使用简单的规则生成图像作为备用方案
149
+ def generate_rule_based_image(prompt):
150
+ """当AI模型不可用时使用规则生成图像"""
151
+ logger.info(f"Using rule-based generator for: {prompt}")
152
 
153
+ # 创建基础图像
154
+ img = PILImage.new('RGB', (512, 512), color=(240, 240, 250))
155
+
156
+ try:
157
+ from PIL import ImageDraw, ImageFont
158
+ draw = ImageDraw.Draw(img)
159
+
160
+ # 提取关键词
161
+ prompt_lower = prompt.lower()
162
+
163
+ # 设置默认颜色和形状
164
+ bg_color = (240, 240, 250) # 浅蓝背景
165
+ shape_color = (64, 64, 128) # 深蓝形状
166
+
167
+ # 基于关键词调整颜色
168
+ if "red" in prompt_lower:
169
+ shape_color = (200, 50, 50)
170
+ elif "blue" in prompt_lower:
171
+ shape_color = (50, 50, 200)
172
+ elif "green" in prompt_lower:
173
+ shape_color = (50, 200, 50)
174
+ elif "yellow" in prompt_lower:
175
+ shape_color = (200, 200, 50)
176
+
177
+ # 画一个基本形状
178
+ if "cat" in prompt_lower or "kitten" in prompt_lower:
179
+ # 猫头
180
+ draw.ellipse((156, 156, 356, 356), fill=shape_color)
181
+ # 猫眼睛
182
+ draw.ellipse((206, 206, 236, 236), fill=(255, 255, 255))
183
+ draw.ellipse((276, 206, 306, 236), fill=(255, 255, 255))
184
+ # 猫瞳孔
185
+ draw.ellipse((216, 216, 226, 226), fill=(0, 0, 0))
186
+ draw.ellipse((286, 216, 296, 226), fill=(0, 0, 0))
187
+ # 猫鼻子
188
+ draw.polygon([(256, 256), (246, 276), (266, 276)], fill=(255, 150, 150))
189
+ # 猫耳朵
190
+ draw.polygon([(156, 156), (176, 96), (216, 156)], fill=shape_color)
191
+ draw.polygon([(356, 156), (336, 96), (296, 156)], fill=shape_color)
192
+ elif "landscape" in prompt_lower or "mountain" in prompt_lower:
193
+ # 天空
194
+ draw.rectangle([(0, 0), (512, 300)], fill=(100, 150, 250))
195
+ # 山脉
196
+ draw.polygon([(0, 300), (150, 100), (300, 300)], fill=(100, 100, 100))
197
+ draw.polygon([(200, 300), (400, 150), (512, 300)], fill=(80, 80, 80))
198
+ # 地面
199
+ draw.rectangle([(0, 300), (512, 512)], fill=(100, 200, 100))
200
+ elif "castle" in prompt_lower or "building" in prompt_lower:
201
+ # 天空
202
+ draw.rectangle([(0, 0), (512, 200)], fill=(150, 200, 250))
203
+ # 主塔
204
+ draw.rectangle([(200, 200), (312, 400)], fill=shape_color)
205
+ # 塔顶
206
+ draw.polygon([(180, 200), (256, 100), (332, 200)], fill=(180, 0, 0))
207
+ # 小塔
208
+ draw.rectangle([(150, 300), (200, 400)], fill=shape_color)
209
+ draw.rectangle([(312, 300), (362, 400)], fill=shape_color)
210
+ # 城墙
211
+ draw.rectangle([(100, 400), (412, 450)], fill=shape_color)
212
+ # 地面
213
+ draw.rectangle([(0, 450), (512, 512)], fill=(100, 150, 100))
214
+ else:
215
+ # 默认绘制几何形状
216
+ draw.rectangle([(100, 100), (412, 412)], outline=(0, 0, 0), width=2)
217
+ draw.ellipse((150, 150, 362, 362), fill=shape_color)
218
+ draw.polygon([(256, 100), (412, 412), (100, 412)], fill=(shape_color[0]//2, shape_color[1]//2, shape_color[2]//2))
219
+
220
+ # 添加提示词和说明
221
+ font = ImageFont.load_default()
222
+ draw.text((10, 10), f"Prompt: {prompt}", fill=(0, 0, 0), font=font)
223
+ draw.text((10, 30), "Generated with rules (AI model unavailable)", fill=(100, 100, 100), font=font)
224
+
225
  except Exception as e:
226
+ logger.error(f"Error in rule-based image generation: {e}")
227
+
228
+ return img
229
 
230
  # 入口点函数 - 处理请求并生成图像
231
  def generate_image(prompt):
232
  # 处理空提示
233
  if not prompt or prompt.strip() == "":
234
  prompt = "a beautiful landscape"
235
+ logger.info(f"Empty prompt, using default: {prompt}")
236
 
237
+ logger.info(f"Received prompt: {prompt}")
238
 
239
  # 尝试使用AI生成
240
+ if HAS_AI_LIBS:
241
+ try:
242
+ image = generate_ai_image(prompt)
243
+ if image is not None:
244
+ return image
245
+ except Exception as e:
246
+ logger.error(f"Error using AI generation: {e}")
247
 
248
+ # 如果AI不可用或失败,使用规则生成
249
+ logger.warning("Using rule-based image generation")
250
+ return generate_rule_based_image(prompt)
 
 
 
251
 
252
  # 创建Gradio界面
253
  def create_demo():
254
+ with gr.Blocks(title="Text to Image Generator") as demo:
255
+ gr.Markdown("# Text to Image Generator")
256
+ gr.Markdown("Enter a text description to generate an image.")
257
 
258
  with gr.Row():
259
  with gr.Column(scale=3):
260
  # 输入区域
261
  prompt_input = gr.Textbox(
262
+ label="Prompt",
263
+ placeholder="Describe the image you want, e.g.: a cute cat, sunset over mountains...",
264
  lines=2
265
  )
266
+ generate_button = gr.Button("Generate Image", variant="primary")
267
 
268
  # 示例
269
  gr.Examples(
 
278
 
279
  # 输出区域
280
  with gr.Column(scale=5):
281
+ output_image = gr.Image(label="Generated Image", type="pil")
282
 
283
  # 绑定按钮事件
284
  generate_button.click(
 
302
  # 启动应用
303
  if __name__ == "__main__":
304
  try:
305
+ logger.info("Starting Gradio interface...")
306
  demo.launch(
307
  server_name="0.0.0.0",
308
  show_api=False,
309
  share=False
310
  )
311
  except Exception as e:
312
+ logger.error(f"Failed to launch: {e}")
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
  accelerate==0.15.0
2
- diffusers==0.14.0
3
- huggingface-hub==0.13.3
4
  torch==1.13.1
5
- transformers==4.26.0
6
  safetensors==0.3.1
7
  gradio==3.24.1
8
  Pillow==9.5.0
 
1
  accelerate==0.15.0
2
+ diffusers==0.10.2
3
+ huggingface-hub==0.11.1
4
  torch==1.13.1
5
+ transformers==4.25.1
6
  safetensors==0.3.1
7
  gradio==3.24.1
8
  Pillow==9.5.0