Commit
·
a5df595
1
Parent(s):
20c6eca
Refactor get_prompt_attention function to include device parameter
Browse files- tabs/images/handlers.py +5 -4
tabs/images/handlers.py
CHANGED
@@ -187,16 +187,16 @@ def get_control_mode(controlnet_config: ControlNetReq):
|
|
187 |
|
188 |
def get_prompt_attention(pipeline, prompt, negative_prompt):
|
189 |
if isinstance(pipeline, flux_pipes):
|
190 |
-
prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt)
|
191 |
return prompt_embeds, None, pooled_prompt_embeds, None
|
192 |
elif isinstance(pipeline, sd_pipes):
|
193 |
-
prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt)
|
194 |
return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
195 |
|
196 |
|
197 |
def cleanup(pipeline, loras = None, embeddings = None):
|
198 |
if loras:
|
199 |
-
pipeline.disable_lora()
|
200 |
pipeline.unload_lora_weights()
|
201 |
if embeddings:
|
202 |
pipeline.unload_textual_inversion()
|
@@ -210,9 +210,10 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq, progress=gr.Prog
|
|
210 |
pipeline_args = get_pipe(request)
|
211 |
pipeline = pipeline_args["pipeline"]
|
212 |
try:
|
213 |
-
progress(0.
|
214 |
positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
|
215 |
|
|
|
216 |
# Common Args
|
217 |
args = {
|
218 |
'prompt_embeds': positive_prompt_embeds,
|
|
|
187 |
|
188 |
def get_prompt_attention(pipeline, prompt, negative_prompt):
|
189 |
if isinstance(pipeline, flux_pipes):
|
190 |
+
prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt, device=device)
|
191 |
return prompt_embeds, None, pooled_prompt_embeds, None
|
192 |
elif isinstance(pipeline, sd_pipes):
|
193 |
+
prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt, device=device)
|
194 |
return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
195 |
|
196 |
|
197 |
def cleanup(pipeline, loras = None, embeddings = None):
|
198 |
if loras:
|
199 |
+
# pipeline.disable_lora()
|
200 |
pipeline.unload_lora_weights()
|
201 |
if embeddings:
|
202 |
pipeline.unload_textual_inversion()
|
|
|
210 |
pipeline_args = get_pipe(request)
|
211 |
pipeline = pipeline_args["pipeline"]
|
212 |
try:
|
213 |
+
progress(0.3, "Getting Prompt Embeddings")
|
214 |
positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
|
215 |
|
216 |
+
progress(0.5, "Configuring Pipeline")
|
217 |
# Common Args
|
218 |
args = {
|
219 |
'prompt_embeds': positive_prompt_embeds,
|