Lifeinhockey commited on
Commit
b7b1936
·
verified ·
1 Parent(s): 7d35736

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -91
app.py CHANGED
@@ -1,58 +1,111 @@
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
26
- model,
27
- prompt,
28
- negative_prompt,
29
- seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
-
37
- global model_repo_id
38
- if model != model_repo_id:
39
- print(model, model_repo_id)
40
- pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch_dtype)
41
- pipe = pipe.to(device)
42
-
43
- generator = torch.Generator().manual_seed(seed)
44
-
45
- image = pipe(
46
- prompt=prompt,
47
- negative_prompt=negative_prompt,
48
- guidance_scale=guidance_scale,
49
- num_inference_steps=num_inference_steps,
50
- width=width,
51
- height=height,
52
- generator=generator,
53
- ).images[0]
54
-
55
- return image, seed
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  examples = [
@@ -87,7 +140,6 @@ available_models = [
87
  "stabilityai/stable-diffusion-3-medium-diffusers",
88
  "stabilityai/stable-diffusion-3.5-large",
89
  "stabilityai/stable-diffusion-3.5-large-turbo",
90
-
91
  ]
92
 
93
  with gr.Blocks(css=css) as demo:
@@ -95,48 +147,62 @@ with gr.Blocks(css=css) as demo:
95
  with gr.Column(elem_id="col-container"):
96
  gr.Markdown(" # Text-to-Image Gradio Template from V. Gorsky")
97
 
98
- model = gr.Dropdown(
99
- label="Model Selection",
100
- choices=available_models,
101
- value="stable-diffusion-v1-5/stable-diffusion-v1-5",
102
- interactive=True
103
- )
104
- prompt = gr.Text(
105
- label="Prompt",
106
- show_label=False,
107
- max_lines=1,
108
- placeholder="Enter your prompt",
109
- container=False,
110
- )
111
-
112
- negative_prompt = gr.Text(
113
- label="Negative prompt",
114
- max_lines=1,
115
- placeholder="Enter a negative prompt",
116
- visible=True,
117
- )
118
-
119
- seed = gr.Slider(
120
- label="Seed",
121
- minimum=0,
122
- maximum=MAX_SEED,
123
- step=1,
124
- value=0,
125
- )
126
- guidance_scale = gr.Slider(
127
- label="Guidance scale",
128
- minimum=0.0,
129
- maximum=10.0,
130
- step=0.1,
131
- value=7.5, # Replace with defaults that work for your model
132
- )
133
- num_inference_steps = gr.Slider(
134
- label="Number of inference steps",
135
- minimum=1,
136
- maximum=100,
137
  step=1,
138
- value=30, # Replace with defaults that work for your model
139
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  with gr.Accordion("Advanced Settings", open=False):
142
  with gr.Row():
@@ -158,15 +224,15 @@ with gr.Blocks(css=css) as demo:
158
 
159
  gr.Examples(examples=examples, inputs=[prompt])
160
  gr.Examples(examples=examples_negative, inputs=[negative_prompt])
161
-
162
  run_button = gr.Button("Run", scale=0, variant="primary")
163
  result = gr.Image(label="Result", show_label=False)
164
-
165
  gr.on(
166
  triggers=[run_button.click, prompt.submit],
167
  fn=infer,
168
  inputs=[
169
- model,
170
  prompt,
171
  negative_prompt,
172
  seed,
@@ -174,6 +240,7 @@ with gr.Blocks(css=css) as demo:
174
  height,
175
  guidance_scale,
176
  num_inference_steps,
 
177
  ],
178
  outputs=[result, seed],
179
  )
 
1
+ import os
2
  import gradio as gr
3
  import numpy as np
4
  import random
 
 
 
5
  import torch
6
+ from diffusers import (
7
+ DiffusionPipeline,
8
+ StableDiffusionPipeline
9
+ )
10
+ from peft import PeftModel, LoraConfig
11
+
12
+
13
+ def get_lora_sd_pipeline(
14
+ ckpt_dir='./lora_man_animestyle',
15
+ base_model_name_or_path=None,
16
+ dtype=torch.float16,
17
+ adapter_name="default"
18
+ ):
19
+
20
+ unet_sub_dir = os.path.join(ckpt_dir, "unet")
21
+ text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
22
+
23
+ if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
24
+ config = LoraConfig.from_pretrained(text_encoder_sub_dir)
25
+ base_model_name_or_path = config.base_model_name_or_path
26
+
27
+ if base_model_name_or_path is None:
28
+ raise ValueError("Please specify the base model name or path")
29
+
30
+ pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
31
+ before_params = pipe.unet.parameters()
32
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
33
+ pipe.unet.set_adapter(adapter_name)
34
+ after_params = pipe.unet.parameters()
35
+ print("Parameters changed:", any(torch.any(b != a) for b, a in zip(before_params, after_params)))
36
+
37
+ if os.path.exists(text_encoder_sub_dir):
38
+ pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
39
+
40
+ if dtype in (torch.float16, torch.bfloat16):
41
+ pipe.unet.half()
42
+ pipe.text_encoder.half()
43
+
44
+ return pipe
45
+
46
+ def process_prompt(prompt, tokenizer, text_encoder, max_length=77):
47
+ tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"]
48
+ chunks = [tokens[:, i:i + max_length] for i in range(0, tokens.shape[1], max_length)]
49
+
50
+ with torch.no_grad():
51
+ embeds = [text_encoder(chunk.to(text_encoder.device))[0] for chunk in chunks]
52
+
53
+ return torch.cat(embeds, dim=1)
54
+
55
+ def align_embeddings(prompt_embeds, negative_prompt_embeds):
56
+ max_length = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
57
+ return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
58
+ torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
59
+
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
62
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
63
+
64
+ pipe_default = get_lora_sd_pipeline(ckpt_dir='./lora_man_animestyle', base_model_name_or_path=model_id_default, dtype=torch_dtype).to(device)
65
 
66
  MAX_SEED = np.iinfo(np.int32).max
67
  MAX_IMAGE_SIZE = 1024
68
 
 
 
69
  def infer(
70
+ prompt,
71
+ negative_prompt,
72
+ width=512,
73
+ height=512,
74
+ num_inference_steps=20,
75
+ model_id="stable-diffusion-v1-5/stable-diffusion-v1-5",
76
+ seed=4,
77
+ guidance_scale=7.5,
78
+ lora_scale=0.5,
79
+ progress=gr.Progress(track_tqdm=True)
80
+ ):
81
+
82
+ generator = torch.Generator(device).manual_seed(seed)
83
+
84
+ if model_id != model_id_default:
85
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
86
+ prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
87
+ negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
88
+ prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
89
+ else:
90
+ pipe = pipe_default
91
+ prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
92
+ negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
93
+ prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
94
+ print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
95
+ print(f"LoRA scale applied: {lora_scale}")
96
+ pipe.fuse_lora(lora_scale=lora_scale)
97
+
98
+ params = {
99
+ 'prompt_embeds': prompt_embeds,
100
+ 'negative_prompt_embeds': negative_prompt_embeds,
101
+ 'guidance_scale': guidance_scale,
102
+ 'num_inference_steps': num_inference_steps,
103
+ 'width': width,
104
+ 'height': height,
105
+ 'generator': generator,
106
+ }
107
+
108
+ return pipe(**params).images[0]
109
 
110
 
111
  examples = [
 
140
  "stabilityai/stable-diffusion-3-medium-diffusers",
141
  "stabilityai/stable-diffusion-3.5-large",
142
  "stabilityai/stable-diffusion-3.5-large-turbo",
 
143
  ]
144
 
145
  with gr.Blocks(css=css) as demo:
 
147
  with gr.Column(elem_id="col-container"):
148
  gr.Markdown(" # Text-to-Image Gradio Template from V. Gorsky")
149
 
150
+ with gr.Row():
151
+ model_id = gr.Dropdown(
152
+ label="Model Selection",
153
+ choices=available_models,
154
+ max_lines=1,
155
+ placeholder="Enter model id like 'stable-diffusion-v1-5/stable-diffusion-v1-5'",
156
+ value="stable-diffusion-v1-5/stable-diffusion-v1-5",
157
+ interactive=True
158
+ )
159
+
160
+ with gr.Row():
161
+ prompt = gr.Text(
162
+ label="Prompt",
163
+ show_label=False,
164
+ max_lines=1,
165
+ placeholder="Enter your prompt",
166
+ container=False,
167
+ )
168
+ negative_prompt = gr.Text(
169
+ label="Negative prompt",
170
+ max_lines=1,
171
+ placeholder="Enter a negative prompt",
172
+ visible=True,
173
+ )
174
+
175
+ with gr.Row():
176
+ lora_scale = gr.Slider(
177
+ label="LoRA scale",
178
+ minimum=0.0,
179
+ maximum=1.0,
180
+ step=0.1,
181
+ value=0.5,
182
+ )
183
+
184
+ with gr.Row():
185
+ seed = gr.Slider(
186
+ label="Seed",
187
+ minimum=0,
188
+ maximum=MAX_SEED,
189
  step=1,
190
+ value=0,
191
+ )
192
+ guidance_scale = gr.Slider(
193
+ label="Guidance scale",
194
+ minimum=0.0,
195
+ maximum=10.0,
196
+ step=0.1,
197
+ value=7.5, # Replace with defaults that work for your model
198
+ )
199
+ num_inference_steps = gr.Slider(
200
+ label="Number of inference steps",
201
+ minimum=1,
202
+ maximum=100,
203
+ step=1,
204
+ value=30, # Replace with defaults that work for your model
205
+ )
206
 
207
  with gr.Accordion("Advanced Settings", open=False):
208
  with gr.Row():
 
224
 
225
  gr.Examples(examples=examples, inputs=[prompt])
226
  gr.Examples(examples=examples_negative, inputs=[negative_prompt])
227
+
228
  run_button = gr.Button("Run", scale=0, variant="primary")
229
  result = gr.Image(label="Result", show_label=False)
230
+
231
  gr.on(
232
  triggers=[run_button.click, prompt.submit],
233
  fn=infer,
234
  inputs=[
235
+ model_id,
236
  prompt,
237
  negative_prompt,
238
  seed,
 
240
  height,
241
  guidance_scale,
242
  num_inference_steps,
243
+ lora_scale,
244
  ],
245
  outputs=[result, seed],
246
  )