GavinHuang commited on
Commit
ac5f4c0
·
1 Parent(s): 1e2be4f

fix: enhance audio processing in transcribe function to handle tuple input and resample to 16kHz

Browse files
Files changed (2) hide show
  1. app.py +19 -23
  2. requirements.txt +2 -1
app.py CHANGED
@@ -34,30 +34,26 @@ def transcribe(audio, state=""):
34
  # Try extracting the first element
35
  audio = audio[1] if len(audio) > 1 else None
36
 
37
- if isinstance(audio, np.ndarray):
38
- audio_buffer.append(audio)
39
- # Process if buffer has enough data (e.g., 5 seconds at 16kHz)
40
- if len(np.concatenate(audio_buffer)) >= 5 * 16000:
41
- # Concatenate and preprocess
42
- audio_data = np.concatenate(audio_buffer)
43
- audio_data = audio_data.mean(axis=1) if audio_data.ndim > 1 else audio_data # To mono
44
- temp_file = "temp_audio.wav"
45
- sf.write(temp_file, audio_data, samplerate=16000)
46
- print("Transcribing audio...")
47
-
48
- # Transcribe
49
- if torch.cuda.is_available():
50
- model = model.cuda()
51
- transcription = model.transcribe([temp_file])[0]
52
- print(f"Transcription: {transcription}")
53
- model = model.cpu()
54
- os.remove(temp_file)
55
- print("Temporary file removed.")
56
 
57
- # Clear buffer
58
- audio_buffer = []
59
- new_state = state + " " + transcription if state else transcription
60
- return new_state, new_state
61
  return state, state
62
 
63
  # Define the Gradio interface
 
34
  # Try extracting the first element
35
  audio = audio[1] if len(audio) > 1 else None
36
 
37
+ if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[1], np.ndarray):
38
+ # Handle tuple of (sample_rate, audio_array)
39
+ print(f"Tuple contents: {audio}")
40
+ sample_rate, audio_data = audio
41
+ # Resample to 16kHz for NeMo
42
+ if sample_rate != 16000:
43
+ print(f"Resampling from {sample_rate}Hz to 16000Hz")
44
+ audio_data = librosa.resample(audio_data.astype(float), orig_sr=sample_rate, target_sr=16000)
45
+ # Save to temporary WAV file
46
+ temp_file = "temp_audio.wav"
47
+ sf.write(temp_file, audio_data, samplerate=16000)
48
+ print(f"Processing temporary audio file: {temp_file}")
49
+ transcription = model.transcribe([temp_file])[0]
50
+ os.remove(temp_file) # Clean up
51
+ print("Temporary file removed.")
 
 
 
 
52
 
53
+ # Clear buffer
54
+ audio_buffer = []
55
+ new_state = state + " " + transcription if state else transcription
56
+ return new_state, new_state
57
  return state, state
58
 
59
  # Define the Gradio interface
requirements.txt CHANGED
@@ -4,4 +4,5 @@ nemo_toolkit[asr]>=1.18.0
4
  omegaconf>=2.2.0
5
  numpy>=1.22.0
6
  cuda-python>=12.3
7
- soundfile
 
 
4
  omegaconf>=2.2.0
5
  numpy>=1.22.0
6
  cuda-python>=12.3
7
+ soundfile
8
+ librosa