hriteshMaikap commited on
Commit
018dafc
·
verified ·
1 Parent(s): 4c524a0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +287 -0
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchaudio
5
+ import json
6
+ import numpy as np
7
+ from huggingface_hub import hf_hub_download
8
+ from transformers import PretrainedConfig, PreTrainedModel
9
+
10
+ # Define model architecture (same as your training code)
11
+ class AudioLanguageClassifierConfig(PretrainedConfig):
12
+ model_type = "audio-language-classifier"
13
+
14
+ def __init__(
15
+ self,
16
+ num_labels=12,
17
+ sampling_rate=16000,
18
+ num_mel_bins=128,
19
+ feature_size=512,
20
+ num_transformer_layers=4,
21
+ num_attention_heads=4,
22
+ intermediate_size=1024,
23
+ hidden_dropout_prob=0.1,
24
+ attention_probs_dropout_prob=0.1,
25
+ **kwargs
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.num_labels = num_labels
29
+ self.sampling_rate = sampling_rate
30
+ self.num_mel_bins = num_mel_bins
31
+ self.feature_size = feature_size
32
+ self.num_transformer_layers = num_transformer_layers
33
+ self.num_attention_heads = num_attention_heads
34
+ self.intermediate_size = intermediate_size
35
+ self.hidden_dropout_prob = hidden_dropout_prob
36
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
37
+
38
+ class AudioFeatureExtractor:
39
+ def __init__(self, config):
40
+ self.config = config
41
+ self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
42
+ sample_rate=config.sampling_rate,
43
+ n_fft=1024,
44
+ hop_length=512,
45
+ n_mels=config.num_mel_bins
46
+ )
47
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
48
+
49
+ def __call__(self, audio_data, padding=True, max_length=None, truncation=True, **kwargs):
50
+ if isinstance(audio_data, np.ndarray):
51
+ audio_data = torch.from_numpy(audio_data)
52
+
53
+ # Ensure it's in the expected shape
54
+ if audio_data.ndim == 1:
55
+ audio_data = audio_data.unsqueeze(0) # Add channel dimension
56
+
57
+ # Convert to mel spectrogram
58
+ mel_spec = self.mel_spectrogram(audio_data)
59
+ log_mel_spec = self.amplitude_to_db(mel_spec)
60
+
61
+ # Normalization
62
+ mean = log_mel_spec.mean()
63
+ std = log_mel_spec.std()
64
+ log_mel_spec = (log_mel_spec - mean) / (std + 1e-10)
65
+
66
+ # Handle max length/truncation
67
+ if max_length is not None and truncation and log_mel_spec.shape[-1] > max_length:
68
+ log_mel_spec = log_mel_spec[..., :max_length]
69
+
70
+ return {"input_values": log_mel_spec}
71
+
72
+ class AudioLanguageClassifier(PreTrainedModel):
73
+ config_class = AudioLanguageClassifierConfig
74
+
75
+ def __init__(self, config):
76
+ super().__init__(config)
77
+ self.config = config
78
+
79
+ # CNN feature extractor
80
+ self.feature_extractor = nn.Sequential(
81
+ nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
82
+ nn.ReLU(),
83
+ nn.MaxPool2d(2, 2),
84
+ nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
85
+ nn.ReLU(),
86
+ nn.MaxPool2d(2, 2)
87
+ )
88
+
89
+ # Global average pooling to eliminate size dependency
90
+ self.global_pool = nn.AdaptiveAvgPool2d((4, 4))
91
+
92
+ # Fixed size after global pooling
93
+ self.flattened_size = 64 * 4 * 4
94
+
95
+ # Projection layer with fixed input size
96
+ self.projection = nn.Linear(self.flattened_size, config.feature_size)
97
+
98
+ # Transformer for sequence modeling
99
+ encoder_layer = nn.TransformerEncoderLayer(
100
+ d_model=config.feature_size,
101
+ nhead=config.num_attention_heads,
102
+ dim_feedforward=config.intermediate_size,
103
+ dropout=config.hidden_dropout_prob,
104
+ batch_first=True
105
+ )
106
+ self.transformer_encoder = nn.TransformerEncoder(
107
+ encoder_layer,
108
+ num_layers=config.num_transformer_layers
109
+ )
110
+
111
+ # Classification head - use num_labels from config or default to 12 if not available
112
+ num_labels = getattr(config, "num_labels", len(getattr(config, "id2label", {12: ""})))
113
+ self.classifier = nn.Linear(config.feature_size, num_labels)
114
+
115
+ def forward(
116
+ self,
117
+ input_values=None,
118
+ labels=None,
119
+ **kwargs
120
+ ):
121
+ batch_size = input_values.size(0)
122
+
123
+ # Extract features using CNN
124
+ x = self.feature_extractor(input_values)
125
+
126
+ # Apply global pooling to get fixed size
127
+ x = self.global_pool(x)
128
+
129
+ # Flatten
130
+ x = x.view(batch_size, -1)
131
+
132
+ # Project to transformer dimension
133
+ x = self.projection(x)
134
+
135
+ # Add sequence dimension for transformer
136
+ x = x.unsqueeze(1) # [batch_size, 1, feature_size]
137
+
138
+ # Transformer encoding
139
+ x = self.transformer_encoder(x)
140
+
141
+ # Classification
142
+ x = x[:, 0, :] # Take first token representation
143
+ logits = self.classifier(x)
144
+
145
+ loss = None
146
+ if labels is not None:
147
+ loss_fct = nn.CrossEntropyLoss()
148
+ loss = loss_fct(logits, labels)
149
+
150
+ return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
151
+
152
+ # Function to load the model and its configuration
153
+ def load_model():
154
+ # Download the model files
155
+ repo_id = "hriteshMaikap/languageClassifier"
156
+
157
+ try:
158
+ model_path = hf_hub_download(repo_id=repo_id, filename="model.pt")
159
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
160
+ mappings_path = hf_hub_download(repo_id=repo_id, filename="language_mappings.json")
161
+
162
+ # Load the config
163
+ with open(config_path, "r") as f:
164
+ config_dict = json.load(f)
165
+
166
+ config = AudioLanguageClassifierConfig(**config_dict)
167
+
168
+ # Load the model
169
+ model = AudioLanguageClassifier(config)
170
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
171
+ model.eval()
172
+
173
+ # Load language mappings
174
+ with open(mappings_path, "r") as f:
175
+ mappings = json.load(f)
176
+
177
+ id_to_language = {int(k): v for k, v in mappings["id_to_language"].items()}
178
+
179
+ # If id2label is updated in the config, use that instead
180
+ if hasattr(config, "id2label") and config.id2label:
181
+ if not all(v == f"LABEL_{k}" for k, v in config.id2label.items()):
182
+ id_to_language = {int(k): v for k, v in config.id2label.items()}
183
+
184
+ return model, config, id_to_language
185
+
186
+ except Exception as e:
187
+ gr.Warning(f"Error loading model: {e}")
188
+ # Return placeholders with error message
189
+ raise gr.Error(f"Failed to load the language classification model: {e}")
190
+
191
+ # Prepare the feature extractor and model
192
+ try:
193
+ model, config, id_to_language = load_model()
194
+ feature_extractor = AudioFeatureExtractor(config)
195
+ languages = list(id_to_language.values())
196
+ except Exception as e:
197
+ model, config, id_to_language = None, None, {}
198
+ languages = []
199
+ print(f"Error initializing model: {e}")
200
+
201
+ # Function to process audio and make predictions
202
+ def classify_language(audio):
203
+ if model is None or config is None:
204
+ return {"Error": 1.0}
205
+
206
+ if audio is None:
207
+ return {"No audio detected": 1.0}
208
+
209
+ try:
210
+ # Process audio
211
+ sr, waveform = audio
212
+
213
+ # Convert to torch tensor
214
+ waveform = torch.tensor(waveform).float()
215
+
216
+ # Ensure mono
217
+ if waveform.ndim > 1 and waveform.shape[0] > 1:
218
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
219
+ elif waveform.ndim == 1:
220
+ waveform = waveform.unsqueeze(0)
221
+
222
+ # Resample to 16kHz if needed
223
+ if sr != 16000:
224
+ resampler = torchaudio.transforms.Resample(sr, 16000)
225
+ waveform = resampler(waveform)
226
+
227
+ # Extract features
228
+ features = feature_extractor(waveform, max_length=256)
229
+ input_values = features["input_values"]
230
+
231
+ # Pad or truncate to fixed length
232
+ _, height, width = input_values.shape
233
+ max_length = 256
234
+ if width < max_length:
235
+ padding = torch.zeros(1, height, max_length - width)
236
+ input_values = torch.cat([input_values, padding], dim=2)
237
+ elif width > max_length:
238
+ input_values = input_values[:, :, :max_length]
239
+
240
+ # Get prediction
241
+ with torch.no_grad():
242
+ outputs = model(input_values=input_values)
243
+ logits = outputs["logits"]
244
+ probs = torch.nn.functional.softmax(logits, dim=1)[0]
245
+ predicted_id = torch.argmax(probs).item()
246
+
247
+ # Get top 3 predictions (or all if fewer than 3)
248
+ num_classes = min(3, len(id_to_language))
249
+ top_probs, top_ids = torch.topk(probs, num_classes)
250
+
251
+ # Format results
252
+ results = {}
253
+ for i, (prob, pred_id) in enumerate(zip(top_probs, top_ids)):
254
+ lang = id_to_language.get(pred_id.item(), f"Unknown-{pred_id.item()}")
255
+ results[lang] = float(prob)
256
+
257
+ return results
258
+
259
+ except Exception as e:
260
+ gr.Warning(f"Error processing audio: {e}")
261
+ return {"Error processing audio": 1.0}
262
+
263
+ # Create the Gradio interface
264
+ demo = gr.Interface(
265
+ fn=classify_language,
266
+ inputs=gr.Audio(sources=["microphone", "upload"], type="tuple"),
267
+ outputs=gr.Label(num_top_classes=3),
268
+ title="Indian Language Identification",
269
+ description="Record or upload audio to identify the Indian language being spoken. Supported languages: " +
270
+ ", ".join(languages) if languages else "Error loading language list",
271
+ examples=[],
272
+ article="""
273
+ <div style="text-align: center;">
274
+ <p>This model identifies various Indian languages from audio input. For best results:</p>
275
+ <ul style="display: inline-block; text-align: left;">
276
+ <li>Speak clearly with minimal background noise</li>
277
+ <li>Recording length of 3-5 seconds is ideal</li>
278
+ <li>Make sure to speak a full sentence or phrase</li>
279
+ </ul>
280
+ <p>Model by <a href="https://huggingface.co/hriteshMaikap/languageClassifier" target="_blank">hriteshMaikap</a></p>
281
+ </div>
282
+ """
283
+ )
284
+
285
+ # Launch the app
286
+ if __name__ == "__main__":
287
+ demo.launch()