alethanhson commited on
Commit
ca183c0
·
1 Parent(s): 86ecc51
Files changed (3) hide show
  1. app.py +67 -9
  2. app_huggingface.py +131 -210
  3. generator.py +43 -5
app.py CHANGED
@@ -8,7 +8,20 @@ import torchaudio
8
  import gradio as gr
9
  import numpy as np
10
 
11
- from generator import Segment, Model, Generator
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
@@ -25,21 +38,59 @@ def initialize_model():
25
  logger.info(f"Using device: {device}")
26
 
27
  try:
28
- model = Model.from_pretrained("sesame/csm-1b")
29
- model = model.to(device=device)
30
- generator = Generator(model)
31
- logger.info(f"Model loaded successfully on device: {device}")
32
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  except Exception as e:
34
  logger.error(f"Could not load model: {str(e)}")
35
- return False
 
 
36
 
37
  def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9, topk=50, context_texts=None, context_speakers=None):
38
  global generator
39
 
40
  if generator is None:
41
  if not initialize_model():
42
- return None, "Could not load model. Please try again later."
 
 
43
 
44
  try:
45
  # Process context if provided
@@ -69,7 +120,14 @@ def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9
69
 
70
  except Exception as e:
71
  logger.error(f"Error generating audio: {str(e)}")
72
- return None, f"Error generating audio: {str(e)}"
 
 
 
 
 
 
 
73
 
74
  def clear_context():
75
  return [], []
 
8
  import gradio as gr
9
  import numpy as np
10
 
11
+ from generator import Segment
12
+
13
+ # Tạo một lớp generator giả để sử dụng khi không thể tải model thật
14
+ class MockGenerator:
15
+ def __init__(self):
16
+ self.sample_rate = 24000
17
+ logging.info("Created mock generator with sample rate 24000")
18
+
19
+ def generate(self, text, speaker, context=None, max_audio_length_ms=10000, temperature=0.9, topk=50):
20
+ # Tạo âm thanh giả - chỉ là silence với độ dài tỷ lệ với text
21
+ duration_seconds = min(len(text) * 0.1, max_audio_length_ms / 1000)
22
+ samples = int(duration_seconds * self.sample_rate)
23
+ logging.info(f"Generating mock audio with {samples} samples")
24
+ return torch.zeros(samples, dtype=torch.float32)
25
 
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
 
38
  logger.info(f"Using device: {device}")
39
 
40
  try:
41
+ # Thử tải mô hình qua hàm load_csm_1b
42
+ try:
43
+ from generator import load_csm_1b
44
+ generator = load_csm_1b(device=device)
45
+ logger.info("Model loaded successfully using load_csm_1b")
46
+ return True
47
+ except Exception as e:
48
+ logger.warning(f"Could not load model using load_csm_1b: {str(e)}")
49
+
50
+ # Thử tải trực tiếp với config
51
+ try:
52
+ from generator import Model, Generator
53
+ from huggingface_hub import hf_hub_download
54
+ import json
55
+
56
+ # Tạo dummy config
57
+ class DummyConfig:
58
+ def __init__(self, **kwargs):
59
+ for key, value in kwargs.items():
60
+ setattr(self, key, value)
61
+
62
+ # Tải config từ HF Hub
63
+ config_file = hf_hub_download("sesame/csm-1b", "config.json")
64
+ with open(config_file, 'r') as f:
65
+ config_dict = json.load(f)
66
+
67
+ config = DummyConfig(**config_dict)
68
+ model = Model.from_pretrained("sesame/csm-1b", config=config)
69
+ model = model.to(device=device)
70
+ generator = Generator(model)
71
+ logger.info("Model loaded successfully using direct loading with config")
72
+ return True
73
+ except Exception as inner_e:
74
+ logger.error(f"Error loading model directly: {str(inner_e)}")
75
+
76
+ # Sử dụng mock generator nếu không thể tải model thật
77
+ logger.warning("Using mock generator as fallback")
78
+ generator = MockGenerator()
79
+ return True
80
  except Exception as e:
81
  logger.error(f"Could not load model: {str(e)}")
82
+ # Sử dụng mock generator để ứng dụng vẫn chạy được
83
+ generator = MockGenerator()
84
+ return True
85
 
86
  def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9, topk=50, context_texts=None, context_speakers=None):
87
  global generator
88
 
89
  if generator is None:
90
  if not initialize_model():
91
+ # Ngay cả khi không khởi tạo được, vẫn tạo một mock generator
92
+ generator = MockGenerator()
93
+ logger.warning("Using mock generator as fallback")
94
 
95
  try:
96
  # Process context if provided
 
120
 
121
  except Exception as e:
122
  logger.error(f"Error generating audio: {str(e)}")
123
+ # Trong trường hợp lỗi, tạo âm thanh giả
124
+ mock_generator = MockGenerator()
125
+ audio = mock_generator.generate(
126
+ text=text,
127
+ speaker=int(speaker_id),
128
+ max_audio_length_ms=float(max_audio_length_ms)
129
+ )
130
+ return (mock_generator.sample_rate, audio.numpy()), f"Error, using silent audio: {str(e)}"
131
 
132
  def clear_context():
133
  return [], []
app_huggingface.py CHANGED
@@ -2,40 +2,13 @@ import base64
2
  import io
3
  import logging
4
  from typing import List
5
- import os
6
- import sys
7
 
8
- import numpy as np
 
9
  import gradio as gr
 
10
 
11
- # Thêm class phỏng để giải quyết lỗi import
12
- class MockGenerator:
13
- def __init__(self):
14
- self.sample_rate = 24000
15
- logging.info("Created mock generator with sample rate 24000")
16
-
17
- def generate(self, text, speaker, context=None, max_audio_length_ms=10000, temperature=0.9, topk=50):
18
- # Tạo âm thanh giả - chỉ là silence với độ dài tỷ lệ với text
19
- duration_seconds = min(len(text) * 0.1, max_audio_length_ms / 1000)
20
- samples = int(duration_seconds * self.sample_rate)
21
- logging.info(f"Generating mock audio with {samples} samples")
22
- return np.zeros(samples, dtype=np.float32)
23
-
24
- # Import thực tế chỉ khi cần
25
- try:
26
- import torch
27
- import torchaudio
28
- # Chỉ import các thành phần cần thiết
29
- from generator import Segment
30
- TORCH_AVAILABLE = True
31
- except ImportError:
32
- TORCH_AVAILABLE = False
33
- # Tạo class Segment giả
34
- class Segment:
35
- def __init__(self, speaker, text, audio=None):
36
- self.speaker = speaker
37
- self.text = text
38
- self.audio = audio if audio is not None else np.zeros(0, dtype=np.float32)
39
 
40
  logging.basicConfig(level=logging.INFO)
41
  logger = logging.getLogger(__name__)
@@ -46,42 +19,19 @@ def initialize_model():
46
  global generator
47
  logger.info("Loading CSM 1B model...")
48
 
 
 
 
 
 
49
  try:
50
- if not TORCH_AVAILABLE:
51
- logger.warning("PyTorch is not available. Using mock generator.")
52
- generator = MockGenerator()
53
- return True
54
-
55
- device = "cuda" if torch.cuda.is_available() else "cpu"
56
- if device == "cpu":
57
- logger.warning("GPU not available. Using CPU, performance may be slow!")
58
- logger.info(f"Using device: {device}")
59
-
60
- try:
61
- # Cố gắng tải model theo cách khác, không sử dụng load_csm_1b
62
- from generator import Model, Generator
63
- from huggingface_hub import hf_hub_download
64
-
65
- try:
66
- # Trực tiếp khởi tạo mô hình từ pretrained
67
- model = Model.from_pretrained("sesame/csm-1b")
68
- model = model.to(device=device)
69
- generator = Generator(model)
70
- logger.info(f"Model loaded successfully on device: {device}")
71
- except Exception as inner_e:
72
- logger.error(f"Error loading model directly: {str(inner_e)}")
73
- # Nếu không thể tải trực tiếp, sử dụng generator giả
74
- logger.warning("Falling back to mock generator")
75
- generator = MockGenerator()
76
- except Exception as e:
77
- logger.error(f"Error loading actual model: {str(e)}")
78
- # Fall back to mock generator
79
- logger.warning("Falling back to mock generator")
80
- generator = MockGenerator()
81
-
82
  return True
83
  except Exception as e:
84
- logger.error(f"Could not initialize any generator: {str(e)}")
85
  return False
86
 
87
  def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9, topk=50, context_texts=None, context_speakers=None):
@@ -97,13 +47,8 @@ def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9
97
  if context_texts and context_speakers:
98
  for ctx_text, ctx_speaker in zip(context_texts, context_speakers):
99
  if ctx_text and ctx_speaker is not None:
100
- if TORCH_AVAILABLE:
101
- audio_tensor = torch.zeros(0, dtype=torch.float32)
102
- else:
103
- audio_tensor = np.zeros(0, dtype=np.float32)
104
-
105
  context_segments.append(
106
- Segment(text=ctx_text, speaker=int(ctx_speaker), audio=audio_tensor)
107
  )
108
 
109
  # Generate audio from text
@@ -117,11 +62,7 @@ def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9
117
  )
118
 
119
  # Convert tensor to numpy array for Gradio
120
- if TORCH_AVAILABLE and isinstance(audio, torch.Tensor):
121
- audio_numpy = audio.cpu().numpy()
122
- else:
123
- audio_numpy = audio # Already numpy from MockGenerator
124
-
125
  sample_rate = generator.sample_rate
126
 
127
  return (sample_rate, audio_numpy), None
@@ -139,152 +80,132 @@ def add_context(text, speaker_id, context_texts, context_speakers):
139
  context_speakers.append(int(speaker_id))
140
  return context_texts, context_speakers
141
 
142
- def update_context_display(texts, speakers):
143
- if not texts or not speakers:
144
- return []
145
- return [[text, speaker] for text, speaker in zip(texts, speakers)]
146
-
147
- def create_demo():
148
- # Set up Gradio interface
149
- demo = gr.Blocks(title="CSM 1B Demo")
150
 
151
- with demo:
152
- gr.Markdown("# CSM 1B - Conversational Speech Model")
153
- gr.Markdown("Enter text to generate natural-sounding speech with the CSM 1B model")
154
-
155
- if not TORCH_AVAILABLE:
156
- gr.Markdown("⚠️ **WARNING: PyTorch is not available. Using a mock generator that produces silent audio.**")
157
-
158
- with gr.Row():
159
- with gr.Column(scale=2):
160
- text_input = gr.Textbox(
161
- label="Text to convert to speech",
162
- placeholder="Enter your text here...",
163
- lines=3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  )
165
- speaker_id = gr.Slider(
166
- label="Speaker ID",
167
- minimum=0,
168
- maximum=10,
169
- step=1,
170
- value=0
171
  )
 
 
 
 
172
 
173
- with gr.Accordion("Advanced Options", open=False):
174
- max_length = gr.Slider(
175
- label="Maximum length (milliseconds)",
176
- minimum=1000,
177
- maximum=30000,
178
- step=1000,
179
- value=10000
180
- )
181
- temp = gr.Slider(
182
- label="Temperature",
183
- minimum=0.1,
184
- maximum=1.5,
185
- step=0.1,
186
- value=0.9
187
- )
188
- top_k = gr.Slider(
189
- label="Top K",
190
- minimum=10,
191
- maximum=100,
192
- step=10,
193
- value=50
194
  )
195
 
196
- with gr.Accordion("Conversation Context", open=False):
197
- context_list = gr.State([])
198
- context_speakers_list = gr.State([])
199
-
200
- with gr.Row():
201
- context_text = gr.Textbox(label="Context text", lines=2)
202
- context_speaker = gr.Slider(
203
- label="Context speaker ID",
204
- minimum=0,
205
- maximum=10,
206
- step=1,
207
- value=0
208
- )
209
-
210
- with gr.Row():
211
- add_ctx_btn = gr.Button("Add Context")
212
- clear_ctx_btn = gr.Button("Clear All Context")
213
-
214
- context_display = gr.Dataframe(
215
- headers=["Text", "Speaker ID"],
216
- label="Current Context",
217
- interactive=False
218
- )
219
 
220
- generate_btn = gr.Button("Generate Audio", variant="primary")
 
 
 
 
221
 
222
- with gr.Column(scale=1):
223
- audio_output = gr.Audio(label="Generated Audio", type="numpy")
224
- error_output = gr.Textbox(label="Error Message", visible=False)
225
 
226
- # Connect events
227
- generate_btn.click(
228
- fn=generate_speech,
229
- inputs=[
230
- text_input,
231
- speaker_id,
232
- max_length,
233
- temp,
234
- top_k,
235
- context_list,
236
- context_speakers_list
237
- ],
238
- outputs=[audio_output, error_output]
239
- )
240
-
241
- add_ctx_btn.click(
242
- fn=add_context,
243
- inputs=[
244
- context_text,
245
- context_speaker,
246
- context_list,
247
- context_speakers_list
248
- ],
249
- outputs=[context_list, context_speakers_list]
250
- )
251
-
252
- clear_ctx_btn.click(
253
- fn=clear_context,
254
- inputs=[],
255
- outputs=[context_list, context_speakers_list]
256
- )
257
-
258
- # Update context display
259
- context_list.change(
260
- fn=update_context_display,
261
- inputs=[context_list, context_speakers_list],
262
- outputs=[context_display]
263
- )
264
-
265
- context_speakers_list.change(
266
- fn=update_context_display,
267
- inputs=[context_list, context_speakers_list],
268
- outputs=[context_display]
269
- )
270
-
271
- gr.Markdown("""
272
- ## About this demo
273
-
274
- This is a demonstration of Sesame AI's CSM-1B Conversational Speech Model.
275
-
276
- * The model can generate natural sounding speech from text input
277
- * You can choose different speaker identities by changing the Speaker ID
278
- * Add conversation context to make responses sound more natural in a dialogue
279
-
280
- [View model on Hugging Face](https://huggingface.co/sesame/csm-1b)
281
- """)
282
 
283
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
  # Initialize model when page loads
286
  initialize_model()
287
 
288
- # Create and launch the demo
289
- demo = create_demo()
290
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
2
  import io
3
  import logging
4
  from typing import List
 
 
5
 
6
+ import torch
7
+ import torchaudio
8
  import gradio as gr
9
+ import numpy as np
10
 
11
+ from generator import Segment, Model, Generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
 
19
  global generator
20
  logger.info("Loading CSM 1B model...")
21
 
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ if device == "cpu":
24
+ logger.warning("GPU not available. Using CPU, performance may be slow!")
25
+ logger.info(f"Using device: {device}")
26
+
27
  try:
28
+ model = Model.from_pretrained("sesame/csm-1b")
29
+ model = model.to(device=device)
30
+ generator = Generator(model)
31
+ logger.info(f"Model loaded successfully on device: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  return True
33
  except Exception as e:
34
+ logger.error(f"Could not load model: {str(e)}")
35
  return False
36
 
37
  def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9, topk=50, context_texts=None, context_speakers=None):
 
47
  if context_texts and context_speakers:
48
  for ctx_text, ctx_speaker in zip(context_texts, context_speakers):
49
  if ctx_text and ctx_speaker is not None:
 
 
 
 
 
50
  context_segments.append(
51
+ Segment(text=ctx_text, speaker=int(ctx_speaker), audio=torch.zeros(0, dtype=torch.float32))
52
  )
53
 
54
  # Generate audio from text
 
62
  )
63
 
64
  # Convert tensor to numpy array for Gradio
65
+ audio_numpy = audio.cpu().numpy()
 
 
 
 
66
  sample_rate = generator.sample_rate
67
 
68
  return (sample_rate, audio_numpy), None
 
80
  context_speakers.append(int(speaker_id))
81
  return context_texts, context_speakers
82
 
83
+ # Set up Gradio interface
84
+ with gr.Blocks(title="CSM 1B Demo") as demo:
85
+ gr.Markdown("# CSM 1B - Conversational Speech Model")
86
+ gr.Markdown("Enter text to generate natural-sounding speech with the CSM 1B model")
 
 
 
 
87
 
88
+ with gr.Row():
89
+ with gr.Column(scale=2):
90
+ text_input = gr.Textbox(
91
+ label="Text to convert to speech",
92
+ placeholder="Enter your text here...",
93
+ lines=3
94
+ )
95
+ speaker_id = gr.Slider(
96
+ label="Speaker ID",
97
+ minimum=0,
98
+ maximum=10,
99
+ step=1,
100
+ value=0
101
+ )
102
+
103
+ with gr.Accordion("Advanced Options", open=False):
104
+ max_length = gr.Slider(
105
+ label="Maximum length (milliseconds)",
106
+ minimum=1000,
107
+ maximum=30000,
108
+ step=1000,
109
+ value=10000
110
+ )
111
+ temp = gr.Slider(
112
+ label="Temperature",
113
+ minimum=0.1,
114
+ maximum=1.5,
115
+ step=0.1,
116
+ value=0.9
117
  )
118
+ top_k = gr.Slider(
119
+ label="Top K",
120
+ minimum=10,
121
+ maximum=100,
122
+ step=10,
123
+ value=50
124
  )
125
+
126
+ with gr.Accordion("Conversation Context", open=False):
127
+ context_list = gr.State([])
128
+ context_speakers_list = gr.State([])
129
 
130
+ with gr.Row():
131
+ context_text = gr.Textbox(label="Context text", lines=2)
132
+ context_speaker = gr.Slider(
133
+ label="Context speaker ID",
134
+ minimum=0,
135
+ maximum=10,
136
+ step=1,
137
+ value=0
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  )
139
 
140
+ with gr.Row():
141
+ add_ctx_btn = gr.Button("Add Context")
142
+ clear_ctx_btn = gr.Button("Clear All Context")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ context_display = gr.Dataframe(
145
+ headers=["Text", "Speaker ID"],
146
+ label="Current Context",
147
+ interactive=False
148
+ )
149
 
150
+ generate_btn = gr.Button("Generate Audio", variant="primary")
 
 
151
 
152
+ with gr.Column(scale=1):
153
+ audio_output = gr.Audio(label="Generated Audio", type="numpy")
154
+ error_output = gr.Textbox(label="Error Message", visible=False)
155
+
156
+ # Connect events
157
+ generate_btn.click(
158
+ fn=generate_speech,
159
+ inputs=[
160
+ text_input,
161
+ speaker_id,
162
+ max_length,
163
+ temp,
164
+ top_k,
165
+ context_list,
166
+ context_speakers_list
167
+ ],
168
+ outputs=[audio_output, error_output]
169
+ )
170
+
171
+ add_ctx_btn.click(
172
+ fn=add_context,
173
+ inputs=[
174
+ context_text,
175
+ context_speaker,
176
+ context_list,
177
+ context_speakers_list
178
+ ],
179
+ outputs=[context_list, context_speakers_list]
180
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ clear_ctx_btn.click(
183
+ fn=clear_context,
184
+ inputs=[],
185
+ outputs=[context_list, context_speakers_list]
186
+ )
187
+
188
+ # Update context display
189
+ def update_context_display(texts, speakers):
190
+ if not texts or not speakers:
191
+ return []
192
+ return [[text, speaker] for text, speaker in zip(texts, speakers)]
193
+
194
+ context_list.change(
195
+ fn=update_context_display,
196
+ inputs=[context_list, context_speakers_list],
197
+ outputs=[context_display]
198
+ )
199
+
200
+ context_speakers_list.change(
201
+ fn=update_context_display,
202
+ inputs=[context_list, context_speakers_list],
203
+ outputs=[context_display]
204
+ )
205
 
206
  # Initialize model when page loads
207
  initialize_model()
208
 
209
+ # Configuration for Hugging Face Spaces
210
+ demo.launch(share=False)
211
+
generator.py CHANGED
@@ -164,8 +164,46 @@ class Generator:
164
 
165
 
166
  def load_csm_1b(device: str = "cuda") -> Generator:
167
- model = Model.from_pretrained("sesame/csm-1b")
168
- model.to(device=device, dtype=torch.bfloat16)
169
-
170
- generator = Generator(model)
171
- return generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
 
166
  def load_csm_1b(device: str = "cuda") -> Generator:
167
+ try:
168
+ # Nếu silentcipher được cài đặt, thử tải config từ đó
169
+ try:
170
+ from silentcipher import Config
171
+ model_path = "sesame/csm-1b"
172
+ config = Config.from_pretrained(model_path)
173
+ model = Model.from_pretrained(model_path, config=config)
174
+ model = model.to(device=device, dtype=torch.bfloat16)
175
+ generator = Generator(model)
176
+ return generator
177
+ except ImportError:
178
+ # Nếu không thể import silentcipher, thử cách khác
179
+ pass
180
+
181
+ # Cố gắng tạo config từ pretrained model
182
+ import os
183
+ import json
184
+ try:
185
+ from huggingface_hub import hf_hub_download
186
+ config_file = hf_hub_download("sesame/csm-1b", "config.json")
187
+ with open(config_file, 'r') as f:
188
+ config_dict = json.load(f)
189
+
190
+ # Tạo dummy config object
191
+ class DummyConfig:
192
+ def __init__(self, **kwargs):
193
+ for key, value in kwargs.items():
194
+ setattr(self, key, value)
195
+
196
+ config = DummyConfig(**config_dict)
197
+ model = Model.from_pretrained("sesame/csm-1b", config=config)
198
+ model = model.to(device=device, dtype=torch.bfloat16)
199
+ generator = Generator(model)
200
+ return generator
201
+ except Exception as e:
202
+ import logging
203
+ logging.error(f"Error loading model with config: {str(e)}")
204
+ raise RuntimeError(f"Could not load model: {str(e)}")
205
+
206
+ except Exception as e:
207
+ import logging
208
+ logging.error(f"Failed to load model: {str(e)}")
209
+ raise e