drhead commited on
Commit
eb6df80
·
1 Parent(s): e6d942c

use hf hub download

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -10,6 +10,7 @@ import torch
10
  from torchvision.transforms import transforms
11
  from torchvision.transforms import InterpolationMode
12
  import torchvision.transforms.functional as TF
 
13
 
14
  torch.set_grad_enabled(False)
15
 
@@ -137,7 +138,9 @@ class GatedHead(torch.nn.Module):
137
 
138
  model.head = GatedHead(min(model.head.weight.shape), 9083)
139
 
140
- safetensors.torch.load_model(model, "JTP_PILOT2-2-e3-vit_so400m_patch14_siglip_384.safetensors")
 
 
141
  model.eval()
142
 
143
  with open("tagger_tags.json", "r") as file:
 
10
  from torchvision.transforms import transforms
11
  from torchvision.transforms import InterpolationMode
12
  import torchvision.transforms.functional as TF
13
+ from huggingface_hub import hf_hub_download
14
 
15
  torch.set_grad_enabled(False)
16
 
 
138
 
139
  model.head = GatedHead(min(model.head.weight.shape), 9083)
140
 
141
+ hf_hub_download(repo_id="RedRocket/JointTaggerProject", subfolder="JTP_PILOT2", filename="JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors")
142
+
143
+ safetensors.torch.load_model(model, "JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors")
144
  model.eval()
145
 
146
  with open("tagger_tags.json", "r") as file: