willsh1997 commited on
Commit
8d20d8f
·
verified ·
1 Parent(s): 22de30c

move everything to same device

Browse files
Files changed (1) hide show
  1. app.py +15 -14
app.py CHANGED
@@ -314,10 +314,11 @@ class customUnClipPipeline(UnCLIPImageVariationPipeline):
314
 
315
 
316
  ### ADDITIONAL PIPELINE CODE FOR KARLO
 
317
  pipe = customUnClipPipeline.from_pretrained("kakaobrain/karlo-v1-alpha-image-variations", torch_dtype=torch.float32, trust_remote_code=True,
318
- # accelerator='ort', device='cpu'
319
  )
320
- pipe.to('cuda')
321
  # pipe.enable_model_cpu_offload()
322
 
323
 
@@ -334,11 +335,11 @@ def load_img_from_URL(URL):
334
 
335
  def embed_img(input_image):
336
  tokens = pipe.feature_extractor(input_image)
337
- img_model = pipe.image_encoder.to('cpu')
338
  with torch.no_grad():
339
  embeds = img_model(torch.tensor(tokens.pixel_values[0]).unsqueeze(0))
340
 
341
- return embeds.image_embeds.to('cpu')
342
 
343
 
344
  def localimg_2_embed(image_dir):
@@ -389,16 +390,16 @@ def image_grid(imgs, rows, cols):
389
  return grid
390
 
391
 
392
- chaosclicker_willtensor = localimg_2_embed('willpaint-imgs/chaosclicker-willpaint.png').to('cpu')
393
- contentcnsr_willtensor = localimg_2_embed('willpaint-imgs/contentconnoisseur-willpaint.png').to('cpu')
394
- digdaydrmr_willtensor = localimg_2_embed('willpaint-imgs/digitaldaydreamer-willpaint.png').to('cpu')
395
- ecoexplr_willtensor = localimg_2_embed('willpaint-imgs/ecoexplorer-willpaint.png').to('cpu')
396
- fandomfox_willtensor = localimg_2_embed('willpaint-imgs/fandomfox-willpaint.png').to('cpu')
397
- mememaven_willtensor = localimg_2_embed('willpaint-imgs/mememaven-willpaint.png').to('cpu')
398
- newsnerd_willtensor = localimg_2_embed('willpaint-imgs/newnerd-willpaint.png').to('cpu')
399
- nostalgicnvgtr_willtensor = localimg_2_embed('willpaint-imgs/nostalgicnavigator-willpaint.png').to('cpu')
400
- scrollseeker_willtensor = localimg_2_embed('willpaint-imgs/scrollseeker-willpaint.png').to('cpu')
401
- trendtracker_willtensor = localimg_2_embed('willpaint-imgs/trendtracker-willpaint.png').to('cpu')
402
 
403
 
404
  will_cand_tensors = torch.cat([chaosclicker_willtensor,
 
314
 
315
 
316
  ### ADDITIONAL PIPELINE CODE FOR KARLO
317
+ torch_device = 'cuda'
318
  pipe = customUnClipPipeline.from_pretrained("kakaobrain/karlo-v1-alpha-image-variations", torch_dtype=torch.float32, trust_remote_code=True,
319
+ # accelerator='ort', device=torch_device
320
  )
321
+ pipe.to(torch_device)
322
  # pipe.enable_model_cpu_offload()
323
 
324
 
 
335
 
336
  def embed_img(input_image):
337
  tokens = pipe.feature_extractor(input_image)
338
+ img_model = pipe.image_encoder.to(torch_device)
339
  with torch.no_grad():
340
  embeds = img_model(torch.tensor(tokens.pixel_values[0]).unsqueeze(0))
341
 
342
+ return embeds.image_embeds.to(torch_device)
343
 
344
 
345
  def localimg_2_embed(image_dir):
 
390
  return grid
391
 
392
 
393
+ chaosclicker_willtensor = localimg_2_embed('willpaint-imgs/chaosclicker-willpaint.png').to(torch_device)
394
+ contentcnsr_willtensor = localimg_2_embed('willpaint-imgs/contentconnoisseur-willpaint.png').to(torch_device)
395
+ digdaydrmr_willtensor = localimg_2_embed('willpaint-imgs/digitaldaydreamer-willpaint.png').to(torch_device)
396
+ ecoexplr_willtensor = localimg_2_embed('willpaint-imgs/ecoexplorer-willpaint.png').to(torch_device)
397
+ fandomfox_willtensor = localimg_2_embed('willpaint-imgs/fandomfox-willpaint.png').to(torch_device)
398
+ mememaven_willtensor = localimg_2_embed('willpaint-imgs/mememaven-willpaint.png').to(torch_device)
399
+ newsnerd_willtensor = localimg_2_embed('willpaint-imgs/newnerd-willpaint.png').to(torch_device)
400
+ nostalgicnvgtr_willtensor = localimg_2_embed('willpaint-imgs/nostalgicnavigator-willpaint.png').to(torch_device)
401
+ scrollseeker_willtensor = localimg_2_embed('willpaint-imgs/scrollseeker-willpaint.png').to(torch_device)
402
+ trendtracker_willtensor = localimg_2_embed('willpaint-imgs/trendtracker-willpaint.png').to(torch_device)
403
 
404
 
405
  will_cand_tensors = torch.cat([chaosclicker_willtensor,