nareauow commited on
Commit
d3d626b
·
verified ·
1 Parent(s): d698901

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -35
app.py CHANGED
@@ -1,37 +1,185 @@
1
- with gr.Row():
2
- with gr.Column():
3
- # Dropdown pour sélectionner le modèle
4
- model_selector = gr.Dropdown(
5
- choices=["model_1.pth", "model_2.pth", "model_3.pth"],
6
- value="model_3.pth",
7
- label="Choisissez le modèle"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  )
9
-
10
- # Créer des onglets pour Microphone et Upload Audio
11
- with gr.Tab("Microphone"):
12
- mic_input = gr.Audio(sources=["microphone"], type="filepath", label="🎙️ Enregistrer depuis le microphone")
13
-
14
- with gr.Tab("Upload Audio"):
15
- file_input = gr.Audio(sources=["upload"], type="filepath", label="📁 Télécharger un fichier audio")
16
-
17
- # Bouton pour démarrer la reconnaissance
18
- record_btn = gr.Button("Reconnaître")
19
-
20
- with gr.Column():
21
- # Résultat, graphique et texte reconnu
22
- result_text = gr.Textbox(label="Résultat")
23
- plot_output = gr.Plot(label="Confiance par locuteur")
24
- recognized_text = gr.Textbox(label="Texte reconnu")
25
- audio_output = gr.Audio(label="Synthèse vocale", visible=False)
26
-
27
- # Fonction de clique pour la reconnaissance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def recognize(audio, selected_model):
29
- # Traitement audio et modèle à charger...
30
- pass # Remplace ici avec ton code de traitement
31
-
32
- # Lier le bouton "Reconnaître" à la fonction
33
- record_btn.click(
34
- fn=recognize,
35
- inputs=[mic_input, file_input, model_selector], # Remplacer Union par les deux inputs distincts
36
- outputs=[result_text, plot_output, recognized_text, audio_output]
37
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import scipy.io.wavfile as wav
6
+ from scipy.fftpack import idct
7
+ import gradio as gr
8
+ import os
9
+ import matplotlib.pyplot as plt
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
+ print(f"Using device: {device}")
14
+
15
+ # Modele CNN
16
+ class modele_CNN(nn.Module):
17
+ def __init__(self, num_classes=8, dropout=0.3):
18
+ super(modele_CNN, self).__init__()
19
+ self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
20
+ self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
21
+ self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
22
+ self.pool = nn.MaxPool2d(2, 2)
23
+ self.fc1 = nn.Linear(64 * 1 * 62, 128)
24
+ self.fc2 = nn.Linear(128, num_classes)
25
+ self.dropout = nn.Dropout(dropout)
26
+
27
+ def forward(self, x):
28
+ x = self.pool(F.relu(self.conv1(x)))
29
+ x = self.pool(F.relu(self.conv2(x)))
30
+ x = self.pool(F.relu(self.conv3(x)))
31
+ x = x.view(x.size(0), -1)
32
+ x = self.dropout(F.relu(self.fc1(x)))
33
+ x = self.fc2(x)
34
+ return x
35
+
36
+ # Audio processor
37
+ class AudioProcessor:
38
+ def Mel2Hz(self, mel): return 700 * (np.power(10, mel/2595)-1)
39
+ def Hz2Mel(self, freq): return 2595 * np.log10(1+freq/700)
40
+ def Hz2Ind(self, freq, fs, Tfft): return (freq*Tfft/fs).astype(int)
41
+
42
+ def hamming(self, T):
43
+ if T <= 1:
44
+ return np.ones(T)
45
+ return 0.54-0.46*np.cos(2*np.pi*np.arange(T)/(T-1))
46
+
47
+ def FiltresMel(self, fs, nf=36, Tfft=512, fmin=100, fmax=8000):
48
+ Indices = self.Hz2Ind(self.Mel2Hz(np.linspace(self.Hz2Mel(fmin), self.Hz2Mel(min(fmax, fs/2)), nf+2)), fs, Tfft)
49
+ filtres = np.zeros((int(Tfft/2), nf))
50
+ for i in range(nf): filtres[Indices[i]:Indices[i+2], i] = self.hamming(Indices[i+2]-Indices[i])
51
+ return filtres
52
+
53
+ def spectrogram(self, x, T, p, Tfft):
54
+ S = []
55
+ for i in range(0, len(x)-T, p): S.append(x[i:i+T]*self.hamming(T))
56
+ S = np.fft.fft(S, Tfft)
57
+ return np.abs(S), np.angle(S)
58
+
59
+ def mfcc(self, data, filtres, nc=13, T=256, p=64, Tfft=512):
60
+ data = (data[1]-np.mean(data[1]))/np.std(data[1])
61
+ amp, ph = self.spectrogram(data, T, p, Tfft)
62
+ amp_f = np.log10(np.dot(amp[:, :int(Tfft/2)], filtres)+1)
63
+ return idct(amp_f, n=nc, norm='ortho')
64
+
65
+ def process_audio(self, audio_data, sr, audio_length=32000):
66
+ if sr != 16000:
67
+ audio_resampled = np.interp(
68
+ np.linspace(0, len(audio_data), int(16000 * len(audio_data) / sr)),
69
+ np.arange(len(audio_data)),
70
+ audio_data
71
  )
72
+ sgn = audio_resampled
73
+ fs = 16000
74
+ else:
75
+ sgn = audio_data
76
+ fs = sr
77
+
78
+ sgn = np.array(sgn, dtype=np.float32)
79
+
80
+ if len(sgn) > audio_length:
81
+ sgn = sgn[:audio_length]
82
+ else:
83
+ sgn = np.pad(sgn, (0, audio_length - len(sgn)), mode='constant')
84
+
85
+ filtres = self.FiltresMel(fs)
86
+ sgn_features = self.mfcc([fs, sgn], filtres)
87
+
88
+ mfcc_tensor = torch.tensor(sgn_features.T, dtype=torch.float32)
89
+ mfcc_tensor = mfcc_tensor.unsqueeze(0).unsqueeze(0)
90
+
91
+ return mfcc_tensor
92
+
93
+ # Fonction prédiction
94
+ def predict_speaker(audio, model, processor):
95
+ if audio is None:
96
+ return "Aucun audio détecté.", None
97
+
98
+ try:
99
+ import soundfile as sf
100
+ audio_data, sr = sf.read(audio) # <- ici tu lis direct l'audio
101
+ input_tensor = processor.process_audio(audio_data, sr)
102
+
103
+ device = next(model.parameters()).device
104
+ input_tensor = input_tensor.to(device)
105
+
106
+ with torch.no_grad():
107
+ output = model(input_tensor)
108
+ print(output)
109
+ probabilities = F.softmax(output, dim=1)
110
+ confidence, predicted_class = torch.max(probabilities, 1)
111
+
112
+ speakers = ["George", "Jackson", "Lucas", "Nicolas", "Theo", "Yweweler", "Narimene"]
113
+ predicted_speaker = speakers[predicted_class.item()]
114
+
115
+ result = f"Locuteur reconnu : {predicted_speaker} (confiance : {confidence.item()*100:.2f}%)"
116
+
117
+ probs_dict = {speakers[i]: float(probs) for i, probs in enumerate(probabilities[0].cpu().numpy())}
118
+
119
+ return result, probs_dict
120
+
121
+ except Exception as e:
122
+ return f"Erreur : {str(e)}", None
123
+
124
+ # Charger modèle
125
+ def load_model(model_id="nareauow/my_speech_recognition", model_filename="model_3.pth"):
126
+ try:
127
+ model_path = hf_hub_download(repo_id=model_id, filename=model_filename)
128
+ model = modele_CNN(num_classes=7, dropout=0.)
129
+ model.load_state_dict(torch.load(model_path, map_location=device))
130
+ model.to(device)
131
+ model.eval()
132
+ print("Modèle chargé avec succès !")
133
+ return model
134
+ except Exception as e:
135
+ print(f"Erreur de chargement: {e}")
136
+ return None
137
+
138
+ # Gradio Interface
139
+ def create_interface():
140
+ processor = AudioProcessor()
141
+
142
+ with gr.Blocks(title="Reconnaissance de Locuteur") as interface:
143
+ gr.Markdown("# 🗣️ Reconnaissance de Locuteur")
144
+ gr.Markdown("Enregistrez votre voix pendant 2 secondes pour identifier qui parle.")
145
+
146
+ with gr.Row():
147
+ with gr.Column():
148
+ model_selector = gr.Dropdown(
149
+ choices=["model_1.pth", "model_2.pth", "model_3.pth"],
150
+ value="model_3.pth",
151
+ label="Choisissez le modèle"
152
+ )
153
+ audio_input = gr.Audio(sources=["microphone"], type="filepath", label="🎙️ Parlez ici")
154
+ record_btn = gr.Button("Reconnaître")
155
+ with gr.Column():
156
+ result_text = gr.Textbox(label="Résultat")
157
+ plot_output = gr.Plot(label="Confiance par locuteur")
158
+
159
  def recognize(audio, selected_model):
160
+ model = load_model(model_filename=selected_model) # Charger le modèle choisi
161
+ res, probs = predict_speaker(audio, model, processor)
162
+ fig = None
163
+ if probs:
164
+ fig, ax = plt.subplots()
165
+ ax.bar(probs.keys(), probs.values(), color='skyblue')
166
+ ax.set_ylim([0, 1])
167
+ ax.set_ylabel("Confiance")
168
+ ax.set_xlabel("Locuteurs")
169
+ plt.xticks(rotation=45)
170
+ return res, fig
171
+
172
+ record_btn.click(fn=recognize, inputs=[audio_input, model_selector], outputs=[result_text, plot_output])
173
+
174
+ gr.Markdown("""### Comment utiliser ?
175
+ - Choisissez le modèle.
176
+ - Cliquez sur 🎙️ pour enregistrer votre voix.
177
+ - Cliquez sur **Reconnaître** pour obtenir la prédiction.
178
+ """)
179
+
180
+ return interface
181
+
182
+ # Lancer
183
+ if __name__ == "__main__":
184
+ app = create_interface()
185
+ app.launch(share=True)