alethanhson commited on
Commit
9605f46
·
1 Parent(s): ca183c0
Files changed (5) hide show
  1. Procfile +1 -1
  2. app.py +9 -67
  3. app_huggingface.py +227 -137
  4. generator.py +5 -43
  5. setup.sh +28 -0
Procfile CHANGED
@@ -1 +1 @@
1
- web: python app_huggingface.py
 
1
+ web: bash setup.sh && python app_huggingface.py
app.py CHANGED
@@ -8,20 +8,7 @@ import torchaudio
8
  import gradio as gr
9
  import numpy as np
10
 
11
- from generator import Segment
12
-
13
- # Tạo một lớp generator giả để sử dụng khi không thể tải model thật
14
- class MockGenerator:
15
- def __init__(self):
16
- self.sample_rate = 24000
17
- logging.info("Created mock generator with sample rate 24000")
18
-
19
- def generate(self, text, speaker, context=None, max_audio_length_ms=10000, temperature=0.9, topk=50):
20
- # Tạo âm thanh giả - chỉ là silence với độ dài tỷ lệ với text
21
- duration_seconds = min(len(text) * 0.1, max_audio_length_ms / 1000)
22
- samples = int(duration_seconds * self.sample_rate)
23
- logging.info(f"Generating mock audio with {samples} samples")
24
- return torch.zeros(samples, dtype=torch.float32)
25
 
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
@@ -38,59 +25,21 @@ def initialize_model():
38
  logger.info(f"Using device: {device}")
39
 
40
  try:
41
- # Thử tải mô hình qua hàm load_csm_1b
42
- try:
43
- from generator import load_csm_1b
44
- generator = load_csm_1b(device=device)
45
- logger.info("Model loaded successfully using load_csm_1b")
46
- return True
47
- except Exception as e:
48
- logger.warning(f"Could not load model using load_csm_1b: {str(e)}")
49
-
50
- # Thử tải trực tiếp với config
51
- try:
52
- from generator import Model, Generator
53
- from huggingface_hub import hf_hub_download
54
- import json
55
-
56
- # Tạo dummy config
57
- class DummyConfig:
58
- def __init__(self, **kwargs):
59
- for key, value in kwargs.items():
60
- setattr(self, key, value)
61
-
62
- # Tải config từ HF Hub
63
- config_file = hf_hub_download("sesame/csm-1b", "config.json")
64
- with open(config_file, 'r') as f:
65
- config_dict = json.load(f)
66
-
67
- config = DummyConfig(**config_dict)
68
- model = Model.from_pretrained("sesame/csm-1b", config=config)
69
- model = model.to(device=device)
70
- generator = Generator(model)
71
- logger.info("Model loaded successfully using direct loading with config")
72
- return True
73
- except Exception as inner_e:
74
- logger.error(f"Error loading model directly: {str(inner_e)}")
75
-
76
- # Sử dụng mock generator nếu không thể tải model thật
77
- logger.warning("Using mock generator as fallback")
78
- generator = MockGenerator()
79
- return True
80
  except Exception as e:
81
  logger.error(f"Could not load model: {str(e)}")
82
- # Sử dụng mock generator để ứng dụng vẫn chạy được
83
- generator = MockGenerator()
84
- return True
85
 
86
  def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9, topk=50, context_texts=None, context_speakers=None):
87
  global generator
88
 
89
  if generator is None:
90
  if not initialize_model():
91
- # Ngay cả khi không khởi tạo được, vẫn tạo một mock generator
92
- generator = MockGenerator()
93
- logger.warning("Using mock generator as fallback")
94
 
95
  try:
96
  # Process context if provided
@@ -120,14 +69,7 @@ def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9
120
 
121
  except Exception as e:
122
  logger.error(f"Error generating audio: {str(e)}")
123
- # Trong trường hợp lỗi, tạo âm thanh giả
124
- mock_generator = MockGenerator()
125
- audio = mock_generator.generate(
126
- text=text,
127
- speaker=int(speaker_id),
128
- max_audio_length_ms=float(max_audio_length_ms)
129
- )
130
- return (mock_generator.sample_rate, audio.numpy()), f"Error, using silent audio: {str(e)}"
131
 
132
  def clear_context():
133
  return [], []
 
8
  import gradio as gr
9
  import numpy as np
10
 
11
+ from generator import Segment, Model, Generator
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
 
25
  logger.info(f"Using device: {device}")
26
 
27
  try:
28
+ model = Model.from_pretrained("sesame/csm-1b")
29
+ model = model.to(device=device)
30
+ generator = Generator(model)
31
+ logger.info(f"Model loaded successfully on device: {device}")
32
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  except Exception as e:
34
  logger.error(f"Could not load model: {str(e)}")
35
+ return False
 
 
36
 
37
  def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9, topk=50, context_texts=None, context_speakers=None):
38
  global generator
39
 
40
  if generator is None:
41
  if not initialize_model():
42
+ return None, "Could not load model. Please try again later."
 
 
43
 
44
  try:
45
  # Process context if provided
 
69
 
70
  except Exception as e:
71
  logger.error(f"Error generating audio: {str(e)}")
72
+ return None, f"Error generating audio: {str(e)}"
 
 
 
 
 
 
 
73
 
74
  def clear_context():
75
  return [], []
app_huggingface.py CHANGED
@@ -2,13 +2,40 @@ import base64
2
  import io
3
  import logging
4
  from typing import List
 
 
5
 
6
- import torch
7
- import torchaudio
8
- import gradio as gr
9
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- from generator import Segment, Model, Generator
 
 
 
 
 
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
@@ -19,39 +46,76 @@ def initialize_model():
19
  global generator
20
  logger.info("Loading CSM 1B model...")
21
 
22
- device = "cuda" if torch.cuda.is_available() else "cpu"
23
- if device == "cpu":
24
- logger.warning("GPU not available. Using CPU, performance may be slow!")
25
- logger.info(f"Using device: {device}")
 
26
 
 
27
  try:
28
- model = Model.from_pretrained("sesame/csm-1b")
29
- model = model.to(device=device)
30
- generator = Generator(model)
31
- logger.info(f"Model loaded successfully on device: {device}")
32
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  except Exception as e:
34
- logger.error(f"Could not load model: {str(e)}")
35
- return False
 
36
 
37
  def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9, topk=50, context_texts=None, context_speakers=None):
38
  global generator
39
 
40
  if generator is None:
41
  if not initialize_model():
42
- return None, "Could not load model. Please try again later."
 
43
 
44
  try:
45
- # Process context if provided
 
 
 
 
 
 
46
  context_segments = []
47
  if context_texts and context_speakers:
48
  for ctx_text, ctx_speaker in zip(context_texts, context_speakers):
49
  if ctx_text and ctx_speaker is not None:
 
 
 
 
 
 
50
  context_segments.append(
51
- Segment(text=ctx_text, speaker=int(ctx_speaker), audio=torch.zeros(0, dtype=torch.float32))
52
  )
53
 
54
- # Generate audio from text
55
  audio = generator.generate(
56
  text=text,
57
  speaker=int(speaker_id),
@@ -61,15 +125,22 @@ def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9
61
  topk=int(topk),
62
  )
63
 
64
- # Convert tensor to numpy array for Gradio
65
- audio_numpy = audio.cpu().numpy()
 
 
 
 
66
  sample_rate = generator.sample_rate
67
 
68
  return (sample_rate, audio_numpy), None
69
 
70
  except Exception as e:
71
  logger.error(f"Error generating audio: {str(e)}")
72
- return None, f"Error generating audio: {str(e)}"
 
 
 
73
 
74
  def clear_context():
75
  return [], []
@@ -80,132 +151,151 @@ def add_context(text, speaker_id, context_texts, context_speakers):
80
  context_speakers.append(int(speaker_id))
81
  return context_texts, context_speakers
82
 
83
- # Set up Gradio interface
84
- with gr.Blocks(title="CSM 1B Demo") as demo:
85
- gr.Markdown("# CSM 1B - Conversational Speech Model")
86
- gr.Markdown("Enter text to generate natural-sounding speech with the CSM 1B model")
 
 
 
 
87
 
88
- with gr.Row():
89
- with gr.Column(scale=2):
90
- text_input = gr.Textbox(
91
- label="Text to convert to speech",
92
- placeholder="Enter your text here...",
93
- lines=3
94
- )
95
- speaker_id = gr.Slider(
96
- label="Speaker ID",
97
- minimum=0,
98
- maximum=10,
99
- step=1,
100
- value=0
101
- )
102
-
103
- with gr.Accordion("Advanced Options", open=False):
104
- max_length = gr.Slider(
105
- label="Maximum length (milliseconds)",
106
- minimum=1000,
107
- maximum=30000,
108
- step=1000,
109
- value=10000
110
- )
111
- temp = gr.Slider(
112
- label="Temperature",
113
- minimum=0.1,
114
- maximum=1.5,
115
- step=0.1,
116
- value=0.9
117
  )
118
- top_k = gr.Slider(
119
- label="Top K",
120
- minimum=10,
121
- maximum=100,
122
- step=10,
123
- value=50
124
  )
125
-
126
- with gr.Accordion("Conversation Context", open=False):
127
- context_list = gr.State([])
128
- context_speakers_list = gr.State([])
129
 
130
- with gr.Row():
131
- context_text = gr.Textbox(label="Context text", lines=2)
132
- context_speaker = gr.Slider(
133
- label="Context speaker ID",
134
- minimum=0,
135
- maximum=10,
136
- step=1,
137
- value=0
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  )
139
 
140
- with gr.Row():
141
- add_ctx_btn = gr.Button("Add Context")
142
- clear_ctx_btn = gr.Button("Clear All Context")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- context_display = gr.Dataframe(
145
- headers=["Text", "Speaker ID"],
146
- label="Current Context",
147
- interactive=False
148
- )
149
 
150
- generate_btn = gr.Button("Generate Audio", variant="primary")
 
 
151
 
152
- with gr.Column(scale=1):
153
- audio_output = gr.Audio(label="Generated Audio", type="numpy")
154
- error_output = gr.Textbox(label="Error Message", visible=False)
155
-
156
- # Connect events
157
- generate_btn.click(
158
- fn=generate_speech,
159
- inputs=[
160
- text_input,
161
- speaker_id,
162
- max_length,
163
- temp,
164
- top_k,
165
- context_list,
166
- context_speakers_list
167
- ],
168
- outputs=[audio_output, error_output]
169
- )
170
-
171
- add_ctx_btn.click(
172
- fn=add_context,
173
- inputs=[
174
- context_text,
175
- context_speaker,
176
- context_list,
177
- context_speakers_list
178
- ],
179
- outputs=[context_list, context_speakers_list]
180
- )
181
-
182
- clear_ctx_btn.click(
183
- fn=clear_context,
184
- inputs=[],
185
- outputs=[context_list, context_speakers_list]
186
- )
187
-
188
- # Update context display
189
- def update_context_display(texts, speakers):
190
- if not texts or not speakers:
191
- return []
192
- return [[text, speaker] for text, speaker in zip(texts, speakers)]
193
-
194
- context_list.change(
195
- fn=update_context_display,
196
- inputs=[context_list, context_speakers_list],
197
- outputs=[context_display]
198
- )
 
 
 
 
 
 
 
199
 
200
- context_speakers_list.change(
201
- fn=update_context_display,
202
- inputs=[context_list, context_speakers_list],
203
- outputs=[context_display]
204
- )
205
 
206
- # Initialize model when page loads
207
  initialize_model()
208
 
209
- # Configuration for Hugging Face Spaces
210
- demo.launch(share=False)
 
211
 
 
2
  import io
3
  import logging
4
  from typing import List
5
+ import os
6
+ import sys
7
 
 
 
 
8
  import numpy as np
9
+ import gradio as gr
10
+
11
+ # Import các module cần thiết
12
+ try:
13
+ import torch
14
+ import torchaudio
15
+ HAS_TORCH = True
16
+ except ImportError:
17
+ HAS_TORCH = False
18
+ logging.warning("PyTorch not available. Using mock generator.")
19
+
20
+ # Tạo lớp Mock để sử dụng khi không có PyTorch hoặc model bị lỗi
21
+ class MockGenerator:
22
+ def __init__(self):
23
+ self.sample_rate = 24000
24
+ logging.info("Created mock generator with sample rate 24000")
25
+
26
+ def generate(self, text, speaker, context=None, max_audio_length_ms=10000, temperature=0.9, topk=50):
27
+ # Tạo âm thanh giả - chỉ là silence với độ dài tỷ lệ với text
28
+ duration_seconds = min(len(text) * 0.1, max_audio_length_ms / 1000)
29
+ samples = int(duration_seconds * self.sample_rate)
30
+ logging.info(f"Generating mock audio with {samples} samples")
31
+ return np.zeros(samples, dtype=np.float32)
32
 
33
+ # Định nghĩa lớp Segment giả khi cần
34
+ class MockSegment:
35
+ def __init__(self, text, speaker, audio=None):
36
+ self.text = text
37
+ self.speaker = speaker
38
+ self.audio = audio if audio is not None else np.zeros(0, dtype=np.float32)
39
 
40
  logging.basicConfig(level=logging.INFO)
41
  logger = logging.getLogger(__name__)
 
46
  global generator
47
  logger.info("Loading CSM 1B model...")
48
 
49
+ # Nếu không PyTorch, sử dụng mock
50
+ if not HAS_TORCH:
51
+ logger.warning("PyTorch not available. Using mock generator.")
52
+ generator = MockGenerator()
53
+ return True
54
 
55
+ # Có PyTorch, thử tải model thật
56
  try:
57
+ # Kiểm tra và tải các thư viện cần thiết
58
+ import sys
59
+ # Thêm thư mục hiện tại vào PATH để đảm bảo import được các module cần thiết
60
+ if os.getcwd() not in sys.path:
61
+ sys.path.append(os.getcwd())
62
+
63
+ # Thử import từ generator module (theo hướng dẫn chính thức)
64
+ try:
65
+ from generator import load_csm_1b, Segment
66
+
67
+ device = "cuda" if torch.cuda.is_available() else "cpu"
68
+ if device == "cpu":
69
+ logger.warning("GPU not available. Using CPU, performance may be slow!")
70
+ logger.info(f"Using device: {device}")
71
+
72
+ # Tải model theo cách chính thức
73
+ generator = load_csm_1b(device=device)
74
+ logger.info(f"Model loaded successfully on device: {device}")
75
+ return True
76
+ except Exception as e:
77
+ logger.error(f"Error loading model: {str(e)}")
78
+ # Tải mock generator trong trường hợp lỗi
79
+ logger.warning("Falling back to mock generator")
80
+ generator = MockGenerator()
81
+ return True
82
+
83
  except Exception as e:
84
+ logger.error(f"Critical error: {str(e)}")
85
+ generator = MockGenerator()
86
+ return True
87
 
88
  def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9, topk=50, context_texts=None, context_speakers=None):
89
  global generator
90
 
91
  if generator is None:
92
  if not initialize_model():
93
+ # Sử dụng mock generator nếu không khởi tạo được
94
+ generator = MockGenerator()
95
 
96
  try:
97
+ # Xác định Segment class để sử dụng
98
+ try:
99
+ from generator import Segment
100
+ except ImportError:
101
+ Segment = MockSegment
102
+
103
+ # Xử lý context nếu có
104
  context_segments = []
105
  if context_texts and context_speakers:
106
  for ctx_text, ctx_speaker in zip(context_texts, context_speakers):
107
  if ctx_text and ctx_speaker is not None:
108
+ # Tạo audio tensor rỗng cho context
109
+ if HAS_TORCH:
110
+ audio_tensor = torch.zeros(0, dtype=torch.float32)
111
+ else:
112
+ audio_tensor = np.zeros(0, dtype=np.float32)
113
+
114
  context_segments.append(
115
+ Segment(text=ctx_text, speaker=int(ctx_speaker), audio=audio_tensor)
116
  )
117
 
118
+ # Generate audio từ text
119
  audio = generator.generate(
120
  text=text,
121
  speaker=int(speaker_id),
 
125
  topk=int(topk),
126
  )
127
 
128
+ # Chuyển đổi tensor sang numpy array cho Gradio
129
+ if HAS_TORCH and isinstance(audio, torch.Tensor):
130
+ audio_numpy = audio.cpu().numpy()
131
+ else:
132
+ audio_numpy = audio # Đã là numpy từ MockGenerator
133
+
134
  sample_rate = generator.sample_rate
135
 
136
  return (sample_rate, audio_numpy), None
137
 
138
  except Exception as e:
139
  logger.error(f"Error generating audio: {str(e)}")
140
+ # Sử dụng mock generator trong trường hợp lỗi
141
+ mock_gen = MockGenerator()
142
+ audio = mock_gen.generate(text=text, speaker=int(speaker_id), max_audio_length_ms=float(max_audio_length_ms))
143
+ return (mock_gen.sample_rate, audio), f"Error generating audio, using silent audio: {str(e)}"
144
 
145
  def clear_context():
146
  return [], []
 
151
  context_speakers.append(int(speaker_id))
152
  return context_texts, context_speakers
153
 
154
+ def update_context_display(texts, speakers):
155
+ if not texts or not speakers:
156
+ return []
157
+ return [[text, speaker] for text, speaker in zip(texts, speakers)]
158
+
159
+ def create_demo():
160
+ # Set up Gradio interface
161
+ demo = gr.Blocks(title="CSM 1B Demo")
162
 
163
+ with demo:
164
+ gr.Markdown("# CSM 1B - Conversational Speech Model")
165
+ gr.Markdown("Enter text to generate natural-sounding speech with the CSM 1B model")
166
+
167
+ if not HAS_TORCH:
168
+ gr.Markdown("⚠️ **WARNING: PyTorch is not available. Using a mock generator that produces silent audio.**")
169
+
170
+ with gr.Row():
171
+ with gr.Column(scale=2):
172
+ text_input = gr.Textbox(
173
+ label="Text to convert to speech",
174
+ placeholder="Enter your text here...",
175
+ lines=3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  )
177
+ speaker_id = gr.Slider(
178
+ label="Speaker ID",
179
+ minimum=0,
180
+ maximum=10,
181
+ step=1,
182
+ value=0
183
  )
 
 
 
 
184
 
185
+ with gr.Accordion("Advanced Options", open=False):
186
+ max_length = gr.Slider(
187
+ label="Maximum length (milliseconds)",
188
+ minimum=1000,
189
+ maximum=30000,
190
+ step=1000,
191
+ value=10000
192
+ )
193
+ temp = gr.Slider(
194
+ label="Temperature",
195
+ minimum=0.1,
196
+ maximum=1.5,
197
+ step=0.1,
198
+ value=0.9
199
+ )
200
+ top_k = gr.Slider(
201
+ label="Top K",
202
+ minimum=10,
203
+ maximum=100,
204
+ step=10,
205
+ value=50
206
  )
207
 
208
+ with gr.Accordion("Conversation Context", open=False):
209
+ context_list = gr.State([])
210
+ context_speakers_list = gr.State([])
211
+
212
+ with gr.Row():
213
+ context_text = gr.Textbox(label="Context text", lines=2)
214
+ context_speaker = gr.Slider(
215
+ label="Context speaker ID",
216
+ minimum=0,
217
+ maximum=10,
218
+ step=1,
219
+ value=0
220
+ )
221
+
222
+ with gr.Row():
223
+ add_ctx_btn = gr.Button("Add Context")
224
+ clear_ctx_btn = gr.Button("Clear All Context")
225
+
226
+ context_display = gr.Dataframe(
227
+ headers=["Text", "Speaker ID"],
228
+ label="Current Context",
229
+ interactive=False
230
+ )
231
 
232
+ generate_btn = gr.Button("Generate Audio", variant="primary")
 
 
 
 
233
 
234
+ with gr.Column(scale=1):
235
+ audio_output = gr.Audio(label="Generated Audio", type="numpy")
236
+ error_output = gr.Textbox(label="Error Message", visible=False)
237
 
238
+ # Connect events
239
+ generate_btn.click(
240
+ fn=generate_speech,
241
+ inputs=[
242
+ text_input,
243
+ speaker_id,
244
+ max_length,
245
+ temp,
246
+ top_k,
247
+ context_list,
248
+ context_speakers_list
249
+ ],
250
+ outputs=[audio_output, error_output]
251
+ )
252
+
253
+ add_ctx_btn.click(
254
+ fn=add_context,
255
+ inputs=[
256
+ context_text,
257
+ context_speaker,
258
+ context_list,
259
+ context_speakers_list
260
+ ],
261
+ outputs=[context_list, context_speakers_list]
262
+ ).then(
263
+ fn=update_context_display,
264
+ inputs=[context_list, context_speakers_list],
265
+ outputs=[context_display]
266
+ )
267
+
268
+ clear_ctx_btn.click(
269
+ fn=clear_context,
270
+ inputs=[],
271
+ outputs=[context_list, context_speakers_list]
272
+ ).then(
273
+ fn=lambda: [],
274
+ inputs=[],
275
+ outputs=[context_display]
276
+ )
277
+
278
+ gr.Markdown("""
279
+ ## About CSM-1B
280
+
281
+ CSM (Conversational Speech Model) is a speech generation model from Sesame that generates audio from text inputs.
282
+ The model can generate a variety of voices and works best when provided with conversational context.
283
+
284
+ ### Features:
285
+ - Generate natural-sounding speech from text
286
+ - Choose different speaker identities (0-10)
287
+ - Adjust temperature to control output variability
288
+ - Add conversation context for more natural responses
289
+
290
+ [View on Hugging Face](https://huggingface.co/sesame/csm-1b) | [GitHub Repository](https://github.com/SesameAILabs/csm)
291
+ """)
292
 
293
+ return demo
 
 
 
 
294
 
295
+ # Khởi tạo model
296
  initialize_model()
297
 
298
+ # Tạo khởi chạy demo
299
+ demo = create_demo()
300
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
301
 
generator.py CHANGED
@@ -164,46 +164,8 @@ class Generator:
164
 
165
 
166
  def load_csm_1b(device: str = "cuda") -> Generator:
167
- try:
168
- # Nếu silentcipher được cài đặt, thử tải config từ đó
169
- try:
170
- from silentcipher import Config
171
- model_path = "sesame/csm-1b"
172
- config = Config.from_pretrained(model_path)
173
- model = Model.from_pretrained(model_path, config=config)
174
- model = model.to(device=device, dtype=torch.bfloat16)
175
- generator = Generator(model)
176
- return generator
177
- except ImportError:
178
- # Nếu không thể import silentcipher, thử cách khác
179
- pass
180
-
181
- # Cố gắng tạo config từ pretrained model
182
- import os
183
- import json
184
- try:
185
- from huggingface_hub import hf_hub_download
186
- config_file = hf_hub_download("sesame/csm-1b", "config.json")
187
- with open(config_file, 'r') as f:
188
- config_dict = json.load(f)
189
-
190
- # Tạo dummy config object
191
- class DummyConfig:
192
- def __init__(self, **kwargs):
193
- for key, value in kwargs.items():
194
- setattr(self, key, value)
195
-
196
- config = DummyConfig(**config_dict)
197
- model = Model.from_pretrained("sesame/csm-1b", config=config)
198
- model = model.to(device=device, dtype=torch.bfloat16)
199
- generator = Generator(model)
200
- return generator
201
- except Exception as e:
202
- import logging
203
- logging.error(f"Error loading model with config: {str(e)}")
204
- raise RuntimeError(f"Could not load model: {str(e)}")
205
-
206
- except Exception as e:
207
- import logging
208
- logging.error(f"Failed to load model: {str(e)}")
209
- raise e
 
164
 
165
 
166
  def load_csm_1b(device: str = "cuda") -> Generator:
167
+ model = Model.from_pretrained("sesame/csm-1b")
168
+ model.to(device=device, dtype=torch.bfloat16)
169
+
170
+ generator = Generator(model)
171
+ return generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
setup.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Log in to Hugging Face to access model
4
+ echo "Logging in to Hugging Face..."
5
+ if [ -n "$HF_TOKEN" ]; then
6
+ echo "Using provided HF_TOKEN"
7
+ huggingface-cli login --token $HF_TOKEN
8
+ else
9
+ echo "No HF_TOKEN provided, trying to use cached credentials"
10
+ fi
11
+
12
+ # Clone repository if needed
13
+ if [ ! -d "./csm" ]; then
14
+ echo "Cloning CSM repository..."
15
+ git clone https://github.com/SesameAILabs/csm.git
16
+ cd csm
17
+ # Copy files back to parent directory
18
+ cp -r generator.py models.py watermarking.py ../
19
+ cd ..
20
+ else
21
+ echo "CSM repository already exists"
22
+ fi
23
+
24
+ # Install additional dependencies
25
+ echo "Installing additional dependencies..."
26
+ pip install -q git+https://github.com/SesameAILabs/csm.git
27
+
28
+ echo "Setup complete! Ready to start the application."