Spaces:
Sleeping
Sleeping
alethanhson
commited on
Commit
·
86ecc51
1
Parent(s):
2f09003
fix
Browse files- generator.py +5 -23
generator.py
CHANGED
@@ -164,26 +164,8 @@ class Generator:
|
|
164 |
|
165 |
|
166 |
def load_csm_1b(device: str = "cuda") -> Generator:
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
generator = Generator(model)
|
173 |
-
return generator
|
174 |
-
except Exception as e:
|
175 |
-
# Log the error for debugging
|
176 |
-
import logging
|
177 |
-
logging.error(f"Error in standard model loading: {str(e)}")
|
178 |
-
|
179 |
-
# Try alternative loading if config is available
|
180 |
-
try:
|
181 |
-
from silentcipher import Config
|
182 |
-
model_path = "sesame/csm-1b"
|
183 |
-
config = Config.from_pretrained(model_path)
|
184 |
-
model = Model.from_pretrained(model_path, config=config)
|
185 |
-
model = model.to(device=device)
|
186 |
-
return Generator(model)
|
187 |
-
except ImportError:
|
188 |
-
# Config not available, try direct initialization
|
189 |
-
raise RuntimeError("Could not load model with either method. Original error: " + str(e))
|
|
|
164 |
|
165 |
|
166 |
def load_csm_1b(device: str = "cuda") -> Generator:
|
167 |
+
model = Model.from_pretrained("sesame/csm-1b")
|
168 |
+
model.to(device=device, dtype=torch.bfloat16)
|
169 |
+
|
170 |
+
generator = Generator(model)
|
171 |
+
return generator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|