nareauow commited on
Commit
d658e55
·
verified ·
1 Parent(s): 9cebd8e
Files changed (1) hide show
  1. app.py +287 -0
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import scipy.io.wavfile as wav
6
+ from scipy.fftpack import idct
7
+ import gradio as gr
8
+ import os
9
+ import matplotlib.pyplot as plt
10
+ from huggingface_hub import hf_hub_download
11
+ from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
12
+ from transformers import pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
13
+ from datasets import load_dataset
14
+ import soundfile as sf
15
+
16
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17
+ print(f"Using device: {device}")
18
+
19
+ # Load speech-to-text model
20
+ try:
21
+ speech_recognizer = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr").to(device)
22
+ speech_processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")
23
+ print("Speech recognition model loaded successfully!")
24
+ except Exception as e:
25
+ print(f"Error loading speech recognition model: {e}")
26
+ speech_recognizer = None
27
+ speech_processor = None
28
+
29
+ # Load text-to-speech models
30
+ try:
31
+ # Load processor and model
32
+ tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
33
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
34
+ tts_vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
35
+
36
+ # Load speaker embeddings
37
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
38
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device)
39
+ print("Text-to-speech models loaded successfully!")
40
+ except Exception as e:
41
+ print(f"Error loading text-to-speech models: {e}")
42
+ tts_processor = None
43
+ tts_model = None
44
+ tts_vocoder = None
45
+ speaker_embeddings = None
46
+
47
+ # Modele CNN
48
+ class modele_CNN(nn.Module):
49
+ def __init__(self, num_classes=7, dropout=0.3):
50
+ super(modele_CNN, self).__init__()
51
+ self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
52
+ self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
53
+ self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
54
+ self.pool = nn.MaxPool2d(2, 2)
55
+ self.fc1 = nn.Linear(64 * 1 * 62, 128)
56
+ self.fc2 = nn.Linear(128, num_classes)
57
+ self.dropout = nn.Dropout(dropout)
58
+
59
+ def forward(self, x):
60
+ x = self.pool(F.relu(self.conv1(x)))
61
+ x = self.pool(F.relu(self.conv2(x)))
62
+ x = self.pool(F.relu(self.conv3(x)))
63
+ x = x.view(x.size(0), -1)
64
+ x = self.dropout(F.relu(self.fc1(x)))
65
+ x = self.fc2(x)
66
+ return x
67
+
68
+ # Audio processor
69
+ class AudioProcessor:
70
+ def Mel2Hz(self, mel): return 700 * (np.power(10, mel/2595)-1)
71
+ def Hz2Mel(self, freq): return 2595 * np.log10(1+freq/700)
72
+ def Hz2Ind(self, freq, fs, Tfft): return (freq*Tfft/fs).astype(int)
73
+
74
+ def hamming(self, T):
75
+ if T <= 1:
76
+ return np.ones(T)
77
+ return 0.54-0.46*np.cos(2*np.pi*np.arange(T)/(T-1))
78
+
79
+ def FiltresMel(self, fs, nf=36, Tfft=512, fmin=100, fmax=8000):
80
+ Indices = self.Hz2Ind(self.Mel2Hz(np.linspace(self.Hz2Mel(fmin), self.Hz2Mel(min(fmax, fs/2)), nf+2)), fs, Tfft)
81
+ filtres = np.zeros((int(Tfft/2), nf))
82
+ for i in range(nf): filtres[Indices[i]:Indices[i+2], i] = self.hamming(Indices[i+2]-Indices[i])
83
+ return filtres
84
+
85
+ def spectrogram(self, x, T, p, Tfft):
86
+ S = []
87
+ for i in range(0, len(x)-T, p): S.append(x[i:i+T]*self.hamming(T))
88
+ S = np.fft.fft(S, Tfft)
89
+ return np.abs(S), np.angle(S)
90
+
91
+ def mfcc(self, data, filtres, nc=13, T=256, p=64, Tfft=512):
92
+ data = (data[1]-np.mean(data[1]))/np.std(data[1])
93
+ amp, ph = self.spectrogram(data, T, p, Tfft)
94
+ amp_f = np.log10(np.dot(amp[:, :int(Tfft/2)], filtres)+1)
95
+ return idct(amp_f, n=nc, norm='ortho')
96
+
97
+ def process_audio(self, audio_data, sr, audio_length=32000):
98
+ if sr != 16000:
99
+ audio_resampled = np.interp(
100
+ np.linspace(0, len(audio_data), int(16000 * len(audio_data) / sr)),
101
+ np.arange(len(audio_data)),
102
+ audio_data
103
+ )
104
+ sgn = audio_resampled
105
+ fs = 16000
106
+ else:
107
+ sgn = audio_data
108
+ fs = sr
109
+
110
+ sgn = np.array(sgn, dtype=np.float32)
111
+
112
+ if len(sgn) > audio_length:
113
+ sgn = sgn[:audio_length]
114
+ else:
115
+ sgn = np.pad(sgn, (0, audio_length - len(sgn)), mode='constant')
116
+
117
+ filtres = self.FiltresMel(fs)
118
+ sgn_features = self.mfcc([fs, sgn], filtres)
119
+
120
+ mfcc_tensor = torch.tensor(sgn_features.T, dtype=torch.float32)
121
+ mfcc_tensor = mfcc_tensor.unsqueeze(0).unsqueeze(0)
122
+
123
+ return mfcc_tensor
124
+
125
+ # Speech recognition function
126
+ def recognize_speech(audio_path):
127
+ if speech_recognizer is None or speech_processor is None:
128
+ return "Speech recognition model not available"
129
+
130
+ try:
131
+ # Read audio file
132
+ audio_data, sr = sf.read(audio_path)
133
+
134
+ # Resample to 16kHz if needed
135
+ if sr != 16000:
136
+ audio_data = np.interp(
137
+ np.linspace(0, len(audio_data), int(16000 * len(audio_data) / sr)),
138
+ np.arange(len(audio_data)),
139
+ audio_data
140
+ )
141
+ sr = 16000
142
+
143
+ # Process audio
144
+ inputs = speech_processor(audio_data, sampling_rate=sr, return_tensors="pt")
145
+ inputs = {k: v.to(device) for k, v in inputs.items()}
146
+
147
+ # Generate transcription
148
+ generated_ids = speech_recognizer.generate(**inputs)
149
+ transcription = speech_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
150
+
151
+ return transcription
152
+ except Exception as e:
153
+ return f"Speech recognition error: {str(e)}"
154
+
155
+ # Speech synthesis function
156
+ def synthesize_speech(text):
157
+ if tts_processor is None or tts_model is None or tts_vocoder is None or speaker_embeddings is None:
158
+ return None
159
+
160
+ try:
161
+ # Preprocess text
162
+ inputs = tts_processor(text=text, return_tensors="pt").to(device)
163
+
164
+ # Generate speech with speaker embeddings
165
+ spectrogram = tts_model.generate_speech(inputs["input_ids"], speaker_embeddings)
166
+
167
+ # Convert to waveform
168
+ with torch.no_grad():
169
+ speech = tts_vocoder(spectrogram)
170
+
171
+ # Convert to numpy array and normalize
172
+ speech = speech.cpu().numpy()
173
+ speech = speech / np.max(np.abs(speech))
174
+
175
+ return (16000, speech.squeeze())
176
+ except Exception as e:
177
+ print(f"Speech synthesis error: {str(e)}")
178
+ return None
179
+
180
+ # Fonction prédiction
181
+ def predict_speaker(audio, model, processor):
182
+ if audio is None:
183
+ return "Aucun audio détecté.", None, None
184
+
185
+ try:
186
+ audio_data, sr = sf.read(audio)
187
+ input_tensor = processor.process_audio(audio_data, sr)
188
+
189
+ device = next(model.parameters()).device
190
+ input_tensor = input_tensor.to(device)
191
+
192
+ with torch.no_grad():
193
+ output = model(input_tensor)
194
+ print(output)
195
+ probabilities = F.softmax(output, dim=1)
196
+ confidence, predicted_class = torch.max(probabilities, 1)
197
+
198
+ speakers = ["George", "Jackson", "Lucas", "Nicolas", "Theo", "Yweweler", "Narimene"]
199
+ predicted_speaker = speakers[predicted_class.item()]
200
+
201
+ result = f"Locuteur reconnu : {predicted_speaker} (confiance : {confidence.item()*100:.2f}%)"
202
+
203
+ probs_dict = {speakers[i]: float(probs) for i, probs in enumerate(probabilities[0].cpu().numpy())}
204
+
205
+ # Recognize speech
206
+ recognized_text = recognize_speech(audio)
207
+
208
+ return result, probs_dict, recognized_text,predicted_speaker
209
+
210
+ except Exception as e:
211
+ return f"Erreur : {str(e)}", None, None
212
+
213
+ # Charger modèle
214
+ def load_model(model_id="nareauow/my_speech_recognition", model_filename="model_3.pth"):
215
+ try:
216
+ model_path = hf_hub_download(repo_id=model_id, filename=model_filename)
217
+ model = modele_CNN(num_classes=7, dropout=0.)
218
+ model.load_state_dict(torch.load(model_path, map_location=device))
219
+ model.to(device)
220
+ model.eval()
221
+ print("Modèle chargé avec succès !")
222
+ return model
223
+ except Exception as e:
224
+ print(f"Erreur de chargement: {e}")
225
+ return None
226
+
227
+ # Gradio Interface
228
+ def create_interface():
229
+ processor = AudioProcessor()
230
+
231
+ with gr.Blocks(title="Reconnaissance de Locuteur") as interface:
232
+ gr.Markdown("# 🗣️ Reconnaissance de Locuteur")
233
+ gr.Markdown("Enregistrez votre voix pendant 2 secondes pour identifier qui parle.")
234
+
235
+ with gr.Row():
236
+ with gr.Column():
237
+ model_selector = gr.Dropdown(
238
+ choices=["model_1.pth", "model_2.pth", "model_3.pth"],
239
+ value="model_3.pth",
240
+ label="Choisissez le modèle"
241
+ )
242
+ audio_input = gr.Audio(sources=["microphone"], type="filepath", label="🎙️ Parlez ici")
243
+ record_btn = gr.Button("Reconnaître")
244
+ with gr.Column():
245
+ result_text = gr.Textbox(label="Résultat")
246
+ plot_output = gr.Plot(label="Confiance par locuteur")
247
+ recognized_text = gr.Textbox(label="Texte reconnu")
248
+ audio_output = gr.Audio(label="Synthèse vocale", type="numpy")
249
+
250
+ def recognize(audio, selected_model):
251
+ model = load_model(model_filename=selected_model)
252
+ res, probs, text,locuteur = predict_speaker(audio, model, processor)
253
+
254
+ # Generate plot
255
+ fig = None
256
+ if probs:
257
+ fig, ax = plt.subplots()
258
+ ax.bar(probs.keys(), probs.values(), color='skyblue')
259
+ ax.set_ylim([0, 1])
260
+ ax.set_ylabel("Confiance")
261
+ ax.set_xlabel("Locuteurs")
262
+ plt.xticks(rotation=45)
263
+
264
+ # Generate speech synthesis if text was recognized
265
+ synth_audio = None
266
+ if text and "error" not in text.lower():
267
+ synth_text = f"{locuteur} said : {text}"
268
+ synth_audio = synthesize_speech(synth_text)
269
+
270
+ return res, fig, text, synth_audio
271
+
272
+ record_btn.click(fn=recognize,
273
+ inputs=[audio_input, model_selector],
274
+ outputs=[result_text, plot_output, recognized_text, audio_output])
275
+
276
+ gr.Markdown("""### Comment utiliser ?
277
+ - Choisissez le modèle.
278
+ - Cliquez sur 🎙️ pour enregistrer votre voix.
279
+ - Cliquez sur **Reconnaître** pour obtenir la prédiction.
280
+ """)
281
+
282
+ return interface
283
+
284
+ # Lancer
285
+ if __name__ == "__main__":
286
+ app = create_interface()
287
+ app.launch(share=True)