leafspark commited on
Commit
75d5350
·
verified ·
1 Parent(s): 641ee6f

feat(model): load model from safetensors

Browse files
Files changed (1) hide show
  1. inference.py +4 -3
inference.py CHANGED
@@ -24,6 +24,7 @@ from utils.snac_utils import get_snac, generate_audio_data
24
  import whisper
25
  from tqdm import tqdm
26
  from huggingface_hub import snapshot_download
 
27
 
28
 
29
  torch.set_printoptions(sci_mode=False)
@@ -351,14 +352,14 @@ def load_model(ckpt_dir, device):
351
  whispermodel = whisper.load_model("small").to(device)
352
  text_tokenizer = Tokenizer(ckpt_dir)
353
  fabric = L.Fabric(devices=1, strategy="auto")
354
- config = Config.from_file(ckpt_dir + "/model_config.yaml")
355
  config.post_adapter = False
356
 
357
  with fabric.init_module(empty_init=False):
358
  model = GPT(config)
359
 
360
  model = fabric.setup(model)
361
- state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
362
  model.load_state_dict(state_dict, strict=True)
363
  model.to(device).eval()
364
 
@@ -366,7 +367,7 @@ def load_model(ckpt_dir, device):
366
 
367
 
368
  def download_model(ckpt_dir):
369
- repo_id = "gpt-omni/mini-omni"
370
  snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
371
 
372
 
 
24
  import whisper
25
  from tqdm import tqdm
26
  from huggingface_hub import snapshot_download
27
+ from safetensors.torch import load_file
28
 
29
 
30
  torch.set_printoptions(sci_mode=False)
 
352
  whispermodel = whisper.load_model("small").to(device)
353
  text_tokenizer = Tokenizer(ckpt_dir)
354
  fabric = L.Fabric(devices=1, strategy="auto")
355
+ config = Config.from_file(ckpt_dir + "/config.json")
356
  config.post_adapter = False
357
 
358
  with fabric.init_module(empty_init=False):
359
  model = GPT(config)
360
 
361
  model = fabric.setup(model)
362
+ state_dict = load_file(ckpt_dir + "/lit_model.safetensors")
363
  model.load_state_dict(state_dict, strict=True)
364
  model.to(device).eval()
365
 
 
367
 
368
 
369
  def download_model(ckpt_dir):
370
+ repo_id = "leafspark/mini-omni-safetensors"
371
  snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
372
 
373