hriteshMaikap commited on
Commit
ba7a495
·
verified ·
1 Parent(s): 8579e22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -48
app.py CHANGED
@@ -1,55 +1,12 @@
1
  import gradio as gr
2
- import torch
3
- import torchaudio
4
- import json
5
- import os
6
 
7
- # Import your model architecture
8
- from model import AudioLanguageClassifier, AudioLanguageClassifierConfig, AudioFeatureExtractor
9
-
10
- MODEL_DIR = "."
11
-
12
- # Load config and mappings
13
- with open(os.path.join(MODEL_DIR, "config.json")) as f:
14
- config_dict = json.load(f)
15
- with open(os.path.join(MODEL_DIR, "language_mappings.json")) as f:
16
- mappings = json.load(f)
17
- id_to_language = {int(k): v for k, v in mappings["id_to_language"].items()}
18
-
19
- config = AudioLanguageClassifierConfig(**config_dict)
20
- model = AudioLanguageClassifier(config)
21
- model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "model.pt"), map_location="cpu"))
22
- model.eval()
23
-
24
- feature_extractor = AudioFeatureExtractor(config)
25
- max_length = 256 # Or whatever you used in training
26
 
27
  def predict_language(audio):
28
- waveform, sample_rate = torchaudio.load(audio)
29
- # Resample and mono
30
- if sample_rate != 16000:
31
- waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
32
- if waveform.shape[0] > 1:
33
- waveform = torch.mean(waveform, dim=0, keepdim=True)
34
- features = feature_extractor(waveform)
35
- input_values = features["input_values"]
36
- _, height, width = input_values.shape
37
- # Pad/truncate
38
- if width < max_length:
39
- padding = torch.zeros(1, height, max_length - width)
40
- input_values = torch.cat([input_values, padding], dim=2)
41
- elif width > max_length:
42
- input_values = input_values[:, :, :max_length]
43
- with torch.no_grad():
44
- outputs = model(input_values=input_values)
45
- logits = outputs["logits"]
46
- probs = torch.nn.functional.softmax(logits, dim=1)[0]
47
- top_probs, top_ids = torch.topk(probs, 3)
48
- results = []
49
- for prob, pred_id in zip(top_probs, top_ids):
50
- lang = id_to_language[pred_id.item()]
51
- results.append(f"{lang}: {prob.item():.2f}")
52
- return "\n".join(results)
53
 
54
  demo = gr.Interface(
55
  fn=predict_language,
 
1
  import gradio as gr
2
+ from transformers import pipeline
 
 
 
3
 
4
+ classifier = pipeline("audio-classification", model="hriteshMaikap/languageClassifier")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def predict_language(audio):
7
+ out = classifier(audio)
8
+ # out is a list of dicts: [{'label': 'Hindi', 'score': 0.98}, ...]
9
+ return "\n".join([f"{res['label']}: {res['score']:.2f}" for res in out])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  demo = gr.Interface(
12
  fn=predict_language,