dezzman commited on
Commit
8a52ce7
·
verified ·
1 Parent(s): b9af0ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -216
app.py CHANGED
@@ -1,260 +1,131 @@
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
- import os
5
-
6
- # import spaces #[uncomment to use ZeroGPU]
7
- from diffusers import DiffusionPipeline, StableDiffusionPipeline
8
- from peft import PeftModel, LoraConfig
9
  import torch
10
- from typing import Optional
11
-
 
12
 
13
  def get_lora_sd_pipeline(
14
  ckpt_dir='./lora_logos',
15
  base_model_name_or_path=None,
16
  dtype=torch.float16,
17
  adapter_name="default"
18
- ):
19
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
20
  text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
 
21
  if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
22
  config = LoraConfig.from_pretrained(text_encoder_sub_dir)
23
  base_model_name_or_path = config.base_model_name_or_path
24
-
25
  if base_model_name_or_path is None:
26
  raise ValueError("Please specify the base model name or path")
27
-
28
  pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
29
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
30
-
31
  if os.path.exists(text_encoder_sub_dir):
32
- pipe.text_encoder = PeftModel.from_pretrained(
33
- pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name
34
- )
35
-
36
  if dtype in (torch.float16, torch.bfloat16):
37
  pipe.unet.half()
38
  pipe.text_encoder.half()
 
39
  return pipe
40
 
41
- def split_prompt(prompt, tokenizer, max_length=77):
42
- print(prompt)
43
- print(type(prompt))
44
- print(str(prompt))
45
- tokens = tokenizer(prompt, truncation=False)["input_ids"]
46
- chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)]
47
- return chunks
48
-
49
- def get_prompt_embeds(prompt_chunks, text_encoder):
50
- prompt_embeds = []
51
- for chunk in prompt_chunks:
52
- chunk_tensor = torch.tensor([chunk]).to(text_encoder.device)
53
- with torch.no_grad():
54
- embeds = text_encoder(chunk_tensor)[0]
55
- prompt_embeds.append(embeds)
56
- return torch.cat(prompt_embeds, dim=1)
57
 
58
- def shape_alignment(prompt_embeds, negative_prompt_embeds):
59
  max_length = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
 
 
60
 
61
- def pad_to_max_length(tensor, target_length):
62
- padding = target_length - tensor.shape[1]
63
- if padding > 0:
64
- pad_tensor = torch.zeros(
65
- tensor.shape[0], padding, tensor.shape[2], device=tensor.device
66
- )
67
- tensor = torch.cat([tensor, pad_tensor], dim=1)
68
- return tensor
69
-
70
- prompt_embeds = pad_to_max_length(prompt_embeds, max_length)
71
- negative_prompt_embeds = pad_to_max_length(negative_prompt_embeds, max_length)
72
-
73
- assert prompt_embeds.shape == negative_prompt_embeds.shape, "Shapes do not match!"
74
- return prompt_embeds, negative_prompt_embeds
75
-
76
- def prompts_embeddings(prompt, negative_promt, tokenizer, text_encoder):
77
- prompt_chunks = split_prompt(prompt, tokenizer)
78
- negative_prompt_chunks = split_prompt(negative_prompt, tokenizer)
79
-
80
- prompt_embeds = get_prompt_embeds(prompt_chunks, text_encoder)
81
- negative_prompt_embeds = get_prompt_embeds(negative_prompt_chunks, text_encoder)
82
-
83
- prompt_embeds, negative_prompt_embeds = shape_alignment(prompt_embeds, negative_prompt_embeds)
84
-
85
- return prompt_embeds, negative_prompt_embeds
86
-
87
-
88
- device = "cuda" if torch.cuda.is_available() else "cpu"
89
  model_id_default = "CompVis/stable-diffusion-v1-4"
 
90
 
91
- if torch.cuda.is_available():
92
- torch_dtype = torch.float16
93
- else:
94
- torch_dtype = torch.float32
95
-
96
-
97
- pipe_default = get_lora_sd_pipeline(
98
- ckpt_dir='./lora_logos',
99
- base_model_name_or_path=model_id_default,
100
- dtype=torch_dtype,
101
- )
102
- # pipe_default = DiffusionPipeline.from_pretrained(model_id_default, torch_dtype=torch_dtype)
103
- pipe_default = pipe_default.to(device)
104
-
105
- MAX_SEED = np.iinfo(np.int32).max
106
- MAX_IMAGE_SIZE = 1024
107
 
108
-
109
- # @spaces.GPU #[uncomment to use ZeroGPU]
110
  def infer(
111
- prompt: str,
112
- negative_prompt: str,
113
- width: int,
114
- height: int,
115
- num_inference_steps: Optional[int] = 20,
116
- model_id: Optional[str] = 'CompVis/stable-diffusion-v1-4',
117
- seed: Optional[int] = 42,
118
- guidance_scale: Optional[float] = 7.0,
119
- lora_scale: Optional[float] = 0.5,
120
- progress=gr.Progress(track_tqdm=True),
121
- ):
122
- generator = torch.Generator().manual_seed(seed)
 
 
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  params = {
125
- # 'prompt': prompt,
126
- # 'negative_prompt': negative_prompt,
127
  'guidance_scale': guidance_scale,
128
  'num_inference_steps': num_inference_steps,
129
  'width': width,
130
  'height': height,
131
  'generator': generator,
132
  }
133
-
134
- if model_id != model_id_default:
135
- pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
136
- pipe = pipe.to(device)
137
- image = pipe(**params).images[0]
138
- else:
139
- prompt_embeds, negative_prompt_embeds = prompts_embeddings(
140
- prompt,
141
- negative_prompt,
142
- pipe_default.tokenizer,
143
- pipe_default.text_encoder
144
- )
145
- params['prompt_embeds'] = prompt_embeds
146
- params['negative_prompt_embeds']=negative_prompt_embeds
147
- pipe_default.fuse_lora(lora_scale=lora_scale)
148
- image = pipe_default(**params).images[0]
149
-
150
- return image
151
-
152
- css = """
153
- #col-container {
154
- margin: 0 auto;
155
- max-width: 640px;
156
- }
157
- """
158
-
159
- with gr.Blocks(css=css) as demo:
160
- with gr.Column(elem_id="col-container"):
161
 
162
- gr.Markdown(" # DEMO Text-to-Image")
163
-
164
- with gr.Row():
165
- model_id = gr.Textbox(
166
- label="Model ID",
167
- max_lines=1,
168
- placeholder="Enter model id like 'CompVis/stable-diffusion-v1-4'",
169
- value="CompVis/stable-diffusion-v1-4"
170
- )
171
-
172
- prompt = gr.Textbox(
173
- label="Prompt",
174
- max_lines=1,
175
- placeholder="Enter your prompt",
176
- )
177
-
178
- negative_prompt = gr.Textbox(
179
- label="Negative prompt",
180
- max_lines=1,
181
- placeholder="Enter a negative prompt",
182
- )
183
-
184
- with gr.Row():
185
- seed = gr.Number(
186
- label="Seed",
187
- minimum=0,
188
- maximum=MAX_SEED,
189
- step=1,
190
- value=42,
191
- )
192
-
193
- with gr.Row():
194
- guidance_scale = gr.Slider(
195
- label="Guidance scale",
196
- minimum=0.0,
197
- maximum=10.0,
198
- step=0.1,
199
- value=7.0,
200
- )
201
-
202
- with gr.Row():
203
- lora_scale = gr.Slider(
204
- label="LoRA scale",
205
- minimum=0.0,
206
- maximum=1.0,
207
- step=0.1,
208
- value=0.5,
209
- )
210
-
211
- with gr.Row():
212
- num_inference_steps = gr.Slider(
213
- label="Number of inference steps",
214
- minimum=1,
215
- maximum=50,
216
- step=1,
217
- value=20,
218
- )
219
-
220
  with gr.Accordion("Optional Settings", open=False):
221
- with gr.Row():
222
- width = gr.Slider(
223
- label="Width",
224
- minimum=256,
225
- maximum=MAX_IMAGE_SIZE,
226
- step=32,
227
- value=512,
228
- )
229
-
230
- with gr.Row():
231
- height = gr.Slider(
232
- label="Height",
233
- minimum=256,
234
- maximum=MAX_IMAGE_SIZE,
235
- step=32,
236
- value=512,
237
- )
238
-
239
- run_button = gr.Button("Run", scale=1, variant="primary")
240
- result = gr.Image(label="Result", show_label=False)
241
 
242
- gr.on(
243
- triggers=[run_button.click, prompt.submit],
244
- fn=infer,
245
  inputs=[
246
- prompt,
247
- negative_prompt,
248
- width,
249
- height,
250
- num_inference_steps,
251
- model_id,
252
- seed,
253
- guidance_scale,
254
- lora_scale,
255
- ],
256
- outputs=[result],
257
- )
258
 
259
  if __name__ == "__main__":
260
- demo.launch()
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
  import torch
3
+ from diffusers import StableDiffusionPipeline
4
+ from peft import PeftModel, LoraConfig
5
+ import os
6
 
7
  def get_lora_sd_pipeline(
8
  ckpt_dir='./lora_logos',
9
  base_model_name_or_path=None,
10
  dtype=torch.float16,
11
  adapter_name="default"
12
+ ):
13
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
14
  text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
15
+
16
  if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
17
  config = LoraConfig.from_pretrained(text_encoder_sub_dir)
18
  base_model_name_or_path = config.base_model_name_or_path
19
+
20
  if base_model_name_or_path is None:
21
  raise ValueError("Please specify the base model name or path")
22
+
23
  pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
24
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
25
+
26
  if os.path.exists(text_encoder_sub_dir):
27
+ pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
28
+
 
 
29
  if dtype in (torch.float16, torch.bfloat16):
30
  pipe.unet.half()
31
  pipe.text_encoder.half()
32
+
33
  return pipe
34
 
35
+ def process_prompt(prompt, tokenizer, text_encoder, max_length=77):
36
+ tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"]
37
+ chunks = [tokens[:, i:i + max_length] for i in range(0, tokens.shape[1], max_length)]
38
+
39
+ with torch.no_grad():
40
+ embeds = [text_encoder(chunk.to(text_encoder.device))[0] for chunk in chunks]
41
+
42
+ return torch.cat(embeds, dim=1)
 
 
 
 
 
 
 
 
43
 
44
+ def align_embeddings(prompt_embeds, negative_prompt_embeds):
45
  max_length = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
46
+ return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
47
+ torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
48
 
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  model_id_default = "CompVis/stable-diffusion-v1-4"
51
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
52
 
53
+ pipe_default = get_lora_sd_pipeline(ckpt_dir='./lora_logos', base_model_name_or_path=model_id_default, dtype=torch_dtype).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
 
 
55
  def infer(
56
+ prompt,
57
+ negative_prompt,
58
+ width=512,
59
+ height=512,
60
+ num_inference_steps=20,
61
+ model_id='CompVis/stable-diffusion-v1-4',
62
+ seed=42,
63
+ guidance_scale=7.0,
64
+ lora_scale=0.5
65
+ ):
66
+ generator = torch.Generator(device).manual_seed(seed)
67
+
68
+ print(prompt)
69
+ print(type(prompt))
70
 
71
+ print(negative_prompt)
72
+ print(type(negative_prompt))
73
+
74
+ if model_id != model_id_default:
75
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
76
+ prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
77
+ negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
78
+ prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
79
+ else:
80
+ pipe = pipe_default
81
+ prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
82
+ negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
83
+ prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
84
+ pipe.fuse_lora(lora_scale=lora_scale)
85
+
86
  params = {
87
+ 'prompt_embeds': prompt_embeds,
88
+ 'negative_prompt_embeds': negative_prompt_embeds,
89
  'guidance_scale': guidance_scale,
90
  'num_inference_steps': num_inference_steps,
91
  'width': width,
92
  'height': height,
93
  'generator': generator,
94
  }
95
+
96
+ return pipe(**params).images[0]
97
+
98
+ with gr.Blocks() as demo:
99
+ with gr.Column():
100
+ gr.Markdown("# DEMO Text-to-Image")
101
+ model_id = gr.Textbox(label="Model ID", value=model_id_default)
102
+ prompt = gr.Textbox(label="Prompt")
103
+ negative_prompt = gr.Textbox(label="Negative prompt")
104
+ seed = gr.Number(label="Seed", value=42)
105
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, value=7.0)
106
+ lora_scale = gr.Slider(label="LoRA scale", minimum=0.0, maximum=1.0, value=0.5)
107
+ num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, value=20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  with gr.Accordion("Optional Settings", open=False):
110
+ width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=32)
111
+ height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=32)
112
+
113
+ run_button = gr.Button("Run")
114
+ result = gr.Image(label="Result")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ run_button.click(
117
+ fn=infer,
 
118
  inputs=[
119
+ prompt,
120
+ negative_prompt,
121
+ width,
122
+ height,
123
+ num_inference_steps,
124
+ model_id, seed,
125
+ guidance_scale,
126
+ lora_scale
127
+ ],
128
+ outputs=result)
 
 
129
 
130
  if __name__ == "__main__":
131
+ demo.launch()