preston-cell commited on
Commit
c4e4a14
·
verified ·
1 Parent(s): 0286478

Create run_csm.py

Browse files
Files changed (1) hide show
  1. run_csm.py +117 -0
run_csm.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ from huggingface_hub import hf_hub_download
5
+ from generator import load_csm_1b, Segment
6
+ from dataclasses import dataclass
7
+
8
+ # Disable Triton compilation
9
+ os.environ["NO_TORCH_COMPILE"] = "1"
10
+
11
+ # Default prompts are available at https://hf.co/sesame/csm-1b
12
+ prompt_filepath_conversational_a = hf_hub_download(
13
+ repo_id="sesame/csm-1b",
14
+ filename="prompts/conversational_a.wav"
15
+ )
16
+ prompt_filepath_conversational_b = hf_hub_download(
17
+ repo_id="sesame/csm-1b",
18
+ filename="prompts/conversational_b.wav"
19
+ )
20
+
21
+ SPEAKER_PROMPTS = {
22
+ "conversational_a": {
23
+ "text": (
24
+ "like revising for an exam I'd have to try and like keep up the momentum because I'd "
25
+ "start really early I'd be like okay I'm gonna start revising now and then like "
26
+ "you're revising for ages and then I just like start losing steam I didn't do that "
27
+ "for the exam we had recently to be fair that was a more of a last minute scenario "
28
+ "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
29
+ "sort of start the day with this not like a panic but like a"
30
+ ),
31
+ "audio": prompt_filepath_conversational_a
32
+ },
33
+ "conversational_b": {
34
+ "text": (
35
+ "like a super Mario level. Like it's very like high detail. And like, once you get "
36
+ "into the park, it just like, everything looks like a computer game and they have all "
37
+ "these, like, you know, if, if there's like a, you know, like in a Mario game, they "
38
+ "will have like a question block. And if you like, you know, punch it, a coin will "
39
+ "come out. So like everyone, when they come into the park, they get like this little "
40
+ "bracelet and then you can go punching question blocks around."
41
+ ),
42
+ "audio": prompt_filepath_conversational_b
43
+ }
44
+ }
45
+
46
+ def load_prompt_audio(audio_path: str, target_sample_rate: int) -> torch.Tensor:
47
+ audio_tensor, sample_rate = torchaudio.load(audio_path)
48
+ audio_tensor = audio_tensor.squeeze(0)
49
+ # Resample is lazy so we can always call it
50
+ audio_tensor = torchaudio.functional.resample(
51
+ audio_tensor, orig_freq=sample_rate, new_freq=target_sample_rate
52
+ )
53
+ return audio_tensor
54
+
55
+ def prepare_prompt(text: str, speaker: int, audio_path: str, sample_rate: int) -> Segment:
56
+ audio_tensor = load_prompt_audio(audio_path, sample_rate)
57
+ return Segment(text=text, speaker=speaker, audio=audio_tensor)
58
+
59
+ def main():
60
+ # Select the best available device, skipping MPS due to float64 limitations
61
+ if torch.cuda.is_available():
62
+ device = "cuda"
63
+ else:
64
+ device = "cpu"
65
+ print(f"Using device: {device}")
66
+
67
+ # Load model
68
+ generator = load_csm_1b(device)
69
+
70
+ # Prepare prompts
71
+ prompt_a = prepare_prompt(
72
+ SPEAKER_PROMPTS["conversational_a"]["text"],
73
+ 0,
74
+ SPEAKER_PROMPTS["conversational_a"]["audio"],
75
+ generator.sample_rate
76
+ )
77
+
78
+ prompt_b = prepare_prompt(
79
+ SPEAKER_PROMPTS["conversational_b"]["text"],
80
+ 1,
81
+ SPEAKER_PROMPTS["conversational_b"]["audio"],
82
+ generator.sample_rate
83
+ )
84
+
85
+ # Generate conversation
86
+ conversation = [
87
+ {"text": "Hey how are you doing?", "speaker_id": 0},
88
+ {"text": "Pretty good, pretty good. How about you?", "speaker_id": 1},
89
+ {"text": "I'm great! So happy to be speaking with you today.", "speaker_id": 0},
90
+ {"text": "Me too! This is some cool stuff, isn't it?", "speaker_id": 1}
91
+ ]
92
+
93
+ # Generate each utterance
94
+ generated_segments = []
95
+ prompt_segments = [prompt_a, prompt_b]
96
+
97
+ for utterance in conversation:
98
+ print(f"Generating: {utterance['text']}")
99
+ audio_tensor = generator.generate(
100
+ text=utterance['text'],
101
+ speaker=utterance['speaker_id'],
102
+ context=prompt_segments + generated_segments,
103
+ max_audio_length_ms=10_000,
104
+ )
105
+ generated_segments.append(Segment(text=utterance['text'], speaker=utterance['speaker_id'], audio=audio_tensor))
106
+
107
+ # Concatenate all generations
108
+ all_audio = torch.cat([seg.audio for seg in generated_segments], dim=0)
109
+ torchaudio.save(
110
+ "full_conversation.wav",
111
+ all_audio.unsqueeze(0).cpu(),
112
+ generator.sample_rate
113
+ )
114
+ print("Successfully generated full_conversation.wav")
115
+
116
+ if __name__ == "__main__":
117
+ main()