Spaces:
Sleeping
Sleeping
alethanhson
commited on
Commit
·
2f09003
1
Parent(s):
e0af3c6
fix
Browse files- app.py +4 -2
- app_huggingface.py +17 -4
- generator.py +23 -5
app.py
CHANGED
@@ -8,7 +8,7 @@ import torchaudio
|
|
8 |
import gradio as gr
|
9 |
import numpy as np
|
10 |
|
11 |
-
from generator import
|
12 |
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
logger = logging.getLogger(__name__)
|
@@ -25,7 +25,9 @@ def initialize_model():
|
|
25 |
logger.info(f"Using device: {device}")
|
26 |
|
27 |
try:
|
28 |
-
|
|
|
|
|
29 |
logger.info(f"Model loaded successfully on device: {device}")
|
30 |
return True
|
31 |
except Exception as e:
|
|
|
8 |
import gradio as gr
|
9 |
import numpy as np
|
10 |
|
11 |
+
from generator import Segment, Model, Generator
|
12 |
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
logger = logging.getLogger(__name__)
|
|
|
25 |
logger.info(f"Using device: {device}")
|
26 |
|
27 |
try:
|
28 |
+
model = Model.from_pretrained("sesame/csm-1b")
|
29 |
+
model = model.to(device=device)
|
30 |
+
generator = Generator(model)
|
31 |
logger.info(f"Model loaded successfully on device: {device}")
|
32 |
return True
|
33 |
except Exception as e:
|
app_huggingface.py
CHANGED
@@ -25,7 +25,8 @@ class MockGenerator:
|
|
25 |
try:
|
26 |
import torch
|
27 |
import torchaudio
|
28 |
-
|
|
|
29 |
TORCH_AVAILABLE = True
|
30 |
except ImportError:
|
31 |
TORCH_AVAILABLE = False
|
@@ -57,9 +58,21 @@ def initialize_model():
|
|
57 |
logger.info(f"Using device: {device}")
|
58 |
|
59 |
try:
|
60 |
-
#
|
61 |
-
generator
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
except Exception as e:
|
64 |
logger.error(f"Error loading actual model: {str(e)}")
|
65 |
# Fall back to mock generator
|
|
|
25 |
try:
|
26 |
import torch
|
27 |
import torchaudio
|
28 |
+
# Chỉ import các thành phần cần thiết
|
29 |
+
from generator import Segment
|
30 |
TORCH_AVAILABLE = True
|
31 |
except ImportError:
|
32 |
TORCH_AVAILABLE = False
|
|
|
58 |
logger.info(f"Using device: {device}")
|
59 |
|
60 |
try:
|
61 |
+
# Cố gắng tải model theo cách khác, không sử dụng load_csm_1b
|
62 |
+
from generator import Model, Generator
|
63 |
+
from huggingface_hub import hf_hub_download
|
64 |
+
|
65 |
+
try:
|
66 |
+
# Trực tiếp khởi tạo mô hình từ pretrained
|
67 |
+
model = Model.from_pretrained("sesame/csm-1b")
|
68 |
+
model = model.to(device=device)
|
69 |
+
generator = Generator(model)
|
70 |
+
logger.info(f"Model loaded successfully on device: {device}")
|
71 |
+
except Exception as inner_e:
|
72 |
+
logger.error(f"Error loading model directly: {str(inner_e)}")
|
73 |
+
# Nếu không thể tải trực tiếp, sử dụng generator giả
|
74 |
+
logger.warning("Falling back to mock generator")
|
75 |
+
generator = MockGenerator()
|
76 |
except Exception as e:
|
77 |
logger.error(f"Error loading actual model: {str(e)}")
|
78 |
# Fall back to mock generator
|
generator.py
CHANGED
@@ -164,8 +164,26 @@ class Generator:
|
|
164 |
|
165 |
|
166 |
def load_csm_1b(device: str = "cuda") -> Generator:
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
|
166 |
def load_csm_1b(device: str = "cuda") -> Generator:
|
167 |
+
try:
|
168 |
+
# Try the simple approach first
|
169 |
+
model = Model.from_pretrained("sesame/csm-1b")
|
170 |
+
model = model.to(device=device)
|
171 |
+
|
172 |
+
generator = Generator(model)
|
173 |
+
return generator
|
174 |
+
except Exception as e:
|
175 |
+
# Log the error for debugging
|
176 |
+
import logging
|
177 |
+
logging.error(f"Error in standard model loading: {str(e)}")
|
178 |
+
|
179 |
+
# Try alternative loading if config is available
|
180 |
+
try:
|
181 |
+
from silentcipher import Config
|
182 |
+
model_path = "sesame/csm-1b"
|
183 |
+
config = Config.from_pretrained(model_path)
|
184 |
+
model = Model.from_pretrained(model_path, config=config)
|
185 |
+
model = model.to(device=device)
|
186 |
+
return Generator(model)
|
187 |
+
except ImportError:
|
188 |
+
# Config not available, try direct initialization
|
189 |
+
raise RuntimeError("Could not load model with either method. Original error: " + str(e))
|