Ruurd commited on
Commit
b36c7a9
·
verified ·
1 Parent(s): 526493a

Monkey-patch (temporarily) for LoRA layers

Browse files
Files changed (1) hide show
  1. app.py +24 -0
app.py CHANGED
@@ -25,6 +25,20 @@ with open("token_probabilities.json") as f:
25
  token_probs_dict = json.load(f)
26
  token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def load_model():
29
  ckpt_path = hf_hub_download(
30
  repo_id="ruurd/tini_bi_m",
@@ -34,6 +48,16 @@ def load_model():
34
 
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  model = torch.load(ckpt_path, map_location=device)
 
 
 
 
 
 
 
 
 
 
37
  model = disable_dropout(model)
38
  model.to(device)
39
  model.eval()
 
25
  token_probs_dict = json.load(f)
26
  token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
27
 
28
+ # def load_model():
29
+ # ckpt_path = hf_hub_download(
30
+ # repo_id="ruurd/tini_bi_m",
31
+ # filename="diffusion-model.pth",
32
+ # token=os.getenv("HF_TOKEN")
33
+ # )
34
+
35
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ # model = torch.load(ckpt_path, map_location=device)
37
+ # model = disable_dropout(model)
38
+ # model.to(device)
39
+ # model.eval()
40
+ # return model
41
+
42
  def load_model():
43
  ckpt_path = hf_hub_download(
44
  repo_id="ruurd/tini_bi_m",
 
48
 
49
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
  model = torch.load(ckpt_path, map_location=device)
51
+
52
+ # 🔧 Monkey-patch for missing PEFT attribute
53
+ def add_cast_input_dtype_enabled(module):
54
+ for child in module.children():
55
+ add_cast_input_dtype_enabled(child)
56
+ if isinstance(module, torch.nn.Linear) and not hasattr(module, "cast_input_dtype_enabled"):
57
+ module.cast_input_dtype_enabled = False
58
+
59
+ add_cast_input_dtype_enabled(model)
60
+
61
  model = disable_dropout(model)
62
  model.to(device)
63
  model.eval()