lisonallen commited on
Commit
e2bc0a8
·
1 Parent(s): 407a5fa

Refactor app.py for improved error handling and simplify Gradio interface; downgrade gradio version in requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +70 -113
  2. requirements.txt +1 -1
app.py CHANGED
@@ -30,27 +30,28 @@ except Exception as e:
30
  from diffusers import DiffusionPipeline
31
  import torch
32
 
33
- device = "cuda" if torch.cuda.is_available() else "cpu"
34
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
35
-
36
- logger.info(f"Using device: {device}")
37
- logger.info(f"Loading model: {model_repo_id}")
38
-
39
- if torch.cuda.is_available():
40
- torch_dtype = torch.float16
41
- else:
42
- torch_dtype = torch.float32
43
-
44
  try:
 
 
 
 
 
 
 
 
 
 
 
45
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
46
  pipe = pipe.to(device)
47
  logger.info("Model loaded successfully")
 
 
 
48
  except Exception as e:
49
- logger.error(f"Error loading model: {str(e)}")
50
- raise
51
-
52
- MAX_SEED = np.iinfo(np.int32).max
53
- MAX_IMAGE_SIZE = 1024
54
 
55
  # @spaces.GPU #[uncomment to use ZeroGPU]
56
  def infer(
@@ -88,113 +89,69 @@ def infer(
88
  return image, seed
89
  except Exception as e:
90
  logger.error(f"Error in inference: {str(e)}")
91
- raise gr.Error(f"Error generating image: {str(e)}")
92
 
 
93
  examples = [
94
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
95
  "An astronaut riding a green horse",
96
  "A delicious ceviche cheesecake slice",
97
  ]
98
 
99
- css = """
100
- #col-container {
101
- margin: 0 auto;
102
- max-width: 640px;
103
- }
104
- """
105
-
106
- try:
107
- with gr.Blocks(css=css) as demo:
108
- with gr.Column(elem_id="col-container"):
109
- gr.Markdown(" # Text-to-Image Gradio Template")
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  with gr.Row():
112
- prompt = gr.Text(
113
- label="Prompt",
114
- show_label=False,
115
- max_lines=1,
116
- placeholder="Enter your prompt",
117
- container=False,
118
- )
119
-
120
- run_button = gr.Button("Run", scale=0, variant="primary")
121
-
122
- result = gr.Image(label="Result", show_label=False)
123
-
124
- with gr.Accordion("Advanced Settings", open=False):
125
- negative_prompt = gr.Text(
126
- label="Negative prompt",
127
- max_lines=1,
128
- placeholder="Enter a negative prompt",
129
- visible=True, # 改为可见
130
- )
131
-
132
- seed = gr.Slider(
133
- label="Seed",
134
- minimum=0,
135
- maximum=MAX_SEED,
136
- step=1,
137
- value=0,
138
- )
139
-
140
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
141
-
142
- with gr.Row():
143
- width = gr.Slider(
144
- label="Width",
145
- minimum=256,
146
- maximum=MAX_IMAGE_SIZE,
147
- step=32,
148
- value=1024,
149
- )
150
-
151
- height = gr.Slider(
152
- label="Height",
153
- minimum=256,
154
- maximum=MAX_IMAGE_SIZE,
155
- step=32,
156
- value=1024,
157
- )
158
-
159
- with gr.Row():
160
- guidance_scale = gr.Slider(
161
- label="Guidance scale",
162
- minimum=0.0,
163
- maximum=10.0,
164
- step=0.1,
165
- value=0.0,
166
- )
167
-
168
- num_inference_steps = gr.Slider(
169
- label="Number of inference steps",
170
- minimum=1,
171
- maximum=50,
172
- step=1,
173
- value=2,
174
- )
175
-
176
- gr.Examples(examples=examples, inputs=[prompt])
177
- gr.on(
178
- triggers=[run_button.click, prompt.submit],
179
- fn=infer,
180
- inputs=[
181
- prompt,
182
- negative_prompt,
183
- seed,
184
- randomize_seed,
185
- width,
186
- height,
187
- guidance_scale,
188
- num_inference_steps,
189
- ],
190
- outputs=[result, seed],
191
- )
192
 
193
- logger.info("Gradio interface created successfully")
194
- except Exception as e:
195
- logger.error(f"Error creating Gradio interface: {str(e)}")
196
- raise
197
-
 
 
 
 
 
 
 
 
 
 
 
 
198
  if __name__ == "__main__":
199
  try:
200
  logger.info("Starting Gradio app")
 
30
  from diffusers import DiffusionPipeline
31
  import torch
32
 
33
+ # 使用 try/except 避免在导入模块时出错
 
 
 
 
 
 
 
 
 
 
34
  try:
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ model_repo_id = "stabilityai/sdxl-turbo"
37
+
38
+ logger.info(f"Using device: {device}")
39
+ logger.info(f"Loading model: {model_repo_id}")
40
+
41
+ if torch.cuda.is_available():
42
+ torch_dtype = torch.float16
43
+ else:
44
+ torch_dtype = torch.float32
45
+
46
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
47
  pipe = pipe.to(device)
48
  logger.info("Model loaded successfully")
49
+
50
+ MAX_SEED = np.iinfo(np.int32).max
51
+ MAX_IMAGE_SIZE = 1024
52
  except Exception as e:
53
+ logger.error(f"Error during setup: {str(e)}")
54
+ # 不立即抛出异常,让 Gradio 界面可以加载
 
 
 
55
 
56
  # @spaces.GPU #[uncomment to use ZeroGPU]
57
  def infer(
 
89
  return image, seed
90
  except Exception as e:
91
  logger.error(f"Error in inference: {str(e)}")
92
+ return None, seed # 返回 None 而不是抛出异常
93
 
94
+ # 定义示例
95
  examples = [
96
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
97
  "An astronaut riding a green horse",
98
  "A delicious ceviche cheesecake slice",
99
  ]
100
 
101
+ # 简化 CSS
102
+ css = "#col-container { margin: 0 auto; max-width: 640px; }"
 
 
 
 
 
 
 
 
 
103
 
104
+ # 创建简化版的 Gradio 界面
105
+ with gr.Blocks(css=css) as demo:
106
+ with gr.Column(elem_id="col-container"):
107
+ gr.Markdown("# Text-to-Image Generator")
108
+
109
+ # 主输入区域
110
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt")
111
+ run_button = gr.Button("Generate Image")
112
+
113
+ # 结果显示
114
+ result = gr.Image(label="Generated Image")
115
+ seed_text = gr.Number(label="Seed Used")
116
+
117
+ # 高级设置(折叠)
118
+ with gr.Accordion("Advanced Settings", open=False):
119
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="What to exclude from the image")
120
+
121
+ # 种子设置
122
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
123
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
124
+
125
+ # 尺寸设置
126
  with gr.Row():
127
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
128
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
129
+
130
+ # 生成参数
131
+ with gr.Row():
132
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=0.0)
133
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=2)
134
+
135
+ # 示例
136
+ gr.Examples(examples, inputs=prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ # 绑定事件处理
139
+ run_button.click(
140
+ fn=infer,
141
+ inputs=[
142
+ prompt,
143
+ negative_prompt,
144
+ seed,
145
+ randomize_seed,
146
+ width,
147
+ height,
148
+ guidance_scale,
149
+ num_inference_steps,
150
+ ],
151
+ outputs=[result, seed_text],
152
+ )
153
+
154
+ # 启动应用
155
  if __name__ == "__main__":
156
  try:
157
  logger.info("Starting Gradio app")
requirements.txt CHANGED
@@ -4,4 +4,4 @@ invisible_watermark
4
  torch
5
  transformers
6
  xformers
7
- gradio==3.45.0
 
4
  torch
5
  transformers
6
  xformers
7
+ gradio==3.34.0