Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -24,39 +24,24 @@ top_k = 20
|
|
24 |
from safetensors.torch import load_file
|
25 |
|
26 |
def convert_to_16_bit_wav(data):
|
27 |
-
# Based on: https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.write.html
|
28 |
-
# breakpoint()
|
29 |
if data.dtype == np.float32:
|
30 |
-
# warnings.warn(
|
31 |
-
# "Audio data is not in 16-bit integer format."
|
32 |
-
# "Trying to convert to 16-bit int format."
|
33 |
-
# )
|
34 |
data = data / np.abs(data).max()
|
35 |
data = data * 32767
|
36 |
data = data.astype(np.int16)
|
37 |
elif data.dtype == np.int32:
|
38 |
-
# warnings.warn(
|
39 |
-
# "Audio data is not in 16-bit integer format."
|
40 |
-
# "Trying to convert to 16-bit int format."
|
41 |
-
# )
|
42 |
data = data / 65538
|
43 |
data = data.astype(np.int16)
|
44 |
elif data.dtype == np.int16:
|
45 |
pass
|
46 |
elif data.dtype == np.uint8:
|
47 |
-
# warnings.warn(
|
48 |
-
# "Audio data is not in 16-bit integer format."
|
49 |
-
# "Trying to convert to 16-bit int format."
|
50 |
-
# )
|
51 |
data = data * 257 - 32768
|
52 |
data = data.astype(np.int16)
|
53 |
else:
|
54 |
-
raise ValueError("Audio data cannot be converted to
|
55 |
return data
|
56 |
|
57 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
|
59 |
-
|
60 |
# Load the model with INT8 quantization
|
61 |
model = AutoModelForCausalLM.from_pretrained(
|
62 |
model_path,
|
@@ -71,7 +56,7 @@ ckpt_path = "audiotokenizer/SpeechTokenizer.pt"
|
|
71 |
quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
|
72 |
quantizer.eval()
|
73 |
|
74 |
-
#
|
75 |
def freeze_entire_model(model):
|
76 |
for n, p in model.named_parameters():
|
77 |
p.requires_grad = False
|
@@ -81,7 +66,7 @@ for n, child in quantizer.named_children():
|
|
81 |
child.to(device)
|
82 |
child = freeze_entire_model(child)
|
83 |
|
84 |
-
#
|
85 |
def get_audio_padding_tokens(quantizer):
|
86 |
audio = torch.zeros((1, 1, 1)).to(device)
|
87 |
codes = quantizer.encode(audio)
|
@@ -89,7 +74,7 @@ def get_audio_padding_tokens(quantizer):
|
|
89 |
torch.cuda.empty_cache()
|
90 |
return {"audio_tokens": codes.squeeze(1)}
|
91 |
|
92 |
-
#
|
93 |
def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
|
94 |
start = torch.nonzero(tokens == tokenizer(start_audio_token)["input_ids"][-1])
|
95 |
end = torch.nonzero(tokens == tokenizer(end_audio_token)["input_ids"][-1])
|
@@ -112,9 +97,7 @@ def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
|
|
112 |
return xp
|
113 |
|
114 |
|
115 |
-
#
|
116 |
-
|
117 |
-
# Функция инференса для текста на входе и аудио на выходе
|
118 |
def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
|
119 |
text_tokenized = tokenizer(text, return_tensors="pt")
|
120 |
text_input_tokens = text_tokenized["input_ids"].to(device)
|
@@ -132,7 +115,6 @@ def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024,
|
|
132 |
|
133 |
return audio_signal
|
134 |
|
135 |
-
# Функция инференса для аудио на входе и текста на выходе
|
136 |
def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
|
137 |
audio_data, sample_rate = torchaudio.load(audio_path)
|
138 |
|
@@ -155,7 +137,7 @@ def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=
|
|
155 |
|
156 |
return decoded_text
|
157 |
|
158 |
-
# Functions for
|
159 |
def infer_text_to_audio_gr(text):
|
160 |
audio_signal = infer_text_to_audio(text.strip().upper(), model, tokenizer, quantizer)
|
161 |
return audio_signal
|
@@ -183,6 +165,42 @@ audio_to_text_interface = gr.Interface(
|
|
183 |
allow_flagging='never'
|
184 |
)
|
185 |
|
186 |
-
#
|
187 |
demo = gr.TabbedInterface([text_to_audio_interface, audio_to_text_interface], ["Text - Audio", "Audio - Text"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
demo.launch(share=True)
|
|
|
24 |
from safetensors.torch import load_file
|
25 |
|
26 |
def convert_to_16_bit_wav(data):
|
|
|
|
|
27 |
if data.dtype == np.float32:
|
|
|
|
|
|
|
|
|
28 |
data = data / np.abs(data).max()
|
29 |
data = data * 32767
|
30 |
data = data.astype(np.int16)
|
31 |
elif data.dtype == np.int32:
|
|
|
|
|
|
|
|
|
32 |
data = data / 65538
|
33 |
data = data.astype(np.int16)
|
34 |
elif data.dtype == np.int16:
|
35 |
pass
|
36 |
elif data.dtype == np.uint8:
|
|
|
|
|
|
|
|
|
37 |
data = data * 257 - 32768
|
38 |
data = data.astype(np.int16)
|
39 |
else:
|
40 |
+
raise ValueError("Audio data cannot be converted to 16-bit int format.")
|
41 |
return data
|
42 |
|
43 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
44 |
|
|
|
45 |
# Load the model with INT8 quantization
|
46 |
model = AutoModelForCausalLM.from_pretrained(
|
47 |
model_path,
|
|
|
56 |
quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
|
57 |
quantizer.eval()
|
58 |
|
59 |
+
# Freeze layers in the quantizer
|
60 |
def freeze_entire_model(model):
|
61 |
for n, p in model.named_parameters():
|
62 |
p.requires_grad = False
|
|
|
66 |
child.to(device)
|
67 |
child = freeze_entire_model(child)
|
68 |
|
69 |
+
# Create padding tokens for audio
|
70 |
def get_audio_padding_tokens(quantizer):
|
71 |
audio = torch.zeros((1, 1, 1)).to(device)
|
72 |
codes = quantizer.encode(audio)
|
|
|
74 |
torch.cuda.empty_cache()
|
75 |
return {"audio_tokens": codes.squeeze(1)}
|
76 |
|
77 |
+
# Decode audio from tokens
|
78 |
def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
|
79 |
start = torch.nonzero(tokens == tokenizer(start_audio_token)["input_ids"][-1])
|
80 |
end = torch.nonzero(tokens == tokenizer(end_audio_token)["input_ids"][-1])
|
|
|
97 |
return xp
|
98 |
|
99 |
|
100 |
+
# Inference functions
|
|
|
|
|
101 |
def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
|
102 |
text_tokenized = tokenizer(text, return_tensors="pt")
|
103 |
text_input_tokens = text_tokenized["input_ids"].to(device)
|
|
|
115 |
|
116 |
return audio_signal
|
117 |
|
|
|
118 |
def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
|
119 |
audio_data, sample_rate = torchaudio.load(audio_path)
|
120 |
|
|
|
137 |
|
138 |
return decoded_text
|
139 |
|
140 |
+
# Functions for Gradio Interface
|
141 |
def infer_text_to_audio_gr(text):
|
142 |
audio_signal = infer_text_to_audio(text.strip().upper(), model, tokenizer, quantizer)
|
143 |
return audio_signal
|
|
|
165 |
allow_flagging='never'
|
166 |
)
|
167 |
|
168 |
+
# Gradio Demo
|
169 |
demo = gr.TabbedInterface([text_to_audio_interface, audio_to_text_interface], ["Text - Audio", "Audio - Text"])
|
170 |
+
|
171 |
+
# Custom CSS for centered links
|
172 |
+
custom_css = """
|
173 |
+
<style>
|
174 |
+
.center {
|
175 |
+
text-align: center;
|
176 |
+
}
|
177 |
+
</style>
|
178 |
+
"""
|
179 |
+
|
180 |
+
# Add Gradio description with centered links
|
181 |
+
description = f"""
|
182 |
+
# **Salt: Speech And Language Transformer**
|
183 |
+
|
184 |
+
Welcome to the demo of **Salt**, a speech and language model. Vikhr Salt is capable of both **Text-to-Speech (T2S)** and **Speech-to-Text (S2T)** tasks, making it a versatile tool for transforming language into speech and vice versa. Built on a pre-trained large language model, Vikhr Salt incorporates audio tokens using cutting-edge techniques like **Encodec** and **SpeechTokenizer**, enabling robust performance across multiple modalities.
|
185 |
+
|
186 |
+
## **🛠 Features**
|
187 |
+
- **Text-to-Speech (T2S)**: Enter text and generate high-quality audio outputs.
|
188 |
+
- **Speech-to-Text (S2T)**: Upload an audio file and convert it into accurate text.
|
189 |
+
|
190 |
+
## **🚀 Try it out:**
|
191 |
+
Explore the tabs to try the **Text - Audio** and **Audio - Text** modes!
|
192 |
+
|
193 |
+
---
|
194 |
+
|
195 |
+
<div class="center">
|
196 |
+
### **📄 Preprint**
|
197 |
+
[Read the paper](https://docs.google.com/document/d/1ZvV47W4BCyZM_JfDC1BKj-0ozwPck5t2yNB8jORVshI/edit?usp=sharing)
|
198 |
+
|
199 |
+
### **📂 Code**
|
200 |
+
[Explore the code](https://github.com/VikhrModels/Vikhr4o)
|
201 |
+
</div>
|
202 |
+
|
203 |
+
"""
|
204 |
+
|
205 |
+
# Launch Gradio App
|
206 |
demo.launch(share=True)
|