WYBar commited on
Commit
e06d242
·
1 Parent(s): d9bc274

add device

Browse files
Files changed (2) hide show
  1. app.py +2 -0
  2. app_test.py +2 -0
app.py CHANGED
@@ -330,6 +330,8 @@ def construction_layout():
330
  print("after .to(device)")
331
  model = model.bfloat16()
332
  model.eval()
 
 
333
  return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
334
 
335
  @torch.no_grad()
 
330
  print("after .to(device)")
331
  model = model.bfloat16()
332
  model.eval()
333
+ quantizer = quantizer.to("cuda")
334
+ tokenizer = tokenizer.to("cuda")
335
  return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
336
 
337
  @torch.no_grad()
app_test.py CHANGED
@@ -330,6 +330,8 @@ def construction_layout():
330
  print("after .to(device)")
331
  model = model.bfloat16()
332
  model.eval()
 
 
333
  return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
334
 
335
  @torch.no_grad()
 
330
  print("after .to(device)")
331
  model = model.bfloat16()
332
  model.eval()
333
+ quantizer = quantizer.to("cuda")
334
+ tokenizer = tokenizer.to("cuda")
335
  return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
336
 
337
  @torch.no_grad()