Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -40,10 +40,10 @@ def get_lora_sd_pipeline(
|
|
40 |
|
41 |
def process_prompt(prompt, tokenizer, text_encoder, max_length=77):
|
42 |
tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"]
|
43 |
-
|
44 |
|
45 |
with torch.no_grad():
|
46 |
-
embeds = [text_encoder(
|
47 |
|
48 |
return torch.cat(embeds, dim=1)
|
49 |
|
|
|
40 |
|
41 |
def process_prompt(prompt, tokenizer, text_encoder, max_length=77):
|
42 |
tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"]
|
43 |
+
part_s = [tokens[:, i:i + max_length] for i in range(0, tokens.shape[1], max_length)]
|
44 |
|
45 |
with torch.no_grad():
|
46 |
+
embeds = [text_encoder(part.to(text_encoder.device))[0] for part in part_s]
|
47 |
|
48 |
return torch.cat(embeds, dim=1)
|
49 |
|