hriteshMaikap commited on
Commit
8579e22
·
verified ·
1 Parent(s): f3090af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -263
app.py CHANGED
@@ -1,273 +1,63 @@
1
  import gradio as gr
2
  import torch
3
- import torch.nn as nn
4
  import torchaudio
5
  import json
6
- import numpy as np
7
- from huggingface_hub import hf_hub_download
8
- from transformers import PretrainedConfig, PreTrainedModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Define model architecture
11
- class AudioLanguageClassifierConfig(PretrainedConfig):
12
- model_type = "audio-language-classifier"
13
-
14
- def __init__(
15
- self,
16
- num_labels=10, # Changed from 12 to 10 to match the saved model
17
- sampling_rate=16000,
18
- num_mel_bins=128,
19
- feature_size=512,
20
- num_transformer_layers=4,
21
- num_attention_heads=4,
22
- intermediate_size=1024,
23
- hidden_dropout_prob=0.1,
24
- attention_probs_dropout_prob=0.1,
25
- **kwargs
26
- ):
27
- super().__init__(**kwargs)
28
- self.num_labels = num_labels
29
- self.sampling_rate = sampling_rate
30
- self.num_mel_bins = num_mel_bins
31
- self.feature_size = feature_size
32
- self.num_transformer_layers = num_transformer_layers
33
- self.num_attention_heads = num_attention_heads
34
- self.intermediate_size = intermediate_size
35
- self.hidden_dropout_prob = hidden_dropout_prob
36
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
37
-
38
- class AudioFeatureExtractor:
39
- def __init__(self, config):
40
- self.config = config
41
- self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
42
- sample_rate=config.sampling_rate,
43
- n_fft=1024,
44
- hop_length=512,
45
- n_mels=config.num_mel_bins
46
- )
47
- self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
48
-
49
- def __call__(self, audio_data, padding=True, max_length=None, truncation=True, **kwargs):
50
- if isinstance(audio_data, np.ndarray):
51
- audio_data = torch.from_numpy(audio_data)
52
-
53
- # Ensure it's in the expected shape
54
- if audio_data.ndim == 1:
55
- audio_data = audio_data.unsqueeze(0) # Add channel dimension
56
-
57
- # Convert to mel spectrogram
58
- mel_spec = self.mel_spectrogram(audio_data)
59
- log_mel_spec = self.amplitude_to_db(mel_spec)
60
-
61
- # Normalization
62
- mean = log_mel_spec.mean()
63
- std = log_mel_spec.std()
64
- log_mel_spec = (log_mel_spec - mean) / (std + 1e-10)
65
-
66
- # Handle max length/truncation
67
- if max_length is not None and truncation and log_mel_spec.shape[-1] > max_length:
68
- log_mel_spec = log_mel_spec[..., :max_length]
69
-
70
- return {"input_values": log_mel_spec}
71
-
72
- class AudioLanguageClassifier(PreTrainedModel):
73
- config_class = AudioLanguageClassifierConfig
74
-
75
- def __init__(self, config):
76
- super().__init__(config)
77
- self.config = config
78
-
79
- # CNN feature extractor
80
- self.feature_extractor = nn.Sequential(
81
- nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
82
- nn.ReLU(),
83
- nn.MaxPool2d(2, 2),
84
- nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
85
- nn.ReLU(),
86
- nn.MaxPool2d(2, 2)
87
- )
88
-
89
- # Global average pooling to eliminate size dependency
90
- self.global_pool = nn.AdaptiveAvgPool2d((4, 4))
91
-
92
- # Fixed size after global pooling
93
- self.flattened_size = 64 * 4 * 4
94
-
95
- # Projection layer with fixed input size
96
- self.projection = nn.Linear(self.flattened_size, config.feature_size)
97
-
98
- # Transformer for sequence modeling
99
- encoder_layer = nn.TransformerEncoderLayer(
100
- d_model=config.feature_size,
101
- nhead=config.num_attention_heads,
102
- dim_feedforward=config.intermediate_size,
103
- dropout=config.hidden_dropout_prob,
104
- batch_first=True
105
- )
106
- self.transformer_encoder = nn.TransformerEncoder(
107
- encoder_layer,
108
- num_layers=config.num_transformer_layers
109
- )
110
-
111
- # Classification head
112
- self.classifier = nn.Linear(config.feature_size, config.num_labels)
113
-
114
- def forward(
115
- self,
116
- input_values=None,
117
- labels=None,
118
- **kwargs
119
- ):
120
- batch_size = input_values.size(0)
121
-
122
- # Extract features using CNN
123
- x = self.feature_extractor(input_values)
124
-
125
- # Apply global pooling to get fixed size
126
- x = self.global_pool(x)
127
-
128
- # Flatten
129
- x = x.view(batch_size, -1)
130
-
131
- # Project to transformer dimension
132
- x = self.projection(x)
133
-
134
- # Add sequence dimension for transformer
135
- x = x.unsqueeze(1) # [batch_size, 1, feature_size]
136
-
137
- # Transformer encoding
138
- x = self.transformer_encoder(x)
139
-
140
- # Classification
141
- x = x[:, 0, :] # Take first token representation
142
- logits = self.classifier(x)
143
-
144
- loss = None
145
- if labels is not None:
146
- loss_fct = nn.CrossEntropyLoss()
147
- loss = loss_fct(logits, labels)
148
-
149
- return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
150
-
151
- # Function to load the model and its configuration
152
- def load_model():
153
- # Download the model files
154
- repo_id = "hriteshMaikap/languageClassifier"
155
-
156
- try:
157
- model_path = hf_hub_download(repo_id=repo_id, filename="model.pt")
158
- config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
159
- mappings_path = hf_hub_download(repo_id=repo_id, filename="language_mappings.json")
160
-
161
- # Load the config
162
- with open(config_path, "r") as f:
163
- config_dict = json.load(f)
164
-
165
- # IMPORTANT: Override num_labels to 10 since the model was trained with 10 classes
166
- config_dict["num_labels"] = 10
167
-
168
- config = AudioLanguageClassifierConfig(**config_dict)
169
-
170
- # Load the model
171
- model = AudioLanguageClassifier(config)
172
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
173
- model.eval()
174
-
175
- # Load language mappings
176
- with open(mappings_path, "r") as f:
177
- mappings = json.load(f)
178
-
179
- id_to_language = {int(k): v for k, v in mappings["id_to_language"].items()}
180
-
181
- return model, config, id_to_language
182
-
183
- except Exception as e:
184
- gr.Warning(f"Error loading model: {e}")
185
- # Return placeholders with error message
186
- raise gr.Error(f"Failed to load the language classification model: {e}")
187
-
188
- # Function to process audio and make predictions
189
- def classify_language(audio):
190
- try:
191
- # Load model on first inference
192
- global model, config, id_to_language, feature_extractor
193
- if 'model' not in globals() or model is None:
194
- model, config, id_to_language = load_model()
195
- feature_extractor = AudioFeatureExtractor(config)
196
-
197
- # Get audio data
198
- sr, waveform = audio
199
-
200
- # Convert to torch tensor
201
- waveform = torch.tensor(waveform).float()
202
-
203
- # Ensure mono
204
- if waveform.ndim > 1 and waveform.shape[0] > 1:
205
- waveform = torch.mean(waveform, dim=0, keepdim=True)
206
- elif waveform.ndim == 1:
207
- waveform = waveform.unsqueeze(0)
208
-
209
- # Resample to 16kHz if needed
210
- if sr != 16000:
211
- resampler = torchaudio.transforms.Resample(sr, 16000)
212
- waveform = resampler(waveform)
213
-
214
- # Extract features
215
- features = feature_extractor(waveform, max_length=256)
216
- input_values = features["input_values"]
217
-
218
- # Pad or truncate to fixed length
219
- _, height, width = input_values.shape
220
- max_length = 256
221
- if width < max_length:
222
- padding = torch.zeros(1, height, max_length - width)
223
- input_values = torch.cat([input_values, padding], dim=2)
224
- elif width > max_length:
225
- input_values = input_values[:, :, :max_length]
226
-
227
- # Get prediction
228
- with torch.no_grad():
229
- outputs = model(input_values=input_values)
230
- logits = outputs["logits"]
231
- probs = torch.nn.functional.softmax(logits, dim=1)[0]
232
-
233
- # Get top predictions
234
- num_classes = min(3, len(id_to_language))
235
- top_probs, top_ids = torch.topk(probs, num_classes)
236
-
237
- # Format results
238
- results = {}
239
- for i, (prob, pred_id) in enumerate(zip(top_probs, top_ids)):
240
- lang = id_to_language.get(pred_id.item(), f"Unknown-{pred_id.item()}")
241
- results[lang] = float(prob)
242
-
243
- return results
244
-
245
- except Exception as e:
246
- return {"Error": 1.0, "Details": str(e)}
247
-
248
- # Create the Gradio interface
249
  demo = gr.Interface(
250
- fn=classify_language,
251
- # Changed type from "tuple" to "numpy" to fix the error
252
- inputs=gr.Audio(sources=["microphone", "upload"], type="numpy"),
253
- outputs=gr.Label(num_top_classes=3),
254
- title="Indian Language Identification",
255
- description="Record or upload audio to identify the Indian language being spoken.",
256
- examples=[],
257
- article="""
258
- <div style="text-align: center;">
259
- <p>This model identifies various Indian languages from audio input. For best results:</p>
260
- <ul style="display: inline-block; text-align: left;">
261
- <li>Speak clearly with minimal background noise</li>
262
- <li>Recording length of 3-5 seconds is ideal</li>
263
- <li>Make sure to speak a full sentence or phrase</li>
264
- </ul>
265
- </div>
266
- """
267
  )
268
 
269
- # Launch the app
270
  if __name__ == "__main__":
271
- # Initialize model as None to lazy-load on first inference
272
- model, config, id_to_language, feature_extractor = None, None, None, None
273
  demo.launch()
 
1
  import gradio as gr
2
  import torch
 
3
  import torchaudio
4
  import json
5
+ import os
6
+
7
+ # Import your model architecture
8
+ from model import AudioLanguageClassifier, AudioLanguageClassifierConfig, AudioFeatureExtractor
9
+
10
+ MODEL_DIR = "."
11
+
12
+ # Load config and mappings
13
+ with open(os.path.join(MODEL_DIR, "config.json")) as f:
14
+ config_dict = json.load(f)
15
+ with open(os.path.join(MODEL_DIR, "language_mappings.json")) as f:
16
+ mappings = json.load(f)
17
+ id_to_language = {int(k): v for k, v in mappings["id_to_language"].items()}
18
+
19
+ config = AudioLanguageClassifierConfig(**config_dict)
20
+ model = AudioLanguageClassifier(config)
21
+ model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "model.pt"), map_location="cpu"))
22
+ model.eval()
23
+
24
+ feature_extractor = AudioFeatureExtractor(config)
25
+ max_length = 256 # Or whatever you used in training
26
+
27
+ def predict_language(audio):
28
+ waveform, sample_rate = torchaudio.load(audio)
29
+ # Resample and mono
30
+ if sample_rate != 16000:
31
+ waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
32
+ if waveform.shape[0] > 1:
33
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
34
+ features = feature_extractor(waveform)
35
+ input_values = features["input_values"]
36
+ _, height, width = input_values.shape
37
+ # Pad/truncate
38
+ if width < max_length:
39
+ padding = torch.zeros(1, height, max_length - width)
40
+ input_values = torch.cat([input_values, padding], dim=2)
41
+ elif width > max_length:
42
+ input_values = input_values[:, :, :max_length]
43
+ with torch.no_grad():
44
+ outputs = model(input_values=input_values)
45
+ logits = outputs["logits"]
46
+ probs = torch.nn.functional.softmax(logits, dim=1)[0]
47
+ top_probs, top_ids = torch.topk(probs, 3)
48
+ results = []
49
+ for prob, pred_id in zip(top_probs, top_ids):
50
+ lang = id_to_language[pred_id.item()]
51
+ results.append(f"{lang}: {prob.item():.2f}")
52
+ return "\n".join(results)
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  demo = gr.Interface(
55
+ fn=predict_language,
56
+ inputs=gr.Audio(source="microphone", type="filepath"),
57
+ outputs="text",
58
+ title="Indian Language Identifier",
59
+ description="Record audio and classify the spoken Indian language."
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
 
 
62
  if __name__ == "__main__":
 
 
63
  demo.launch()