Mimi-playground / app.py
thomwolf's picture
thomwolf HF Staff
update for zero-gpu
0c8c55f
raw
history blame
3.35 kB
import gradio as gr
import random
import time
from huggingface_hub import hf_hub_download
import numpy as np
import sphn
import torch
import spaces
from moshi.models import loaders
device = "cuda" if torch.cuda.device_count() else "cpu"
num_codebooks = 32
print("loading mimi")
model_file = hf_hub_download(loaders.DEFAULT_REPO, "tokenizer-e351c8d8-checkpoint125.safetensors")
mimi = loaders.get_mimi(model_file, device, num_codebooks=num_codebooks)
mimi.eval()
print("mimi loaded")
@spaces.GPU
def mimi_streaming_test(input_wave, max_duration_sec=10.0):
pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate)
# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
sample_pcm, sample_sr = sphn.read(input_wave) # ("bria.mp3")
sample_rate = mimi.sample_rate
print("loaded pcm", sample_pcm.shape, sample_sr)
sample_pcm = sphn.resample(
sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=sample_rate
)
sample_pcm = torch.tensor(sample_pcm, device=device)
max_duration_len = int(sample_rate * max_duration_sec)
if sample_pcm.shape[-1] > max_duration_len:
sample_pcm = sample_pcm[..., :max_duration_len]
print("resampled pcm", sample_pcm.shape, sample_sr)
sample_pcm = sample_pcm[None].to(device=device)
print("streaming encoding...")
start_time = time.time()
all_codes = []
def run_loop():
for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size):
end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size)
chunk = sample_pcm[..., start_idx:end_idx]
with torch.no_grad():
codes = mimi.encode(chunk)
if codes.shape[-1]:
print(start_idx, codes.shape, end="\r")
all_codes.append(codes)
run_loop()
all_codes_th = torch.cat(all_codes, dim=-1)
print(f"codes {all_codes_th.shape} generated in {time.time() - start_time:.2f}s")
all_codes_list = [all_codes_th[:, :1, :],
all_codes_th[:, :2, :],
all_codes_th[:, :4, :],
# all_codes_th[:, :8, :],
# all_codes_th[:, :16, :],
all_codes_th[:, :32, :]]
pcm_list = []
for i, all_codes_th in enumerate(all_codes_list):
with torch.no_grad():
print(f"decoding {i+1} codebooks, {all_codes_th.shape}")
pcm = mimi.decode(all_codes_th)
pcm_list.append((sample_rate, pcm[0, 0].cpu().numpy()))
# sphn.write_wav("roundtrip_out.wav", pcm[0, 0].cpu().numpy(), sample_rate)
return pcm_list
demo = gr.Interface(
fn=mimi_streaming_test,
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
outputs=[gr.Audio(type="numpy", label="With 1 codebook"),
gr.Audio(type="numpy", label="With 2 codebooks"),
gr.Audio(type="numpy", label="With 4 codebooks"),
# gr.Audio(type="numpy", label="With 8 codebooks"),
# gr.Audio(type="numpy", label="With 16 codebooks"),
gr.Audio(type="numpy", label="With 32 codebooks")],
examples= [["hello.mp3"]],
title="Mimi tokenizer playground",
description="Explore the quality of compression when using various number of code books in the Mimi model."
)
demo.launch()