SincVAD_Demo / app.py
jethrowang's picture
Update app.py
bef20d8 verified
raw
history blame contribute delete
6.96 kB
import torch
import torch.nn as nn
import gradio as gr
import numpy as np
import torchaudio
import torchaudio.transforms as T
import time
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from model.tinyvad import TinyVAD
# Font configuration
font_path = './fonts/Times_New_Roman.ttf'
font_prop = FontProperties(fname=font_path, size=18)
# Model and Processing Parameters
WINDOW_SIZE = 0.63
SINC_CONV = False
SSM = False
TARGET_SAMPLE_RATE = 16000
# Model Initialization
model = TinyVAD(1, 32, 64, patch_size=8, num_blocks=2,
sinc_conv=SINC_CONV, ssm=SSM)
checkpoint_path = './sincvad.ckpt'
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=True)
model.load_state_dict(checkpoint, strict=False)
model.eval()
# Audio Processing Transforms
mel_spectrogram = T.MelSpectrogram(sample_rate=TARGET_SAMPLE_RATE, n_mels=64, win_length=400, hop_length=160)
log_mel_spectrogram = T.AmplitudeToDB()
# Chunking Parameters
chunk_duration = WINDOW_SIZE
shift_duration = WINDOW_SIZE * 0.875 # Increased overlap compared to first version
def predict(audio_input, threshold):
"""
Predict voice activity in an audio file with detailed processing and visualization.
Args:
audio_file (str): Path to the audio file
threshold (float): Decision threshold for speech/non-speech classification
Yields:
Intermediate and final prediction results
"""
start_time = time.time()
try:
# Load and preprocess audio
waveform, orig_sample_rate = torchaudio.load(audio_input)
# Resample if necessary
if orig_sample_rate != TARGET_SAMPLE_RATE:
print(f"Resampling from {orig_sample_rate} Hz to {TARGET_SAMPLE_RATE} Hz")
resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=TARGET_SAMPLE_RATE)
waveform = resampler(waveform)
# Ensure mono channel
if waveform.size(0) > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
except Exception as e:
print(f"Error loading audio file: {e}")
yield "Error loading audio file.", None, None, None
return
# Audio duration checks and padding
audio_duration = waveform.size(1) / TARGET_SAMPLE_RATE
print(f"Audio duration: {audio_duration:.2f} seconds")
print(f"Original sample rate: {orig_sample_rate} Hz")
print(f"Current sample rate: {TARGET_SAMPLE_RATE} Hz")
if audio_duration < chunk_duration:
required_length = int(chunk_duration * TARGET_SAMPLE_RATE)
padding_length = required_length - waveform.size(1)
waveform = torch.nn.functional.pad(waveform, (0, padding_length))
# Chunk processing parameters
chunk_size = int(chunk_duration * TARGET_SAMPLE_RATE)
shift_size = int(shift_duration * TARGET_SAMPLE_RATE)
num_chunks = (waveform.size(1) - chunk_size) // shift_size + 1
predictions = []
time_stamps = []
detailed_predictions = []
# Initialize plot
fig, ax = plt.subplots(figsize=(12, 5))
ax.set_xlabel('Time (seconds)', fontproperties=font_prop)
ax.set_ylabel('Probability', fontproperties=font_prop)
ax.set_title('Voice Activity Detection Probability Over Time', fontproperties=font_prop)
ax.axhline(y=threshold, color='tab:red', linestyle='--', label='Threshold')
ax.grid(True)
ax.set_ylim([-0.05, 1.05])
# Process audio in chunks
for i in range(num_chunks):
start_idx = i * shift_size
end_idx = start_idx + chunk_size
chunk = waveform[:, start_idx:end_idx]
if chunk.size(1) < chunk_size:
break
# Feature extraction
inputs = mel_spectrogram(chunk)
inputs = log_mel_spectrogram(inputs).unsqueeze(0)
# Model inference
with torch.no_grad():
outputs = model(inputs)
outputs = torch.sigmoid(outputs)
# Process outputs
predictions.append(outputs.item())
time_stamps.append(start_idx / TARGET_SAMPLE_RATE)
detailed_predictions.append({
'start_time': start_idx / TARGET_SAMPLE_RATE,
'output': outputs.item(),
})
# Update plot dynamically
ax.clear()
ax.set_xlabel('Time (seconds)', fontproperties=font_prop)
ax.set_ylabel('Probability', fontproperties=font_prop)
ax.set_title('Speech Probability Over Time', fontproperties=font_prop)
ax.axhline(y=threshold, color='tab:red', linestyle='--', label='Threshold')
ax.grid(True)
ax.set_ylim([-0.05, 1.05])
ax.plot(time_stamps, predictions, label='Speech Probability', color='tab:blue')
plt.tight_layout()
# Yield intermediate progress
yield "Processing...", None, None, fig
# Detailed logging
print("Detailed Predictions:")
for pred in detailed_predictions:
print(f"Start Time: {pred['start_time']:.2f}s, Output: {pred['output']:.4f}")
# Final prediction processing
avg_output = max(0, min(1, np.mean(predictions)))
prediction_time = time.time() - start_time
prediction = "Speech" if avg_output > threshold else "Non-speech"
probability = f'{(float(avg_output) * 100):.2f}'
inference_time = f'{prediction_time:.4f}'
print(f"Final Prediction: {prediction}")
print(f"Average Probability: {probability}%")
print(f"Number of chunks processed: {num_chunks}")
# Final result
yield prediction, probability, inference_time, fig
# Gradio Interface
with gr.Blocks() as demo:
gr.Image("./img/logo.png", elem_id="logo", height=100)
# Title and Description
gr.Markdown("<h1 style='text-align: center; color: black;'>Voice Activity Detection using SincVAD</h1>")
gr.Markdown("<h3 style='text-align: center; color: black;'>Upload or record audio to predict speech activity and view the probability curve.</h3>")
# Interface Layout
with gr.Row():
with gr.Column():
# Separate recording and file upload
audio_input = gr.Audio(type="filepath", label="Upload or Record Audio")
threshold_input = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Threshold")
with gr.Column():
prediction_output = gr.Textbox(label="Prediction")
probability_output = gr.Number(label="Average Probability (%)")
time_output = gr.Textbox(label="Inference Time (seconds)")
plot_output = gr.Plot(label="Probability Curve")
# Prediction Trigger
predict_btn = gr.Button("Start Prediction")
predict_btn.click(
predict,
[audio_input, threshold_input],
[prediction_output, probability_output, time_output, plot_output],
api_name="predict"
)
# Launch Configuration
if __name__ == "__main__":
demo.queue() # Enable queue to support generators
demo.launch()