Dakerqi commited on
Commit
efcd292
·
verified ·
1 Parent(s): 0aadd5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -63,7 +63,6 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt
63
  elif isinstance(caption, str):
64
  captions.append(caption)
65
  elif isinstance(caption, (list, np.ndarray)):
66
- # take a random caption if there are multiple
67
  captions.append(random.choice(caption) if is_train else caption[0])
68
 
69
  with torch.no_grad():
@@ -76,12 +75,14 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt
76
  return_tensors="pt",
77
  )
78
 
79
- text_input_ids = text_inputs.input_ids
80
- prompt_masks = text_inputs.attention_mask
 
 
81
 
82
  prompt_embeds = text_encoder(
83
- input_ids=text_input_ids.cuda(),
84
- attention_mask=prompt_masks.cuda(),
85
  output_hidden_states=True,
86
  ).hidden_states[-2]
87
 
 
63
  elif isinstance(caption, str):
64
  captions.append(caption)
65
  elif isinstance(caption, (list, np.ndarray)):
 
66
  captions.append(random.choice(caption) if is_train else caption[0])
67
 
68
  with torch.no_grad():
 
75
  return_tensors="pt",
76
  )
77
 
78
+
79
+ device = text_encoder.device
80
+ text_input_ids = text_inputs.input_ids.to(device)
81
+ prompt_masks = text_inputs.attention_mask.to(device)
82
 
83
  prompt_embeds = text_encoder(
84
+ input_ids=text_input_ids,
85
+ attention_mask=prompt_masks,
86
  output_hidden_states=True,
87
  ).hidden_states[-2]
88