Spaces:
Running on Zero

Ruurd commited on
Commit
0e3f268
·
1 Parent(s): ccc6000
Files changed (1) hide show
  1. app.py +4 -29
app.py CHANGED
@@ -8,6 +8,7 @@ from llama_diffusion_model import disable_dropout
8
  import os
9
  import importlib
10
  from huggingface_hub import hf_hub_download
 
11
 
12
  hf_token = os.getenv("HF_TOKEN")
13
 
@@ -26,42 +27,16 @@ token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(toke
26
 
27
 
28
  def load_model():
29
-
30
- # 1. Download the checkpoint
31
- checkpoint_path = hf_hub_download(
32
- repo_id="ruurd/tini_model",
33
  filename="diffusion-model.pth",
34
  token=os.getenv("HF_TOKEN")
35
  )
36
 
37
- # # 2. Prepare dynamic class loading like you did before
38
- # torch.serialization.clear_safe_globals()
39
- # unsafe_globals = torch.serialization.get_unsafe_globals_in_checkpoint(checkpoint_path)
40
- # missing_class_names = [name.split(".")[-1] for name in unsafe_globals]
41
-
42
- # safe_classes = [cls for name, cls in globals().items() if name in missing_class_names]
43
-
44
- # for class_path in unsafe_globals:
45
- # try:
46
- # module_name, class_name = class_path.rsplit(".", 1)
47
- # module = importlib.import_module(module_name)
48
- # cls = getattr(module, class_name)
49
- # safe_classes.append(cls)
50
- # except (ImportError, AttributeError) as e:
51
- # print(f"⚠️ Warning: Could not import {class_path} - {e}")
52
-
53
- # torch.serialization.add_safe_globals(safe_classes)
54
-
55
- # 3. Actually load the full model
56
- # model = torch.load(checkpoint_path, weights_only=True)
57
- model = torch.load(checkpoint_path, map_location="cuda")
58
-
59
-
60
- # 4. Final setup
61
  model = disable_dropout(model)
62
  model.to("cuda")
63
  model.eval()
64
-
65
  return model
66
 
67
 
 
8
  import os
9
  import importlib
10
  from huggingface_hub import hf_hub_download
11
+ from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, disable_dropout
12
 
13
  hf_token = os.getenv("HF_TOKEN")
14
 
 
27
 
28
 
29
  def load_model():
30
+ ckpt_path = hf_hub_download(
31
+ repo_id="ruurd/diffusion-llama",
 
 
32
  filename="diffusion-model.pth",
33
  token=os.getenv("HF_TOKEN")
34
  )
35
 
36
+ model = torch.load(ckpt_path, map_location="cuda") # no weights_only, no globals hack
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  model = disable_dropout(model)
38
  model.to("cuda")
39
  model.eval()
 
40
  return model
41
 
42