alethanhson commited on
Commit
86ecc51
·
1 Parent(s): 2f09003
Files changed (1) hide show
  1. generator.py +5 -23
generator.py CHANGED
@@ -164,26 +164,8 @@ class Generator:
164
 
165
 
166
  def load_csm_1b(device: str = "cuda") -> Generator:
167
- try:
168
- # Try the simple approach first
169
- model = Model.from_pretrained("sesame/csm-1b")
170
- model = model.to(device=device)
171
-
172
- generator = Generator(model)
173
- return generator
174
- except Exception as e:
175
- # Log the error for debugging
176
- import logging
177
- logging.error(f"Error in standard model loading: {str(e)}")
178
-
179
- # Try alternative loading if config is available
180
- try:
181
- from silentcipher import Config
182
- model_path = "sesame/csm-1b"
183
- config = Config.from_pretrained(model_path)
184
- model = Model.from_pretrained(model_path, config=config)
185
- model = model.to(device=device)
186
- return Generator(model)
187
- except ImportError:
188
- # Config not available, try direct initialization
189
- raise RuntimeError("Could not load model with either method. Original error: " + str(e))
 
164
 
165
 
166
  def load_csm_1b(device: str = "cuda") -> Generator:
167
+ model = Model.from_pretrained("sesame/csm-1b")
168
+ model.to(device=device, dtype=torch.bfloat16)
169
+
170
+ generator = Generator(model)
171
+ return generator