alethanhson commited on
Commit
2f09003
·
1 Parent(s): e0af3c6
Files changed (3) hide show
  1. app.py +4 -2
  2. app_huggingface.py +17 -4
  3. generator.py +23 -5
app.py CHANGED
@@ -8,7 +8,7 @@ import torchaudio
8
  import gradio as gr
9
  import numpy as np
10
 
11
- from generator import load_csm_1b, Segment
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
@@ -25,7 +25,9 @@ def initialize_model():
25
  logger.info(f"Using device: {device}")
26
 
27
  try:
28
- generator = load_csm_1b(device=device)
 
 
29
  logger.info(f"Model loaded successfully on device: {device}")
30
  return True
31
  except Exception as e:
 
8
  import gradio as gr
9
  import numpy as np
10
 
11
+ from generator import Segment, Model, Generator
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
 
25
  logger.info(f"Using device: {device}")
26
 
27
  try:
28
+ model = Model.from_pretrained("sesame/csm-1b")
29
+ model = model.to(device=device)
30
+ generator = Generator(model)
31
  logger.info(f"Model loaded successfully on device: {device}")
32
  return True
33
  except Exception as e:
app_huggingface.py CHANGED
@@ -25,7 +25,8 @@ class MockGenerator:
25
  try:
26
  import torch
27
  import torchaudio
28
- from generator import load_csm_1b, Segment
 
29
  TORCH_AVAILABLE = True
30
  except ImportError:
31
  TORCH_AVAILABLE = False
@@ -57,9 +58,21 @@ def initialize_model():
57
  logger.info(f"Using device: {device}")
58
 
59
  try:
60
- # Try to use the actual model
61
- generator = load_csm_1b(device=device)
62
- logger.info(f"Model loaded successfully on device: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
63
  except Exception as e:
64
  logger.error(f"Error loading actual model: {str(e)}")
65
  # Fall back to mock generator
 
25
  try:
26
  import torch
27
  import torchaudio
28
+ # Chỉ import các thành phần cần thiết
29
+ from generator import Segment
30
  TORCH_AVAILABLE = True
31
  except ImportError:
32
  TORCH_AVAILABLE = False
 
58
  logger.info(f"Using device: {device}")
59
 
60
  try:
61
+ # Cố gắng tải model theo cách khác, không sử dụng load_csm_1b
62
+ from generator import Model, Generator
63
+ from huggingface_hub import hf_hub_download
64
+
65
+ try:
66
+ # Trực tiếp khởi tạo mô hình từ pretrained
67
+ model = Model.from_pretrained("sesame/csm-1b")
68
+ model = model.to(device=device)
69
+ generator = Generator(model)
70
+ logger.info(f"Model loaded successfully on device: {device}")
71
+ except Exception as inner_e:
72
+ logger.error(f"Error loading model directly: {str(inner_e)}")
73
+ # Nếu không thể tải trực tiếp, sử dụng generator giả
74
+ logger.warning("Falling back to mock generator")
75
+ generator = MockGenerator()
76
  except Exception as e:
77
  logger.error(f"Error loading actual model: {str(e)}")
78
  # Fall back to mock generator
generator.py CHANGED
@@ -164,8 +164,26 @@ class Generator:
164
 
165
 
166
  def load_csm_1b(device: str = "cuda") -> Generator:
167
- model = Model.from_pretrained("sesame/csm-1b")
168
- model.to(device=device, dtype=torch.bfloat16)
169
-
170
- generator = Generator(model)
171
- return generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
 
166
  def load_csm_1b(device: str = "cuda") -> Generator:
167
+ try:
168
+ # Try the simple approach first
169
+ model = Model.from_pretrained("sesame/csm-1b")
170
+ model = model.to(device=device)
171
+
172
+ generator = Generator(model)
173
+ return generator
174
+ except Exception as e:
175
+ # Log the error for debugging
176
+ import logging
177
+ logging.error(f"Error in standard model loading: {str(e)}")
178
+
179
+ # Try alternative loading if config is available
180
+ try:
181
+ from silentcipher import Config
182
+ model_path = "sesame/csm-1b"
183
+ config = Config.from_pretrained(model_path)
184
+ model = Model.from_pretrained(model_path, config=config)
185
+ model = model.to(device=device)
186
+ return Generator(model)
187
+ except ImportError:
188
+ # Config not available, try direct initialization
189
+ raise RuntimeError("Could not load model with either method. Original error: " + str(e))