|
|
|
import argparse |
|
import csv |
|
import os |
|
import random |
|
from statistics import StatisticsError, mean, median, mode, stdev |
|
|
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from text.cmudict import CMUDict |
|
|
|
|
|
def get_audio_seconds(frames): |
|
return (frames * 12.5) / 1000 |
|
|
|
|
|
def append_data_statistics(meta_data): |
|
|
|
for char_cnt in meta_data: |
|
data = meta_data[char_cnt]["data"] |
|
audio_len_list = [d["audio_len"] for d in data] |
|
mean_audio_len = mean(audio_len_list) |
|
try: |
|
mode_audio_list = [round(d["audio_len"], 2) for d in data] |
|
mode_audio_len = mode(mode_audio_list) |
|
except StatisticsError: |
|
mode_audio_len = audio_len_list[0] |
|
median_audio_len = median(audio_len_list) |
|
|
|
try: |
|
std = stdev(d["audio_len"] for d in data) |
|
except StatisticsError: |
|
std = 0 |
|
|
|
meta_data[char_cnt]["mean"] = mean_audio_len |
|
meta_data[char_cnt]["median"] = median_audio_len |
|
meta_data[char_cnt]["mode"] = mode_audio_len |
|
meta_data[char_cnt]["std"] = std |
|
return meta_data |
|
|
|
|
|
def process_meta_data(path): |
|
meta_data = {} |
|
|
|
|
|
with open(path, "r", encoding="utf-8") as f: |
|
data = csv.reader(f, delimiter="|") |
|
for row in data: |
|
frames = int(row[2]) |
|
utt = row[3] |
|
audio_len = get_audio_seconds(frames) |
|
char_count = len(utt) |
|
if not meta_data.get(char_count): |
|
meta_data[char_count] = {"data": []} |
|
|
|
meta_data[char_count]["data"].append( |
|
{ |
|
"utt": utt, |
|
"frames": frames, |
|
"audio_len": audio_len, |
|
"row": "{}|{}|{}|{}".format(row[0], row[1], row[2], row[3]), |
|
} |
|
) |
|
|
|
meta_data = append_data_statistics(meta_data) |
|
|
|
return meta_data |
|
|
|
|
|
def get_data_points(meta_data): |
|
x = meta_data |
|
y_avg = [meta_data[d]["mean"] for d in meta_data] |
|
y_mode = [meta_data[d]["mode"] for d in meta_data] |
|
y_median = [meta_data[d]["median"] for d in meta_data] |
|
y_std = [meta_data[d]["std"] for d in meta_data] |
|
y_num_samples = [len(meta_data[d]["data"]) for d in meta_data] |
|
return { |
|
"x": x, |
|
"y_avg": y_avg, |
|
"y_mode": y_mode, |
|
"y_median": y_median, |
|
"y_std": y_std, |
|
"y_num_samples": y_num_samples, |
|
} |
|
|
|
|
|
def save_training(file_path, meta_data): |
|
rows = [] |
|
for char_cnt in meta_data: |
|
data = meta_data[char_cnt]["data"] |
|
for d in data: |
|
rows.append(d["row"] + "\n") |
|
|
|
random.shuffle(rows) |
|
with open(file_path, "w+", encoding="utf-8") as f: |
|
for row in rows: |
|
f.write(row) |
|
|
|
|
|
def plot(meta_data, save_path=None): |
|
save = False |
|
if save_path: |
|
save = True |
|
|
|
graph_data = get_data_points(meta_data) |
|
x = graph_data["x"] |
|
y_avg = graph_data["y_avg"] |
|
y_std = graph_data["y_std"] |
|
y_mode = graph_data["y_mode"] |
|
y_median = graph_data["y_median"] |
|
y_num_samples = graph_data["y_num_samples"] |
|
|
|
plt.figure() |
|
plt.plot(x, y_avg, "ro") |
|
plt.xlabel("character lengths", fontsize=30) |
|
plt.ylabel("avg seconds", fontsize=30) |
|
if save: |
|
name = "char_len_vs_avg_secs" |
|
plt.savefig(os.path.join(save_path, name)) |
|
|
|
plt.figure() |
|
plt.plot(x, y_mode, "ro") |
|
plt.xlabel("character lengths", fontsize=30) |
|
plt.ylabel("mode seconds", fontsize=30) |
|
if save: |
|
name = "char_len_vs_mode_secs" |
|
plt.savefig(os.path.join(save_path, name)) |
|
|
|
plt.figure() |
|
plt.plot(x, y_median, "ro") |
|
plt.xlabel("character lengths", fontsize=30) |
|
plt.ylabel("median seconds", fontsize=30) |
|
if save: |
|
name = "char_len_vs_med_secs" |
|
plt.savefig(os.path.join(save_path, name)) |
|
|
|
plt.figure() |
|
plt.plot(x, y_std, "ro") |
|
plt.xlabel("character lengths", fontsize=30) |
|
plt.ylabel("standard deviation", fontsize=30) |
|
if save: |
|
name = "char_len_vs_std" |
|
plt.savefig(os.path.join(save_path, name)) |
|
|
|
plt.figure() |
|
plt.plot(x, y_num_samples, "ro") |
|
plt.xlabel("character lengths", fontsize=30) |
|
plt.ylabel("number of samples", fontsize=30) |
|
if save: |
|
name = "char_len_vs_num_samples" |
|
plt.savefig(os.path.join(save_path, name)) |
|
|
|
|
|
def plot_phonemes(train_path, cmu_dict_path, save_path): |
|
cmudict = CMUDict(cmu_dict_path) |
|
|
|
phonemes = {} |
|
|
|
with open(train_path, "r", encoding="utf-8") as f: |
|
data = csv.reader(f, delimiter="|") |
|
phonemes["None"] = 0 |
|
for row in data: |
|
words = row[3].split() |
|
for word in words: |
|
pho = cmudict.lookup(word) |
|
if pho: |
|
indie = pho[0].split() |
|
for nemes in indie: |
|
if phonemes.get(nemes): |
|
phonemes[nemes] += 1 |
|
else: |
|
phonemes[nemes] = 1 |
|
else: |
|
phonemes["None"] += 1 |
|
|
|
x, y = [], [] |
|
for k, v in phonemes.items(): |
|
x.append(k) |
|
y.append(v) |
|
|
|
plt.figure() |
|
plt.rcParams["figure.figsize"] = (50, 20) |
|
barplot = sns.barplot(x=x, y=y) |
|
if save_path: |
|
fig = barplot.get_figure() |
|
fig.savefig(os.path.join(save_path, "phoneme_dist")) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--train_file_path", |
|
required=True, |
|
help="this is the path to the train.txt file that the preprocess.py script creates", |
|
) |
|
parser.add_argument("--save_to", help="path to save charts of data to") |
|
parser.add_argument("--cmu_dict_path", help="give cmudict-0.7b to see phoneme distribution") |
|
args = parser.parse_args() |
|
meta_data = process_meta_data(args.train_file_path) |
|
plt.rcParams["figure.figsize"] = (10, 5) |
|
plot(meta_data, save_path=args.save_to) |
|
if args.cmu_dict_path: |
|
plt.rcParams["figure.figsize"] = (30, 10) |
|
plot_phonemes(args.train_file_path, args.cmu_dict_path, args.save_to) |
|
|
|
plt.show() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|