Spaces:
Running on Zero

Ruurd commited on
Commit
932e0b0
·
verified ·
1 Parent(s): b36c7a9

Try different monkey-patch

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -47,16 +47,23 @@ def load_model():
47
  )
48
 
49
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
- model = torch.load(ckpt_path, map_location=device)
51
 
52
- # 🔧 Monkey-patch for missing PEFT attribute
53
- def add_cast_input_dtype_enabled(module):
54
- for child in module.children():
55
- add_cast_input_dtype_enabled(child)
56
- if isinstance(module, torch.nn.Linear) and not hasattr(module, "cast_input_dtype_enabled"):
57
- module.cast_input_dtype_enabled = False
58
 
59
- add_cast_input_dtype_enabled(model)
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  model = disable_dropout(model)
62
  model.to(device)
@@ -64,6 +71,7 @@ def load_model():
64
  return model
65
 
66
 
 
67
  rng = np.random.default_rng()
68
 
69
  # --- Utility Functions ---
 
47
  )
48
 
49
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
50
 
51
+ # Step 1: Create model from scratch
52
+ model = CustomTransformerModel(CustomTransformerConfig())
 
 
 
 
53
 
54
+ # Step 2: Load state_dict from full checkpoint
55
+ full_model = torch.load(ckpt_path, map_location=device)
56
+
57
+ # This handles both full model or just state_dict
58
+ try:
59
+ state_dict = full_model.state_dict()
60
+ except AttributeError:
61
+ state_dict = full_model # already a state_dict
62
+
63
+ # Step 3: Load weights (might print mismatches)
64
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
65
+ print("Missing keys:", missing)
66
+ print("Unexpected keys:", unexpected)
67
 
68
  model = disable_dropout(model)
69
  model.to(device)
 
71
  return model
72
 
73
 
74
+
75
  rng = np.random.default_rng()
76
 
77
  # --- Utility Functions ---