Ruurd commited on
Commit
ccc6000
·
1 Parent(s): b7639f5

Load model directly

Browse files
Files changed (1) hide show
  1. app.py +17 -15
app.py CHANGED
@@ -34,26 +34,28 @@ def load_model():
34
  token=os.getenv("HF_TOKEN")
35
  )
36
 
37
- # 2. Prepare dynamic class loading like you did before
38
- torch.serialization.clear_safe_globals()
39
- unsafe_globals = torch.serialization.get_unsafe_globals_in_checkpoint(checkpoint_path)
40
- missing_class_names = [name.split(".")[-1] for name in unsafe_globals]
41
 
42
- safe_classes = [cls for name, cls in globals().items() if name in missing_class_names]
43
 
44
- for class_path in unsafe_globals:
45
- try:
46
- module_name, class_name = class_path.rsplit(".", 1)
47
- module = importlib.import_module(module_name)
48
- cls = getattr(module, class_name)
49
- safe_classes.append(cls)
50
- except (ImportError, AttributeError) as e:
51
- print(f"⚠️ Warning: Could not import {class_path} - {e}")
52
 
53
- torch.serialization.add_safe_globals(safe_classes)
54
 
55
  # 3. Actually load the full model
56
- model = torch.load(checkpoint_path, weights_only=True)
 
 
57
 
58
  # 4. Final setup
59
  model = disable_dropout(model)
 
34
  token=os.getenv("HF_TOKEN")
35
  )
36
 
37
+ # # 2. Prepare dynamic class loading like you did before
38
+ # torch.serialization.clear_safe_globals()
39
+ # unsafe_globals = torch.serialization.get_unsafe_globals_in_checkpoint(checkpoint_path)
40
+ # missing_class_names = [name.split(".")[-1] for name in unsafe_globals]
41
 
42
+ # safe_classes = [cls for name, cls in globals().items() if name in missing_class_names]
43
 
44
+ # for class_path in unsafe_globals:
45
+ # try:
46
+ # module_name, class_name = class_path.rsplit(".", 1)
47
+ # module = importlib.import_module(module_name)
48
+ # cls = getattr(module, class_name)
49
+ # safe_classes.append(cls)
50
+ # except (ImportError, AttributeError) as e:
51
+ # print(f"⚠️ Warning: Could not import {class_path} - {e}")
52
 
53
+ # torch.serialization.add_safe_globals(safe_classes)
54
 
55
  # 3. Actually load the full model
56
+ # model = torch.load(checkpoint_path, weights_only=True)
57
+ model = torch.load(checkpoint_path, map_location="cuda")
58
+
59
 
60
  # 4. Final setup
61
  model = disable_dropout(model)