sanchit-gandhi commited on
Commit
9f7a693
·
1 Parent(s): 9e87dbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import torch
 
2
 
3
  from transformers import WhisperForConditionalGeneration, WhisperProcessor
4
  from transformers.models.whisper.tokenization_whisper import LANGUAGES
5
  from transformers.pipelines.audio_utils import ffmpeg_read
6
 
 
7
  import gradio as gr
8
 
9
 
@@ -43,18 +45,20 @@ def transcribe(Microphone, File_Upload):
43
 
44
  audio_data = process_audio_file(file)
45
 
46
- input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
47
 
48
  with torch.no_grad():
49
  logits = model.forward(input_features, decoder_input_ids=decoder_input_ids).logits
50
 
51
  pred_ids = torch.argmax(logits, dim=-1)
 
 
52
  lang_ids = processor.decode(pred_ids[0])
53
 
54
  lang_ids = lang_ids.lstrip("<|").rstrip("|>")
55
- language = LANGUAGES[lang_ids]
56
 
57
- return language
58
 
59
 
60
  iface = gr.Interface(
@@ -63,7 +67,10 @@ iface = gr.Interface(
63
  gr.inputs.Audio(source="microphone", type='filepath', optional=True),
64
  gr.inputs.Audio(source="upload", type='filepath', optional=True),
65
  ],
66
- outputs="text",
 
 
 
67
  layout="horizontal",
68
  theme="huggingface",
69
  title="Whisper Language Identification",
 
1
  import torch
2
+ import torch.nn.functional as F
3
 
4
  from transformers import WhisperForConditionalGeneration, WhisperProcessor
5
  from transformers.models.whisper.tokenization_whisper import LANGUAGES
6
  from transformers.pipelines.audio_utils import ffmpeg_read
7
 
8
+ import librosa
9
  import gradio as gr
10
 
11
 
 
45
 
46
  audio_data = process_audio_file(file)
47
 
48
+ input_features = processor(audio_data, return_tensors="pt").input_features
49
 
50
  with torch.no_grad():
51
  logits = model.forward(input_features, decoder_input_ids=decoder_input_ids).logits
52
 
53
  pred_ids = torch.argmax(logits, dim=-1)
54
+ probability = F.softmax(logits, dim=-1).max()
55
+
56
  lang_ids = processor.decode(pred_ids[0])
57
 
58
  lang_ids = lang_ids.lstrip("<|").rstrip("|>")
59
+ language = LANGUAGES.get(lang_ids, "not detected")
60
 
61
+ return language.capitalize(), probability.cpu().numpy()
62
 
63
 
64
  iface = gr.Interface(
 
67
  gr.inputs.Audio(source="microphone", type='filepath', optional=True),
68
  gr.inputs.Audio(source="upload", type='filepath', optional=True),
69
  ],
70
+ outputs=[
71
+ gr.outputs.Textbox(label="Language"),
72
+ gr.Number(label="Probability"),
73
+ ],
74
  layout="horizontal",
75
  theme="huggingface",
76
  title="Whisper Language Identification",