Spaces:
Sleeping
Sleeping
Create run_csm.py
Browse files- run_csm.py +117 -0
run_csm.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
from huggingface_hub import hf_hub_download
|
5 |
+
from generator import load_csm_1b, Segment
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
# Disable Triton compilation
|
9 |
+
os.environ["NO_TORCH_COMPILE"] = "1"
|
10 |
+
|
11 |
+
# Default prompts are available at https://hf.co/sesame/csm-1b
|
12 |
+
prompt_filepath_conversational_a = hf_hub_download(
|
13 |
+
repo_id="sesame/csm-1b",
|
14 |
+
filename="prompts/conversational_a.wav"
|
15 |
+
)
|
16 |
+
prompt_filepath_conversational_b = hf_hub_download(
|
17 |
+
repo_id="sesame/csm-1b",
|
18 |
+
filename="prompts/conversational_b.wav"
|
19 |
+
)
|
20 |
+
|
21 |
+
SPEAKER_PROMPTS = {
|
22 |
+
"conversational_a": {
|
23 |
+
"text": (
|
24 |
+
"like revising for an exam I'd have to try and like keep up the momentum because I'd "
|
25 |
+
"start really early I'd be like okay I'm gonna start revising now and then like "
|
26 |
+
"you're revising for ages and then I just like start losing steam I didn't do that "
|
27 |
+
"for the exam we had recently to be fair that was a more of a last minute scenario "
|
28 |
+
"but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
|
29 |
+
"sort of start the day with this not like a panic but like a"
|
30 |
+
),
|
31 |
+
"audio": prompt_filepath_conversational_a
|
32 |
+
},
|
33 |
+
"conversational_b": {
|
34 |
+
"text": (
|
35 |
+
"like a super Mario level. Like it's very like high detail. And like, once you get "
|
36 |
+
"into the park, it just like, everything looks like a computer game and they have all "
|
37 |
+
"these, like, you know, if, if there's like a, you know, like in a Mario game, they "
|
38 |
+
"will have like a question block. And if you like, you know, punch it, a coin will "
|
39 |
+
"come out. So like everyone, when they come into the park, they get like this little "
|
40 |
+
"bracelet and then you can go punching question blocks around."
|
41 |
+
),
|
42 |
+
"audio": prompt_filepath_conversational_b
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
|
47 |
+
audio_tensor, sample_rate = torchaudio.load(audio_path)
|
48 |
+
audio_tensor = audio_tensor.squeeze(0)
|
49 |
+
# Resample is lazy so we can always call it
|
50 |
+
audio_tensor = torchaudio.functional.resample(
|
51 |
+
audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
|
52 |
+
)
|
53 |
+
return audio_tensor
|
54 |
+
|
55 |
+
def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
|
56 |
+
audio_tensor = load_prompt_audio(audio_path, sample_rate)
|
57 |
+
return Segment(text=text, speaker=speaker, audio=audio_tensor)
|
58 |
+
|
59 |
+
def main():
|
60 |
+
# Select the best available device, skipping MPS due to float64 limitations
|
61 |
+
if torch.cuda.is_available():
|
62 |
+
device = "cuda"
|
63 |
+
else:
|
64 |
+
device = "cpu"
|
65 |
+
print(f"Using device: {device}")
|
66 |
+
|
67 |
+
# Load model
|
68 |
+
generator = load_csm_1b(device)
|
69 |
+
|
70 |
+
# Prepare prompts
|
71 |
+
prompt_a = prepare_prompt(
|
72 |
+
SPEAKER_PROMPTS["conversational_a"]["text"],
|
73 |
+
0,
|
74 |
+
SPEAKER_PROMPTS["conversational_a"]["audio"],
|
75 |
+
generator.sample_rate
|
76 |
+
)
|
77 |
+
|
78 |
+
prompt_b = prepare_prompt(
|
79 |
+
SPEAKER_PROMPTS["conversational_b"]["text"],
|
80 |
+
1,
|
81 |
+
SPEAKER_PROMPTS["conversational_b"]["audio"],
|
82 |
+
generator.sample_rate
|
83 |
+
)
|
84 |
+
|
85 |
+
# Generate conversation
|
86 |
+
conversation = [
|
87 |
+
{"text": "Hey how are you doing?", "speaker_id": 0},
|
88 |
+
{"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
|
89 |
+
{"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
|
90 |
+
{"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
|
91 |
+
]
|
92 |
+
|
93 |
+
# Generate each utterance
|
94 |
+
generated_segments = []
|
95 |
+
prompt_segments = [prompt_a, prompt_b]
|
96 |
+
|
97 |
+
for utterance in conversation:
|
98 |
+
print(f"Generating: {utterance['text']}")
|
99 |
+
audio_tensor = generator.generate(
|
100 |
+
text=utterance['text'],
|
101 |
+
speaker=utterance['speaker_id'],
|
102 |
+
context=prompt_segments + generated_segments,
|
103 |
+
max_audio_length_ms=10_000,
|
104 |
+
)
|
105 |
+
generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
|
106 |
+
|
107 |
+
# Concatenate all generations
|
108 |
+
all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
|
109 |
+
torchaudio.save(
|
110 |
+
"full_conversation.wav",
|
111 |
+
all_audio.unsqueeze(0).cpu(),
|
112 |
+
generator.sample_rate
|
113 |
+
)
|
114 |
+
print("Successfully generated full_conversation.wav")
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
main()
|