Peijie commited on
Commit
11e3b9b
·
1 Parent(s): 81867df

directly load model weight to gpu

Browse files
Files changed (1) hide show
  1. utils/load_model.py +1 -1
utils/load_model.py CHANGED
@@ -34,7 +34,7 @@ def load_xclip(device: str = "cuda:0",
34
  "loss_sym_box_label": 0, "loss_xclip": 0}
35
  model = OwlViTForClassification(owlvit_det_model=owlvit_det_model, num_classes=n_classes, device=device, weight_dict=weight_dict, logits_from_teacher=use_teacher_logits, custom_box_head=custom_box_head)
36
  if model_path is not None:
37
- ckpt = torch.load(model_path, map_location='cpu')
38
  model.load_state_dict(ckpt, strict=False)
39
  model.to(device)
40
  return model, owlvit_det_processor
 
34
  "loss_sym_box_label": 0, "loss_xclip": 0}
35
  model = OwlViTForClassification(owlvit_det_model=owlvit_det_model, num_classes=n_classes, device=device, weight_dict=weight_dict, logits_from_teacher=use_teacher_logits, custom_box_head=custom_box_head)
36
  if model_path is not None:
37
+ ckpt = torch.load(model_path, map_location=device, weights_only=True)
38
  model.load_state_dict(ckpt, strict=False)
39
  model.to(device)
40
  return model, owlvit_det_processor