mrfakename's picture
Upload 51 files
82bc972 verified
# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Pattern, Union
import numpy as np
import torch
import torchaudio
# from encodec import EncodecModel
# from encodec.utils import convert_audio
# from lhotse.features import FeatureExtractor
# from lhotse.utils import Seconds, compute_num_frames
from phonemizer.backend import EspeakBackend
from phonemizer.backend.espeak.language_switch import LanguageSwitch
from phonemizer.backend.espeak.words_mismatch import WordMismatch
from phonemizer.punctuation import Punctuation
from phonemizer.separator import Separator
try:
from pypinyin import Style, pinyin
from pypinyin.style._utils import get_finals, get_initials
except Exception:
pass
class PypinyinBackend:
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
There are two types pinyin or initials_finals, one is
just like "ni1 hao3", the other is like "n i1 h ao3".
"""
def __init__(
self,
backend="initials_finals",
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
) -> None:
self.backend = backend
self.punctuation_marks = punctuation_marks
def phonemize(
self, text: List[str], separator: Separator, strip=True, njobs=1
) -> List[str]:
assert isinstance(text, List)
phonemized = []
for _text in text:
_text = re.sub(" +", " ", _text.strip())
_text = _text.replace(" ", separator.word)
phones = []
if self.backend == "pypinyin":
for n, py in enumerate(
pinyin(
_text, style=Style.TONE3, neutral_tone_with_five=True
)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
phones.extend([py[0], separator.syllable])
elif self.backend == "pypinyin_initials_finals":
for n, py in enumerate(
pinyin(
_text, style=Style.TONE3, neutral_tone_with_five=True
)
):
if all([c in self.punctuation_marks for c in py[0]]):
if len(phones):
assert phones[-1] == separator.syllable
phones.pop(-1)
phones.extend(list(py[0]))
else:
if py[0][-1].isalnum():
initial = get_initials(py[0], strict=False)
if py[0][-1].isdigit():
final = (
get_finals(py[0][:-1], strict=False)
+ py[0][-1]
)
else:
final = get_finals(py[0], strict=False)
phones.extend(
[
initial,
separator.phone,
final,
separator.syllable,
]
)
else:
assert ValueError
else:
raise NotImplementedError
phonemized.append(
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
)
return phonemized
class TextTokenizer:
"""Phonemize Text."""
def __init__(
self,
language="en-us",
backend="espeak",
separator=Separator(word="_", syllable="-", phone="|"),
preserve_punctuation=True,
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
with_stress: bool = False,
tie: Union[bool, str] = False,
language_switch: LanguageSwitch = "keep-flags",
words_mismatch: WordMismatch = "ignore",
) -> None:
if backend == "espeak":
phonemizer = EspeakBackend(
language,
punctuation_marks=punctuation_marks,
preserve_punctuation=preserve_punctuation,
with_stress=with_stress,
tie=tie,
language_switch=language_switch,
words_mismatch=words_mismatch,
)
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
phonemizer = PypinyinBackend(
backend=backend,
punctuation_marks=punctuation_marks + separator.word,
)
else:
raise NotImplementedError(f"{backend}")
self.backend = phonemizer
self.separator = separator
def to_list(self, phonemized: str) -> List[str]:
fields = []
for word in phonemized.split(self.separator.word):
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
fields.extend(
[p for p in pp if p != self.separator.phone]
+ [self.separator.word]
)
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
self.separator.phone
)
return fields[:-1]
def __call__(self, text, strip=True) -> List[List[str]]:
if isinstance(text, str):
text = [text]
phonemized = self.backend.phonemize(
text, separator=self.separator, strip=strip, njobs=1
)
return [self.to_list(p) for p in phonemized]
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
phonemes = tokenizer([text.strip()])
return phonemes[0] # k2symbols
def remove_encodec_weight_norm(model):
from encodec.modules import SConv1d
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
from torch.nn.utils import remove_weight_norm
encoder = model.encoder.model
for key in encoder._modules:
if isinstance(encoder._modules[key], SEANetResnetBlock):
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
block_modules = encoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(encoder._modules[key], SConv1d):
remove_weight_norm(encoder._modules[key].conv.conv)
decoder = model.decoder.model
for key in decoder._modules:
if isinstance(decoder._modules[key], SEANetResnetBlock):
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
block_modules = decoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(decoder._modules[key], SConvTranspose1d):
remove_weight_norm(decoder._modules[key].convtr.convtr)
elif isinstance(decoder._modules[key], SConv1d):
remove_weight_norm(decoder._modules[key].conv.conv)
# class AudioTokenizer:
# """EnCodec audio."""
# def __init__(
# self,
# bandwidth, float=6.0,
# device: Any = None,
# ) -> None:
# # Instantiate a pretrained EnCodec model
# model = EncodecModel.encodec_model_24khz()
# model.set_target_bandwidth(bandwidth=bandwidth)
# remove_encodec_weight_norm(model)
# if not device:
# device = torch.device("cpu")
# if torch.cuda.is_available():
# device = torch.device("cuda:0")
# self._device = device
# self.codec = model.to(device)
# self.sample_rate = model.sample_rate
# self.channels = model.channels
# @property
# def device(self):
# return self._device
# def encode(self, wav: torch.Tensor) -> torch.Tensor:
# return self.codec.encode(wav.to(self.device))
# def decode(self, frames: torch.Tensor) -> torch.Tensor:
# return self.codec.decode(frames)
# class AudioTokenizer:
# """EnCodec audio."""
# def __init__(
# self,
# bandwidth: float=6.0,
# device: Any = None,
# hificodec=False,
# signature = None
# ) -> None:
# self.hificodec = hificodec
# self.customized = True if signature != None else False
# if hificodec:
# import sys
# sys.path.append("/home/pyp/AcademiCodec")
# from academicodec.models.hificodec.vqvae import VQVAE
# config_path = "/home/pyp/AcademiCodec/egs/HiFi-Codec-16k-320d/config_16k_320d.json"
# model_path = "/home/pyp/AcademiCodec/egs/HiFi-Codec-16k-320d/checkpoint/HiFi-Codec-16k-320d"
# self.sample_rate = 16000
# self.channels = 1
# model = VQVAE(config_path, model_path, with_encoder=True)
# model.generator.remove_weight_norm()
# model.encoder.remove_weight_norm()
# model.eval()
# else:
# if signature != None:
# # use customized encodec model
# # import sys
# # sys.path.append("home/pyp/audiocraft")
# from audiocraft.solvers import CompressionSolver
# model_path = f'//sig/{signature}'
# model = CompressionSolver.model_from_checkpoint(model_path)
# self.sample_rate = model.sample_rate
# self.channels = model.channels
# else:
# # Instantiate a pretrained EnCodec model
# model = EncodecModel.encodec_model_24khz()
# model.set_target_bandwidth(bandwidth=bandwidth)
# remove_encodec_weight_norm(model)
# self.sample_rate = model.sample_rate
# self.channels = model.channels
# if not device:
# device = torch.device("cpu")
# if torch.cuda.is_available():
# device = torch.device("cuda:0")
# self._device = device
# self.codec = model.to(device)
# @property
# def device(self):
# return self._device
# def encode(self, wav: torch.Tensor) -> torch.Tensor:
# if self.hificodec:
# assert wav.ndim==3 and wav.shape[:2] == torch.Size((1,1)), wav.shape
# wav = wav.squeeze(0)
# codes = self.codec.encode(wav.to(self.device)) # [1,T,4]
# return [(codes.transpose(2,1),None)]
# elif self.customized:
# codes = self.codec.encode(wav.to(self.device))
# return [(codes[0], None)]
# return self.codec.encode(wav.to(self.device))
# def decode(self, frames: torch.Tensor) -> torch.Tensor:
# if self.hificodec:
# frames = frames[0][0] # [1,4,T]
# assert frames.shape[:2] == torch.Size((1,4))
# audio = self.codec(frames.transpose(2,1))
# assert audio.shape[0] == 1, audio.shape
# return audio
# elif self.customized:
# frames = frames[0][0] # [1,4,T]
# return self.codec.decode(frames)
# return self.codec.decode(frames)
# # try:
# # return self.codec.decode(frames)
# # except:
# # import logging
# # logging.info(f"error when decoding frame of shape: {frames[0][0].shape}")
# # self.codec.cpu()
# # ret = self.codec.cpu().decode([(frames[0][0].cpu(),None)])[0].to(self._device)
# # self.codec.to(self._device)
# # return [ret]
# def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
# # Load and pre-process the audio waveform
# if offset != -1 and num_frames!=-1:
# wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
# else:
# wav, sr = torchaudio.load(audio_path)
# wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
# wav = wav.unsqueeze(0)
# # Extract discrete codes from EnCodec
# with torch.no_grad():
# encoded_frames = tokenizer.encode(wav)
# return encoded_frames
# @dataclass
# class AudioTokenConfig:
# frame_shift: Seconds = 320.0 / 24000
# num_quantizers: int = 8
# def to_dict(self) -> Dict[str, Any]:
# return asdict(self)
# @staticmethod
# def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
# return AudioTokenConfig(**data)
# class AudioTokenExtractor(FeatureExtractor):
# name = "encodec"
# config_type = AudioTokenConfig
# def __init__(self, config: Optional[Any] = None):
# super(AudioTokenExtractor, self).__init__(config)
# self.tokenizer = AudioTokenizer()
# def extract(
# self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
# ) -> np.ndarray:
# if not isinstance(samples, torch.Tensor):
# samples = torch.from_numpy(samples)
# if sampling_rate != self.tokenizer.sample_rate:
# samples = convert_audio(
# samples,
# sampling_rate,
# self.tokenizer.sample_rate,
# self.tokenizer.channels,
# )
# if len(samples.shape) == 2:
# samples = samples.unsqueeze(0)
# else:
# raise ValueError()
# device = self.tokenizer.device
# encoded_frames = self.tokenizer.encode(samples.detach().to(device))
# codes = encoded_frames[0][0] # [B, n_q, T]
# if True:
# duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
# expected_num_frames = compute_num_frames(
# duration=duration,
# frame_shift=self.frame_shift,
# sampling_rate=sampling_rate,
# )
# assert abs(codes.shape[-1] - expected_num_frames) <= 1
# codes = codes[..., :expected_num_frames]
# return codes.cpu().squeeze(0).permute(1, 0).numpy()
# @property
# def frame_shift(self) -> Seconds:
# return self.config.frame_shift
# def feature_dim(self, sampling_rate: int) -> int:
# return self.config.num_quantizers
# def pad_tensor_list(self, tensor_list, device, padding_value=0):
# # 计算每个张量的长度
# lengths = [tensor.shape[0] for tensor in tensor_list]
# # 使用pad_sequence函数进行填充
# tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
# padded_tensor = torch.nn.utils.rnn.pad_sequence(
# tensor_list, batch_first=True, padding_value=padding_value
# )
# return padded_tensor, lengths
# def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
# samples = [wav.squeeze() for wav in samples]
# device = self.tokenizer.device
# samples, lengths = self.pad_tensor_list(samples, device)
# samples = samples.unsqueeze(1)
# if not isinstance(samples, torch.Tensor):
# samples = torch.from_numpy(samples)
# if len(samples.shape) != 3:
# raise ValueError()
# if sampling_rate != self.tokenizer.sample_rate:
# samples = [
# convert_audio(
# wav,
# sampling_rate,
# self.tokenizer.sample_rate,
# self.tokenizer.channels,
# )
# for wav in samples
# ]
# # Extract discrete codes from EnCodec
# with torch.no_grad():
# encoded_frames = self.tokenizer.encode(samples.detach().to(device))
# encoded_frames = encoded_frames[0][0] # [B, n_q, T]
# batch_codes = []
# for b, length in enumerate(lengths):
# codes = encoded_frames[b]
# duration = round(length / sampling_rate, ndigits=12)
# expected_num_frames = compute_num_frames(
# duration=duration,
# frame_shift=self.frame_shift,
# sampling_rate=sampling_rate,
# )
# batch_codes.append(codes[..., :expected_num_frames])
# return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
if __name__ == "__main__":
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
# model.cuda()
samples = torch.from_numpy(np.random.random([4, 1, 30000])).type(torch.float32)
codes_norm = model.encode(samples.cuda())
remove_encodec_weight_norm(model)
codes_raw = model.encode(samples.cuda())
assert torch.allclose(codes_raw[0][0], codes_norm[0][0])