|
import matplotlib |
|
import numpy |
|
import soundfile as sf |
|
from matplotlib import pyplot as plt |
|
from matplotlib import cm |
|
matplotlib.use("tkAgg") |
|
from sklearn.manifold import TSNE |
|
from sklearn.decomposition import PCA |
|
|
|
from tqdm import tqdm |
|
|
|
from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor |
|
|
|
|
|
class Visualizer: |
|
|
|
def __init__(self, sr=48000, device="cpu"): |
|
""" |
|
Args: |
|
sr: The sampling rate of the audios you want to visualize. |
|
""" |
|
self.tsne = TSNE(n_jobs=-1) |
|
self.pca = PCA(n_components=2) |
|
self.pros_cond_ext = ProsodicConditionExtractor(sr=sr, device=device) |
|
self.sr = sr |
|
|
|
def visualize_speaker_embeddings(self, label_to_filepaths, title_of_plot, save_file_path=None, include_pca=True, legend=True): |
|
label_list = list() |
|
embedding_list = list() |
|
for label in tqdm(label_to_filepaths): |
|
for filepath in tqdm(label_to_filepaths[label]): |
|
wave, sr = sf.read(filepath) |
|
if len(wave) / sr < 1: |
|
continue |
|
if self.sr != sr: |
|
print("One of the Audios you included doesn't match the sampling rate of this visualizer object, " |
|
"creating a new condition extractor. Results will be correct, but if there are too many cases " |
|
"of changing samplingrate, this will run very slowly.") |
|
self.pros_cond_ext = ProsodicConditionExtractor(sr=sr) |
|
self.sr = sr |
|
embedding_list.append(self.pros_cond_ext.extract_condition_from_reference_wave(wave).squeeze().numpy()) |
|
label_list.append(label) |
|
embeddings_as_array = numpy.array(embedding_list) |
|
|
|
dimensionality_reduced_embeddings_tsne = self.tsne.fit_transform(embeddings_as_array) |
|
self._plot_embeddings(projected_data=dimensionality_reduced_embeddings_tsne, |
|
labels=label_list, |
|
title=title_of_plot + " t-SNE" if include_pca else title_of_plot, |
|
save_file_path=save_file_path, |
|
legend=legend) |
|
|
|
if include_pca: |
|
dimensionality_reduced_embeddings_pca = self.pca.fit_transform(embeddings_as_array) |
|
self._plot_embeddings(projected_data=dimensionality_reduced_embeddings_pca, |
|
labels=label_list, |
|
title=title_of_plot + " PCA", |
|
save_file_path=save_file_path, |
|
legend=legend) |
|
|
|
def _plot_embeddings(self, projected_data, labels, title, save_file_path, legend): |
|
colors = cm.gist_rainbow(numpy.linspace(0, 1, len(set(labels)))) |
|
label_to_color = dict() |
|
for index, label in enumerate(list(set(labels))): |
|
label_to_color[label] = colors[index] |
|
|
|
labels_to_points_x = dict() |
|
labels_to_points_y = dict() |
|
for label in labels: |
|
labels_to_points_x[label] = list() |
|
labels_to_points_y[label] = list() |
|
for index, label in enumerate(labels): |
|
labels_to_points_x[label].append(projected_data[index][0]) |
|
labels_to_points_y[label].append(projected_data[index][1]) |
|
|
|
fig, ax = plt.subplots() |
|
for label in set(labels): |
|
x = numpy.array(labels_to_points_x[label]) |
|
y = numpy.array(labels_to_points_y[label]) |
|
ax.scatter(x=x, |
|
y=y, |
|
c=label_to_color[label], |
|
label=label, |
|
alpha=0.9) |
|
if legend: |
|
ax.legend() |
|
fig.tight_layout() |
|
ax.axis('off') |
|
fig.subplots_adjust(top=0.9, bottom=0.0, right=1.0, left=0.0) |
|
ax.set_title(title) |
|
if save_file_path is not None: |
|
plt.savefig(save_file_path) |
|
else: |
|
plt.show() |
|
plt.close() |
|
|