pheodoraa commited on
Commit
2548d5a
·
verified ·
1 Parent(s): facd705
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -8,13 +8,13 @@ try:
8
  asr_model = EncoderASR.from_hparams(
9
  source="speechbrain/asr-wav2vec2-dvoice-darija",
10
  savedir="tmp_model",
11
- run_opts={"device": "cpu"} # Ensure compatibility with CPU if needed
12
  )
13
  except Exception as e:
14
  print(f"Error loading model: {str(e)}")
15
 
16
  def transcribe(audio):
17
- """Transcribe audio to text using SpeechBrain ASR model."""
18
  if audio is None:
19
  return "No audio file uploaded. Please upload a valid file."
20
 
@@ -22,20 +22,21 @@ def transcribe(audio):
22
  # Load audio
23
  waveform, sample_rate = torchaudio.load(audio)
24
 
25
- # Convert to single-channel (mono) if stereo
26
  if waveform.shape[0] > 1:
27
  waveform = torch.mean(waveform, dim=0, keepdim=True)
28
 
29
- # Ensure correct sample rate (16kHz expected by the model)
30
  if sample_rate != 16000:
31
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
32
  waveform = resampler(waveform)
33
 
34
- # Compute waveform length as a relative fraction
35
- wav_lens = torch.tensor([waveform.shape[1] / waveform.shape[1]], dtype=torch.float32)
 
36
 
37
- # Add batch dimension (SpeechBrain expects a batch format)
38
- waveform = waveform.unsqueeze(0)
39
 
40
  # Transcribe
41
  transcription = asr_model.transcribe_batch(waveform, wav_lens)
 
8
  asr_model = EncoderASR.from_hparams(
9
  source="speechbrain/asr-wav2vec2-dvoice-darija",
10
  savedir="tmp_model",
11
+ run_opts={"device": "cpu"} # Ensures compatibility with CPU environments
12
  )
13
  except Exception as e:
14
  print(f"Error loading model: {str(e)}")
15
 
16
  def transcribe(audio):
17
+ """Transcribe uploaded audio to text using SpeechBrain ASR."""
18
  if audio is None:
19
  return "No audio file uploaded. Please upload a valid file."
20
 
 
22
  # Load audio
23
  waveform, sample_rate = torchaudio.load(audio)
24
 
25
+ # Convert stereo to mono if needed
26
  if waveform.shape[0] > 1:
27
  waveform = torch.mean(waveform, dim=0, keepdim=True)
28
 
29
+ # Resample if sample rate is not 16kHz
30
  if sample_rate != 16000:
31
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
32
  waveform = resampler(waveform)
33
 
34
+ # Ensure waveform is 2D (1, time_steps)
35
+ waveform = waveform.squeeze(0) # Remove channel dim if present
36
+ waveform = waveform.unsqueeze(0) # Add batch dimension -> (1, time_steps)
37
 
38
+ # Compute wav_lens as a relative fraction
39
+ wav_lens = torch.tensor([waveform.shape[1] / waveform.shape[1]], dtype=torch.float32)
40
 
41
  # Transcribe
42
  transcription = asr_model.transcribe_batch(waveform, wav_lens)