lucas-ventura commited on
Commit
ea75f09
·
verified ·
1 Parent(s): 33c56ea

Update tools/extract/asr.py

Browse files
Files changed (1) hide show
  1. tools/extract/asr.py +5 -2
tools/extract/asr.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from pathlib import Path
2
 
3
  import torch
@@ -22,7 +23,6 @@ class ASRProcessor:
22
  def __init__(self, model_name="large-v2", compute_type="float16"):
23
  self.model_name = model_name
24
  # Check if whisperx.load_model accepts compute_type parameter
25
- import inspect
26
 
27
  if "compute_type" in inspect.signature(whisperx.load_model).parameters:
28
  self.model = whisperx.load_model(
@@ -34,7 +34,10 @@ class ASRProcessor:
34
  def get_asr(self, audio_file, return_duration=True):
35
  assert Path(audio_file).exists(), f"File {audio_file} does not exist"
36
  audio = whisperx.load_audio(audio_file)
37
- result = self.model.transcribe(audio, batch_size=1)
 
 
 
38
  language = result["language"]
39
  duration = audio.shape[0] / SAMPLE_RATE
40
 
 
1
+ import inspect
2
  from pathlib import Path
3
 
4
  import torch
 
23
  def __init__(self, model_name="large-v2", compute_type="float16"):
24
  self.model_name = model_name
25
  # Check if whisperx.load_model accepts compute_type parameter
 
26
 
27
  if "compute_type" in inspect.signature(whisperx.load_model).parameters:
28
  self.model = whisperx.load_model(
 
34
  def get_asr(self, audio_file, return_duration=True):
35
  assert Path(audio_file).exists(), f"File {audio_file} does not exist"
36
  audio = whisperx.load_audio(audio_file)
37
+ if "batch_size" in inspect.signature(self.model.transcribe).parameters:
38
+ result = self.model.transcribe(audio, batch_size=1)
39
+ else:
40
+ result = self.model.transcribe(audio)
41
  language = result["language"]
42
  duration = audio.shape[0] / SAMPLE_RATE
43