hriteshMaikap commited on
Commit
6fd65bb
·
verified ·
1 Parent(s): ba7a495

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +264 -11
app.py CHANGED
@@ -1,20 +1,273 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
 
 
 
 
 
3
 
4
- classifier = pipeline("audio-classification", model="hriteshMaikap/languageClassifier")
 
 
5
 
6
- def predict_language(audio):
7
- out = classifier(audio)
8
- # out is a list of dicts: [{'label': 'Hindi', 'score': 0.98}, ...]
9
- return "\n".join([f"{res['label']}: {res['score']:.2f}" for res in out])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  demo = gr.Interface(
12
- fn=predict_language,
13
- inputs=gr.Audio(source="microphone", type="filepath"),
14
- outputs="text",
15
- title="Indian Language Identifier",
16
- description="Record audio and classify the spoken Indian language."
 
 
 
 
 
 
 
 
 
 
 
 
17
  )
18
 
 
19
  if __name__ == "__main__":
 
 
20
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchaudio
5
+ import json
6
+ import numpy as np
7
+ from huggingface_hub import hf_hub_download
8
+ from transformers import PretrainedConfig, PreTrainedModel
9
 
10
+ # Define model architecture
11
+ class AudioLanguageClassifierConfig(PretrainedConfig):
12
+ model_type = "audio-language-classifier"
13
 
14
+ def __init__(
15
+ self,
16
+ num_labels=10, # Changed from 12 to 10 to match the saved model
17
+ sampling_rate=16000,
18
+ num_mel_bins=128,
19
+ feature_size=512,
20
+ num_transformer_layers=4,
21
+ num_attention_heads=4,
22
+ intermediate_size=1024,
23
+ hidden_dropout_prob=0.1,
24
+ attention_probs_dropout_prob=0.1,
25
+ **kwargs
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.num_labels = num_labels
29
+ self.sampling_rate = sampling_rate
30
+ self.num_mel_bins = num_mel_bins
31
+ self.feature_size = feature_size
32
+ self.num_transformer_layers = num_transformer_layers
33
+ self.num_attention_heads = num_attention_heads
34
+ self.intermediate_size = intermediate_size
35
+ self.hidden_dropout_prob = hidden_dropout_prob
36
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
37
 
38
+ class AudioFeatureExtractor:
39
+ def __init__(self, config):
40
+ self.config = config
41
+ self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
42
+ sample_rate=config.sampling_rate,
43
+ n_fft=1024,
44
+ hop_length=512,
45
+ n_mels=config.num_mel_bins
46
+ )
47
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
48
+
49
+ def __call__(self, audio_data, padding=True, max_length=None, truncation=True, **kwargs):
50
+ if isinstance(audio_data, np.ndarray):
51
+ audio_data = torch.from_numpy(audio_data)
52
+
53
+ # Ensure it's in the expected shape
54
+ if audio_data.ndim == 1:
55
+ audio_data = audio_data.unsqueeze(0) # Add channel dimension
56
+
57
+ # Convert to mel spectrogram
58
+ mel_spec = self.mel_spectrogram(audio_data)
59
+ log_mel_spec = self.amplitude_to_db(mel_spec)
60
+
61
+ # Normalization
62
+ mean = log_mel_spec.mean()
63
+ std = log_mel_spec.std()
64
+ log_mel_spec = (log_mel_spec - mean) / (std + 1e-10)
65
+
66
+ # Handle max length/truncation
67
+ if max_length is not None and truncation and log_mel_spec.shape[-1] > max_length:
68
+ log_mel_spec = log_mel_spec[..., :max_length]
69
+
70
+ return {"input_values": log_mel_spec}
71
+
72
+ class AudioLanguageClassifier(PreTrainedModel):
73
+ config_class = AudioLanguageClassifierConfig
74
+
75
+ def __init__(self, config):
76
+ super().__init__(config)
77
+ self.config = config
78
+
79
+ # CNN feature extractor
80
+ self.feature_extractor = nn.Sequential(
81
+ nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
82
+ nn.ReLU(),
83
+ nn.MaxPool2d(2, 2),
84
+ nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
85
+ nn.ReLU(),
86
+ nn.MaxPool2d(2, 2)
87
+ )
88
+
89
+ # Global average pooling to eliminate size dependency
90
+ self.global_pool = nn.AdaptiveAvgPool2d((4, 4))
91
+
92
+ # Fixed size after global pooling
93
+ self.flattened_size = 64 * 4 * 4
94
+
95
+ # Projection layer with fixed input size
96
+ self.projection = nn.Linear(self.flattened_size, config.feature_size)
97
+
98
+ # Transformer for sequence modeling
99
+ encoder_layer = nn.TransformerEncoderLayer(
100
+ d_model=config.feature_size,
101
+ nhead=config.num_attention_heads,
102
+ dim_feedforward=config.intermediate_size,
103
+ dropout=config.hidden_dropout_prob,
104
+ batch_first=True
105
+ )
106
+ self.transformer_encoder = nn.TransformerEncoder(
107
+ encoder_layer,
108
+ num_layers=config.num_transformer_layers
109
+ )
110
+
111
+ # Classification head
112
+ self.classifier = nn.Linear(config.feature_size, config.num_labels)
113
+
114
+ def forward(
115
+ self,
116
+ input_values=None,
117
+ labels=None,
118
+ **kwargs
119
+ ):
120
+ batch_size = input_values.size(0)
121
+
122
+ # Extract features using CNN
123
+ x = self.feature_extractor(input_values)
124
+
125
+ # Apply global pooling to get fixed size
126
+ x = self.global_pool(x)
127
+
128
+ # Flatten
129
+ x = x.view(batch_size, -1)
130
+
131
+ # Project to transformer dimension
132
+ x = self.projection(x)
133
+
134
+ # Add sequence dimension for transformer
135
+ x = x.unsqueeze(1) # [batch_size, 1, feature_size]
136
+
137
+ # Transformer encoding
138
+ x = self.transformer_encoder(x)
139
+
140
+ # Classification
141
+ x = x[:, 0, :] # Take first token representation
142
+ logits = self.classifier(x)
143
+
144
+ loss = None
145
+ if labels is not None:
146
+ loss_fct = nn.CrossEntropyLoss()
147
+ loss = loss_fct(logits, labels)
148
+
149
+ return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
150
+
151
+ # Function to load the model and its configuration
152
+ def load_model():
153
+ # Download the model files
154
+ repo_id = "hriteshMaikap/languageClassifier"
155
+
156
+ try:
157
+ model_path = hf_hub_download(repo_id=repo_id, filename="model.pt")
158
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
159
+ mappings_path = hf_hub_download(repo_id=repo_id, filename="language_mappings.json")
160
+
161
+ # Load the config
162
+ with open(config_path, "r") as f:
163
+ config_dict = json.load(f)
164
+
165
+ # IMPORTANT: Override num_labels to 10 since the model was trained with 10 classes
166
+ config_dict["num_labels"] = 10
167
+
168
+ config = AudioLanguageClassifierConfig(**config_dict)
169
+
170
+ # Load the model
171
+ model = AudioLanguageClassifier(config)
172
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
173
+ model.eval()
174
+
175
+ # Load language mappings
176
+ with open(mappings_path, "r") as f:
177
+ mappings = json.load(f)
178
+
179
+ id_to_language = {int(k): v for k, v in mappings["id_to_language"].items()}
180
+
181
+ return model, config, id_to_language
182
+
183
+ except Exception as e:
184
+ gr.Warning(f"Error loading model: {e}")
185
+ # Return placeholders with error message
186
+ raise gr.Error(f"Failed to load the language classification model: {e}")
187
+
188
+ # Function to process audio and make predictions
189
+ def classify_language(audio):
190
+ try:
191
+ # Load model on first inference
192
+ global model, config, id_to_language, feature_extractor
193
+ if 'model' not in globals() or model is None:
194
+ model, config, id_to_language = load_model()
195
+ feature_extractor = AudioFeatureExtractor(config)
196
+
197
+ # Get audio data
198
+ sr, waveform = audio
199
+
200
+ # Convert to torch tensor
201
+ waveform = torch.tensor(waveform).float()
202
+
203
+ # Ensure mono
204
+ if waveform.ndim > 1 and waveform.shape[0] > 1:
205
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
206
+ elif waveform.ndim == 1:
207
+ waveform = waveform.unsqueeze(0)
208
+
209
+ # Resample to 16kHz if needed
210
+ if sr != 16000:
211
+ resampler = torchaudio.transforms.Resample(sr, 16000)
212
+ waveform = resampler(waveform)
213
+
214
+ # Extract features
215
+ features = feature_extractor(waveform, max_length=256)
216
+ input_values = features["input_values"]
217
+
218
+ # Pad or truncate to fixed length
219
+ _, height, width = input_values.shape
220
+ max_length = 256
221
+ if width < max_length:
222
+ padding = torch.zeros(1, height, max_length - width)
223
+ input_values = torch.cat([input_values, padding], dim=2)
224
+ elif width > max_length:
225
+ input_values = input_values[:, :, :max_length]
226
+
227
+ # Get prediction
228
+ with torch.no_grad():
229
+ outputs = model(input_values=input_values)
230
+ logits = outputs["logits"]
231
+ probs = torch.nn.functional.softmax(logits, dim=1)[0]
232
+
233
+ # Get top predictions
234
+ num_classes = min(3, len(id_to_language))
235
+ top_probs, top_ids = torch.topk(probs, num_classes)
236
+
237
+ # Format results
238
+ results = {}
239
+ for i, (prob, pred_id) in enumerate(zip(top_probs, top_ids)):
240
+ lang = id_to_language.get(pred_id.item(), f"Unknown-{pred_id.item()}")
241
+ results[lang] = float(prob)
242
+
243
+ return results
244
+
245
+ except Exception as e:
246
+ return {"Error": 1.0, "Details": str(e)}
247
+
248
+ # Create the Gradio interface
249
  demo = gr.Interface(
250
+ fn=classify_language,
251
+ # Changed type from "tuple" to "numpy" to fix the error
252
+ inputs=gr.Audio(sources=["microphone", "upload"], type="numpy"),
253
+ outputs=gr.Label(num_top_classes=3),
254
+ title="Indian Language Identification",
255
+ description="Record or upload audio to identify the Indian language being spoken.",
256
+ examples=[],
257
+ article="""
258
+ <div style="text-align: center;">
259
+ <p>This model identifies various Indian languages from audio input. For best results:</p>
260
+ <ul style="display: inline-block; text-align: left;">
261
+ <li>Speak clearly with minimal background noise</li>
262
+ <li>Recording length of 3-5 seconds is ideal</li>
263
+ <li>Make sure to speak a full sentence or phrase</li>
264
+ </ul>
265
+ </div>
266
+ """
267
  )
268
 
269
+ # Launch the app
270
  if __name__ == "__main__":
271
+ # Initialize model as None to lazy-load on first inference
272
+ model, config, id_to_language, feature_extractor = None, None, None, None
273
  demo.launch()