Spaces:
Runtime error
Runtime error
feat(model): load model from safetensors
Browse files- inference.py +4 -3
inference.py
CHANGED
@@ -24,6 +24,7 @@ from utils.snac_utils import get_snac, generate_audio_data
|
|
24 |
import whisper
|
25 |
from tqdm import tqdm
|
26 |
from huggingface_hub import snapshot_download
|
|
|
27 |
|
28 |
|
29 |
torch.set_printoptions(sci_mode=False)
|
@@ -351,14 +352,14 @@ def load_model(ckpt_dir, device):
|
|
351 |
whispermodel = whisper.load_model("small").to(device)
|
352 |
text_tokenizer = Tokenizer(ckpt_dir)
|
353 |
fabric = L.Fabric(devices=1, strategy="auto")
|
354 |
-
config = Config.from_file(ckpt_dir + "/
|
355 |
config.post_adapter = False
|
356 |
|
357 |
with fabric.init_module(empty_init=False):
|
358 |
model = GPT(config)
|
359 |
|
360 |
model = fabric.setup(model)
|
361 |
-
state_dict =
|
362 |
model.load_state_dict(state_dict, strict=True)
|
363 |
model.to(device).eval()
|
364 |
|
@@ -366,7 +367,7 @@ def load_model(ckpt_dir, device):
|
|
366 |
|
367 |
|
368 |
def download_model(ckpt_dir):
|
369 |
-
repo_id = "
|
370 |
snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
|
371 |
|
372 |
|
|
|
24 |
import whisper
|
25 |
from tqdm import tqdm
|
26 |
from huggingface_hub import snapshot_download
|
27 |
+
from safetensors.torch import load_file
|
28 |
|
29 |
|
30 |
torch.set_printoptions(sci_mode=False)
|
|
|
352 |
whispermodel = whisper.load_model("small").to(device)
|
353 |
text_tokenizer = Tokenizer(ckpt_dir)
|
354 |
fabric = L.Fabric(devices=1, strategy="auto")
|
355 |
+
config = Config.from_file(ckpt_dir + "/config.json")
|
356 |
config.post_adapter = False
|
357 |
|
358 |
with fabric.init_module(empty_init=False):
|
359 |
model = GPT(config)
|
360 |
|
361 |
model = fabric.setup(model)
|
362 |
+
state_dict = load_file(ckpt_dir + "/lit_model.safetensors")
|
363 |
model.load_state_dict(state_dict, strict=True)
|
364 |
model.to(device).eval()
|
365 |
|
|
|
367 |
|
368 |
|
369 |
def download_model(ckpt_dir):
|
370 |
+
repo_id = "leafspark/mini-omni-safetensors"
|
371 |
snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
|
372 |
|
373 |
|