github-actions[bot] commited on
Commit
166c454
·
1 Parent(s): d35ffe9

Sync with https://github.com/mozilla-ai/speech-to-text-finetune

Browse files
Files changed (1) hide show
  1. app.py +85 -35
app.py CHANGED
@@ -1,34 +1,39 @@
1
  import os
2
  import gradio as gr
3
  import spaces
 
4
  from transformers import pipeline, Pipeline
5
 
6
  is_hf_space = os.getenv("IS_HF_SPACE")
7
- model_ids = [
8
- "",
9
- "mozilla-ai/whisper-small-gl (Galician)",
10
- "mozilla-ai/whisper-small-el (Greek)",
11
- "mozilla-ai/whisper-small-fr (French)",
12
- "mozilla-ai/whisper-small-sv (Swedish)",
13
- "openai/whisper-tiny (Multilingual)",
14
- "openai/whisper-small (Multilingual)",
15
- "openai/whisper-medium (Multilingual)",
16
- "openai/whisper-large-v3 (Multilingual)",
17
- "openai/whisper-large-v3-turbo (Multilingual)",
18
- ]
19
-
20
-
21
- def _load_local_model(model_dir: str) -> Pipeline:
22
- from transformers import (
23
- WhisperProcessor,
24
- WhisperTokenizer,
25
- WhisperFeatureExtractor,
26
- WhisperForConditionalGeneration,
 
 
27
  )
28
 
 
 
 
 
29
  processor = WhisperProcessor.from_pretrained(model_dir)
30
- tokenizer = WhisperTokenizer.from_pretrained(model_dir, task="transcribe")
31
- feature_extractor = WhisperFeatureExtractor.from_pretrained(model_dir)
32
  model = WhisperForConditionalGeneration.from_pretrained(model_dir)
33
 
34
  try:
@@ -36,29 +41,52 @@ def _load_local_model(model_dir: str) -> Pipeline:
36
  task="automatic-speech-recognition",
37
  model=model,
38
  processor=processor,
39
- tokenizer=tokenizer,
40
- feature_extractor=feature_extractor,
41
  )
42
  except Exception as e:
43
  return str(e)
44
 
45
 
46
- def _load_hf_model(model_repo_id: str) -> Pipeline:
47
  try:
48
  return pipeline(
49
  "automatic-speech-recognition",
50
  model=model_repo_id,
 
51
  )
52
  except Exception as e:
53
  return str(e)
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  @spaces.GPU(duration=30)
57
  def transcribe(
58
  dropdown_model_id: str,
59
  hf_model_id: str,
60
  local_model_id: str,
61
  audio: gr.Audio,
 
62
  ) -> str:
63
  if dropdown_model_id and not hf_model_id and not local_model_id:
64
  dropdown_model_id = dropdown_model_id.split(" (")[0]
@@ -74,7 +102,21 @@ def transcribe(
74
  if isinstance(pipe, str):
75
  # Exception raised when loading
76
  return f"⚠️ Error: {pipe}"
77
- text = pipe(audio)["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return text
79
 
80
 
@@ -88,7 +130,7 @@ def setup_gradio_demo():
88
  """
89
  )
90
  ### Model selection ###
91
-
92
  with gr.Row():
93
  with gr.Column():
94
  dropdown_model = gr.Dropdown(
@@ -106,19 +148,27 @@ def setup_gradio_demo():
106
  )
107
 
108
  ### Transcription ###
109
- audio_input = gr.Audio(
110
- sources=["microphone", "upload"],
111
- type="filepath",
112
- label="Record a message / Upload audio file",
113
- show_download_button=True,
114
- max_length=30,
115
- )
 
 
116
  transcribe_button = gr.Button("Transcribe")
117
  transcribe_output = gr.Text(label="Output")
118
 
119
  transcribe_button.click(
120
  fn=transcribe,
121
- inputs=[dropdown_model, user_model, local_model, audio_input],
 
 
 
 
 
 
122
  outputs=transcribe_output,
123
  )
124
 
 
1
  import os
2
  import gradio as gr
3
  import spaces
4
+ from huggingface_hub import get_collection, HfApi
5
  from transformers import pipeline, Pipeline
6
 
7
  is_hf_space = os.getenv("IS_HF_SPACE")
8
+
9
+
10
+ def get_dropdown_model_ids():
11
+ mozilla_ai_model_ids = []
12
+ # Get model ids from collection and append the language in () from the model's metadata
13
+ for model_i in get_collection(
14
+ "mozilla-ai/common-voice-whisper-67b847a74ad7561781aa10fd"
15
+ ).items:
16
+ model_metadata = HfApi().model_info(model_i.item_id)
17
+ language = model_metadata.card_data.model_name.split("on ")[1]
18
+ mozilla_ai_model_ids.append(model_i.item_id + f" ({language})")
19
+
20
+ return (
21
+ [""]
22
+ + mozilla_ai_model_ids
23
+ + [
24
+ "openai/whisper-tiny (Multilingual)",
25
+ "openai/whisper-small (Multilingual)",
26
+ "openai/whisper-medium (Multilingual)",
27
+ "openai/whisper-large-v3 (Multilingual)",
28
+ "openai/whisper-large-v3-turbo (Multilingual)",
29
+ ]
30
  )
31
 
32
+
33
+ def _load_local_model(model_dir: str) -> Pipeline | str:
34
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
35
+
36
  processor = WhisperProcessor.from_pretrained(model_dir)
 
 
37
  model = WhisperForConditionalGeneration.from_pretrained(model_dir)
38
 
39
  try:
 
41
  task="automatic-speech-recognition",
42
  model=model,
43
  processor=processor,
44
+ chunk_length_s=30, # max input duration for whisper
 
45
  )
46
  except Exception as e:
47
  return str(e)
48
 
49
 
50
+ def _load_hf_model(model_repo_id: str) -> Pipeline | str:
51
  try:
52
  return pipeline(
53
  "automatic-speech-recognition",
54
  model=model_repo_id,
55
+ chunk_length_s=30, # max input duration for whisper
56
  )
57
  except Exception as e:
58
  return str(e)
59
 
60
 
61
+ # Copied from https://github.com/openai/whisper/blob/517a43ecd132a2089d85f4ebc044728a71d49f6e/whisper/utils.py#L50
62
+ def format_timestamp(
63
+ seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
64
+ ):
65
+ assert seconds >= 0, "non-negative timestamp expected"
66
+ milliseconds = round(seconds * 1000.0)
67
+
68
+ hours = milliseconds // 3_600_000
69
+ milliseconds -= hours * 3_600_000
70
+
71
+ minutes = milliseconds // 60_000
72
+ milliseconds -= minutes * 60_000
73
+
74
+ seconds = milliseconds // 1_000
75
+ milliseconds -= seconds * 1_000
76
+
77
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
78
+ return (
79
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
80
+ )
81
+
82
+
83
  @spaces.GPU(duration=30)
84
  def transcribe(
85
  dropdown_model_id: str,
86
  hf_model_id: str,
87
  local_model_id: str,
88
  audio: gr.Audio,
89
+ show_timestamps: bool,
90
  ) -> str:
91
  if dropdown_model_id and not hf_model_id and not local_model_id:
92
  dropdown_model_id = dropdown_model_id.split(" (")[0]
 
102
  if isinstance(pipe, str):
103
  # Exception raised when loading
104
  return f"⚠️ Error: {pipe}"
105
+
106
+ output = pipe(
107
+ audio,
108
+ generate_kwargs={"task": "transcribe"},
109
+ batch_size=16,
110
+ return_timestamps=show_timestamps,
111
+ )
112
+ text = output["text"]
113
+ if show_timestamps:
114
+ timestamps = output["chunks"]
115
+ timestamps = [
116
+ f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
117
+ for chunk in timestamps
118
+ ]
119
+ text = "\n".join(str(feature) for feature in timestamps)
120
  return text
121
 
122
 
 
130
  """
131
  )
132
  ### Model selection ###
133
+ model_ids = get_dropdown_model_ids()
134
  with gr.Row():
135
  with gr.Column():
136
  dropdown_model = gr.Dropdown(
 
148
  )
149
 
150
  ### Transcription ###
151
+ with gr.Group():
152
+ audio_input = gr.Audio(
153
+ sources=["microphone", "upload"],
154
+ type="filepath",
155
+ label="Record a message / Upload audio file",
156
+ show_download_button=True,
157
+ )
158
+ timestamps_check = gr.Checkbox(label="Show timestamps")
159
+
160
  transcribe_button = gr.Button("Transcribe")
161
  transcribe_output = gr.Text(label="Output")
162
 
163
  transcribe_button.click(
164
  fn=transcribe,
165
+ inputs=[
166
+ dropdown_model,
167
+ user_model,
168
+ local_model,
169
+ audio_input,
170
+ timestamps_check,
171
+ ],
172
  outputs=transcribe_output,
173
  )
174