chenjoya commited on
Commit
821526b
·
verified ·
1 Parent(s): fddefb0

Update demo/infer.py

Browse files
Files changed (1) hide show
  1. demo/infer.py +1 -1
demo/infer.py CHANGED
@@ -206,7 +206,7 @@ class LiveCCDemoInfer:
206
  return_tensors="pt",
207
  return_attention_mask=False
208
  )
209
- inputs.to('cuda')
210
  if past_ids is not None:
211
  inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
212
  outputs = self.model.generate(
 
206
  return_tensors="pt",
207
  return_attention_mask=False
208
  )
209
+ inputs.to(self.model.device)
210
  if past_ids is not None:
211
  inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
212
  outputs = self.model.generate(