Lifeinhockey commited on
Commit
04519b1
·
verified ·
1 Parent(s): 0fa9e72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -38,7 +38,7 @@ def get_lora_sd_pipeline(
38
 
39
  return pipe
40
 
41
- def process_prompt(prompt, tokenizer, text_encoder, max_length=77):
42
  tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"]
43
  part_s = [tokens[:, i:i + max_length] for i in range(0, tokens.shape[1], max_length)]
44
 
@@ -79,13 +79,13 @@ def infer(
79
 
80
  if model_id != model_id_default:
81
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
82
- prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
83
- negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
84
  prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
85
  else:
86
  pipe = pipe_default
87
- prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
88
- negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
89
  prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
90
  print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
91
  print(f"LoRA scale applied: {lora_scale}")
@@ -241,4 +241,3 @@ with gr.Blocks(css=css) as demo:
241
 
242
  if __name__ == "__main__":
243
  demo.launch()
244
-
 
38
 
39
  return pipe
40
 
41
+ def long_prompt_encoder(prompt, tokenizer, text_encoder, max_length=77):
42
  tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"]
43
  part_s = [tokens[:, i:i + max_length] for i in range(0, tokens.shape[1], max_length)]
44
 
 
79
 
80
  if model_id != model_id_default:
81
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
82
+ prompt_embeds = long_prompt_encoder(prompt, pipe.tokenizer, pipe.text_encoder)
83
+ negative_prompt_embeds = long_prompt_encoder(negative_prompt, pipe.tokenizer, pipe.text_encoder)
84
  prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
85
  else:
86
  pipe = pipe_default
87
+ prompt_embeds = long_prompt_encoder(prompt, pipe.tokenizer, pipe.text_encoder)
88
+ negative_prompt_embeds = long_prompt_encoder(negative_prompt, pipe.tokenizer, pipe.text_encoder)
89
  prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
90
  print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
91
  print(f"LoRA scale applied: {lora_scale}")
 
241
 
242
  if __name__ == "__main__":
243
  demo.launch()