Zackh commited on
Commit
45e163c
·
1 Parent(s): ef55fce

switch to loading with from_pretrained

Browse files
Files changed (3) hide show
  1. app.py +1 -2
  2. generator.py +4 -12
  3. models.py +17 -3
app.py CHANGED
@@ -102,8 +102,7 @@ SPEAKER_PROMPTS = {
102
  }
103
 
104
  device = "cuda" if torch.cuda.is_available() else "cpu"
105
- model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt")
106
- generator = load_csm_1b(model_path, device)
107
 
108
 
109
  @spaces.GPU(duration=gpu_timeout)
 
102
  }
103
 
104
  device = "cuda" if torch.cuda.is_available() else "cpu"
105
+ generator = load_csm_1b(device=device)
 
106
 
107
 
108
  @spaces.GPU(duration=gpu_timeout)
generator.py CHANGED
@@ -5,7 +5,7 @@ from typing import List, Tuple
5
  import torch
6
  import torchaudio
7
  from huggingface_hub import hf_hub_download
8
- from models import Model, ModelArgs
9
  from moshi.models import loaders
10
  from tokenizers.processors import TemplateProcessing
11
  from transformers import AutoTokenizer
@@ -166,17 +166,9 @@ class Generator:
166
  return audio
167
 
168
 
169
- def load_csm_1b(ckpt_path: str = "ckpt.pt", device: str = "cuda") -> Generator:
170
- model_args = ModelArgs(
171
- backbone_flavor="llama-1B",
172
- decoder_flavor="llama-100M",
173
- text_vocab_size=128256,
174
- audio_vocab_size=2051,
175
- audio_num_codebooks=32,
176
- )
177
- model = Model(model_args).to(device=device, dtype=torch.bfloat16)
178
- state_dict = torch.load(ckpt_path)
179
- model.load_state_dict(state_dict)
180
 
181
  generator = Generator(model)
182
  return generator
 
5
  import torch
6
  import torchaudio
7
  from huggingface_hub import hf_hub_download
8
+ from models import Model
9
  from moshi.models import loaders
10
  from tokenizers.processors import TemplateProcessing
11
  from transformers import AutoTokenizer
 
166
  return audio
167
 
168
 
169
+ def load_csm_1b(device: str = "cuda") -> Generator:
170
+ model = Model.from_pretrained("sesame/csm-1b")
171
+ model.to(device=device, dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
172
 
173
  generator = Generator(model)
174
  return generator
models.py CHANGED
@@ -1,8 +1,9 @@
1
- from dataclasses import dataclass
2
 
3
  import torch
4
  import torch.nn as nn
5
  import torchtune
 
6
  from torchtune.models import llama3_2
7
 
8
 
@@ -95,7 +96,20 @@ class ModelArgs:
95
  audio_num_codebooks: int
96
 
97
 
98
- class Model(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def __init__(self, args: ModelArgs):
100
  super().__init__()
101
  self.args = args
@@ -110,7 +124,7 @@ class Model(nn.Module):
110
  self.codebook0_head = nn.Linear(backbone_dim, args.audio_vocab_size, bias=False)
111
  self.audio_head = nn.Parameter(torch.empty(args.audio_num_codebooks - 1, decoder_dim, args.audio_vocab_size))
112
 
113
- def setup_caches(self, max_batch_size: int) -> torch.Tensor:
114
  """Setup KV caches and return a causal mask."""
115
  dtype = next(self.parameters()).dtype
116
  device = next(self.parameters()).device
 
1
+ from dataclasses import asdict, dataclass
2
 
3
  import torch
4
  import torch.nn as nn
5
  import torchtune
6
+ from huggingface_hub import PyTorchModelHubMixin
7
  from torchtune.models import llama3_2
8
 
9
 
 
96
  audio_num_codebooks: int
97
 
98
 
99
+ class Model(
100
+ nn.Module,
101
+ PyTorchModelHubMixin,
102
+ repo_url="https://github.com/SesameAILabs/csm",
103
+ pipeline_tag="text-to-speech",
104
+ license="apache-2.0",
105
+ coders={
106
+ # Tells the class how to serialize and deserialize config.json
107
+ ModelArgs : (
108
+ lambda x: asdict(x), # Encoder: how to convert a `ModelArgs` to a valid jsonable value?
109
+ lambda data: ModelArgs(**data), # Decoder: how to reconstruct a `ModelArgs` from a dictionary?
110
+ )
111
+ }
112
+ ):
113
  def __init__(self, args: ModelArgs):
114
  super().__init__()
115
  self.args = args
 
124
  self.codebook0_head = nn.Linear(backbone_dim, args.audio_vocab_size, bias=False)
125
  self.audio_head = nn.Parameter(torch.empty(args.audio_num_codebooks - 1, decoder_dim, args.audio_vocab_size))
126
 
127
+ def setup_caches(self, max_batch_size: int) -> None:
128
  """Setup KV caches and return a causal mask."""
129
  dtype = next(self.parameters()).dtype
130
  device = next(self.parameters()).device