tnt306 commited on
Commit
6bce319
·
1 Parent(s): c9a8555

Added map_location

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -20,7 +20,7 @@ D_LABITEMS = pd.read_csv("D_LABITEMS.csv", header = "infer", sep = ",", encoding
20
 
21
  def load_model():
22
  path = r"final_model.pt"
23
- kwargs, state = torch.load(path, weights_only=False)
24
  model = VariationalGNN(**kwargs).to(device)
25
  model.load_state_dict(state)
26
  return model
 
20
 
21
  def load_model():
22
  path = r"final_model.pt"
23
+ kwargs, state = torch.load(path, weights_only=False, map_location=device)
24
  model = VariationalGNN(**kwargs).to(device)
25
  model.load_state_dict(state)
26
  return model