Lifeinhockey commited on
Commit
2eb05f5
·
verified ·
1 Parent(s): 7d4603f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -16
app.py CHANGED
@@ -9,7 +9,7 @@ MAX_SEED = np.iinfo(np.int32).max
9
  MAX_IMAGE_SIZE = 1024
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
13
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
14
 
15
  def get_lora_sd_pipeline(
@@ -34,7 +34,6 @@ def get_lora_sd_pipeline(
34
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
35
  pipe.unet.set_adapter(adapter_name)
36
  after_params = pipe.unet.parameters()
37
- print("Parameters changed:", any(torch.any(b != a) for b, a in zip(before_params, after_params)))
38
 
39
  if os.path.exists(text_encoder_sub_dir):
40
  pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
@@ -48,10 +47,8 @@ def get_lora_sd_pipeline(
48
  def long_prompt_encoder(prompt, tokenizer, text_encoder, max_length=77):
49
  tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"]
50
  part_s = [tokens[:, i:i + max_length] for i in range(0, tokens.shape[1], max_length)]
51
-
52
  with torch.no_grad():
53
  embeds = [text_encoder(part.to(text_encoder.device))[0] for part in part_s]
54
-
55
  return torch.cat(embeds, dim=1)
56
 
57
  def align_embeddings(prompt_embeds, negative_prompt_embeds):
@@ -59,25 +56,25 @@ def align_embeddings(prompt_embeds, negative_prompt_embeds):
59
  return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
60
  torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
61
 
62
- pipe_default = get_lora_sd_pipeline(ckpt_dir='./lora_man_animestyle', base_model_name_or_path=model_id_default, dtype=torch_dtype).to(device)
63
 
64
  def infer(
65
  prompt,
66
  negative_prompt,
67
  width=512,
68
  height=512,
69
- num_inference_steps=20,
70
- model_id='stable-diffusion-v1-5/stable-diffusion-v1-5',
71
- seed=4,
72
- guidance_scale=7.5,
73
- lora_scale=0.5,
74
  progress=gr.Progress(track_tqdm=True)
75
  ):
76
 
77
  generator = torch.Generator(device).manual_seed(seed)
78
 
79
- if model_id != model_id_default:
80
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
81
  prompt_embeds = long_prompt_encoder(prompt, pipe.tokenizer, pipe.text_encoder)
82
  negative_prompt_embeds = long_prompt_encoder(negative_prompt, pipe.tokenizer, pipe.text_encoder)
83
  prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
@@ -86,8 +83,6 @@ def infer(
86
  prompt_embeds = long_prompt_encoder(prompt, pipe.tokenizer, pipe.text_encoder)
87
  negative_prompt_embeds = long_prompt_encoder(negative_prompt, pipe.tokenizer, pipe.text_encoder)
88
  prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
89
- print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
90
- print(f"LoRA scale applied: {lora_scale}")
91
  pipe.fuse_lora(lora_scale=lora_scale)
92
 
93
  params = {
@@ -139,7 +134,7 @@ with gr.Blocks(css=css) as demo:
139
  gr.Markdown(" # Text-to-Image Gradio Template from V. Gorsky")
140
 
141
  with gr.Row():
142
- model_id = gr.Dropdown(
143
  label="Model Selection",
144
  choices=available_models,
145
  value="stable-diffusion-v1-5/stable-diffusion-v1-5",
@@ -228,7 +223,7 @@ with gr.Blocks(css=css) as demo:
228
  width,
229
  height,
230
  num_inference_steps,
231
- model_id,
232
  seed,
233
  guidance_scale,
234
  lora_scale,
 
9
  MAX_IMAGE_SIZE = 1024
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
13
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
14
 
15
  def get_lora_sd_pipeline(
 
34
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
35
  pipe.unet.set_adapter(adapter_name)
36
  after_params = pipe.unet.parameters()
 
37
 
38
  if os.path.exists(text_encoder_sub_dir):
39
  pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
 
47
  def long_prompt_encoder(prompt, tokenizer, text_encoder, max_length=77):
48
  tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"]
49
  part_s = [tokens[:, i:i + max_length] for i in range(0, tokens.shape[1], max_length)]
 
50
  with torch.no_grad():
51
  embeds = [text_encoder(part.to(text_encoder.device))[0] for part in part_s]
 
52
  return torch.cat(embeds, dim=1)
53
 
54
  def align_embeddings(prompt_embeds, negative_prompt_embeds):
 
56
  return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
57
  torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
58
 
59
+ pipe_default = get_lora_sd_pipeline(ckpt_dir='./lora_man_animestyle', base_model_name_or_path=model_default, dtype=torch_dtype).to(device)
60
 
61
  def infer(
62
  prompt,
63
  negative_prompt,
64
  width=512,
65
  height=512,
66
+ num_inference_steps,
67
+ model,
68
+ seed,
69
+ guidance_scale,
70
+ lora_scale,
71
  progress=gr.Progress(track_tqdm=True)
72
  ):
73
 
74
  generator = torch.Generator(device).manual_seed(seed)
75
 
76
+ if model != model_default:
77
+ pipe = StableDiffusionPipeline.from_pretrained(model, torch_dtype=torch_dtype).to(device)
78
  prompt_embeds = long_prompt_encoder(prompt, pipe.tokenizer, pipe.text_encoder)
79
  negative_prompt_embeds = long_prompt_encoder(negative_prompt, pipe.tokenizer, pipe.text_encoder)
80
  prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
 
83
  prompt_embeds = long_prompt_encoder(prompt, pipe.tokenizer, pipe.text_encoder)
84
  negative_prompt_embeds = long_prompt_encoder(negative_prompt, pipe.tokenizer, pipe.text_encoder)
85
  prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
 
 
86
  pipe.fuse_lora(lora_scale=lora_scale)
87
 
88
  params = {
 
134
  gr.Markdown(" # Text-to-Image Gradio Template from V. Gorsky")
135
 
136
  with gr.Row():
137
+ model = gr.Dropdown(
138
  label="Model Selection",
139
  choices=available_models,
140
  value="stable-diffusion-v1-5/stable-diffusion-v1-5",
 
223
  width,
224
  height,
225
  num_inference_steps,
226
+ model,
227
  seed,
228
  guidance_scale,
229
  lora_scale,