barreloflube commited on
Commit
a5df595
·
1 Parent(s): 20c6eca

Refactor get_prompt_attention function to include device parameter

Browse files
Files changed (1) hide show
  1. tabs/images/handlers.py +5 -4
tabs/images/handlers.py CHANGED
@@ -187,16 +187,16 @@ def get_control_mode(controlnet_config: ControlNetReq):
187
 
188
  def get_prompt_attention(pipeline, prompt, negative_prompt):
189
  if isinstance(pipeline, flux_pipes):
190
- prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt)
191
  return prompt_embeds, None, pooled_prompt_embeds, None
192
  elif isinstance(pipeline, sd_pipes):
193
- prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt)
194
  return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
195
 
196
 
197
  def cleanup(pipeline, loras = None, embeddings = None):
198
  if loras:
199
- pipeline.disable_lora()
200
  pipeline.unload_lora_weights()
201
  if embeddings:
202
  pipeline.unload_textual_inversion()
@@ -210,9 +210,10 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq, progress=gr.Prog
210
  pipeline_args = get_pipe(request)
211
  pipeline = pipeline_args["pipeline"]
212
  try:
213
- progress(0.5, "Configuring Pipeline")
214
  positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
215
 
 
216
  # Common Args
217
  args = {
218
  'prompt_embeds': positive_prompt_embeds,
 
187
 
188
  def get_prompt_attention(pipeline, prompt, negative_prompt):
189
  if isinstance(pipeline, flux_pipes):
190
+ prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt, device=device)
191
  return prompt_embeds, None, pooled_prompt_embeds, None
192
  elif isinstance(pipeline, sd_pipes):
193
+ prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt, device=device)
194
  return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
195
 
196
 
197
  def cleanup(pipeline, loras = None, embeddings = None):
198
  if loras:
199
+ # pipeline.disable_lora()
200
  pipeline.unload_lora_weights()
201
  if embeddings:
202
  pipeline.unload_textual_inversion()
 
210
  pipeline_args = get_pipe(request)
211
  pipeline = pipeline_args["pipeline"]
212
  try:
213
+ progress(0.3, "Getting Prompt Embeddings")
214
  positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
215
 
216
+ progress(0.5, "Configuring Pipeline")
217
  # Common Args
218
  args = {
219
  'prompt_embeds': positive_prompt_embeds,