|
|
|
|
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from matplotlib.animation import FuncAnimation |
|
from gmm import GaussianMixtureModel |
|
|
|
def initialize_gmm(mu_list, Sigma_list, pi_list): |
|
mu = torch.tensor(mu_list, dtype=torch.float32) |
|
Sigma = torch.tensor(Sigma_list, dtype=torch.float32) |
|
pi = torch.tensor(pi_list, dtype=torch.float32) |
|
return GaussianMixtureModel(mu, Sigma, pi) |
|
|
|
def generate_grid(dx): |
|
x_positions = np.arange(-10, 10.5, 0.5) |
|
y_positions = np.arange(-10, 10.5, 0.5) |
|
vertical_lines = [np.stack([np.full(int((10 - (-10))/ dx + 1), x), np.arange(-10, 10 + dx, dx)], axis=1) for x in x_positions] |
|
horizontal_lines = [np.stack([np.arange(-10, 10 + dx, dx), np.full(int((10 - (-10)) / dx + 1), y)], axis=1) for y in y_positions] |
|
grid_points = np.concatenate(vertical_lines + horizontal_lines, axis=0) |
|
return torch.tensor(grid_points, dtype=torch.float32) |
|
|
|
def generate_contours(dtheta): |
|
angles = np.linspace(0, 2 * np.pi, int(2 * np.pi / dtheta)) |
|
std_normal_contours = np.concatenate([np.stack([r * np.cos(angles), r * np.sin(angles)], axis=1) for r in range(1, 4)], axis=0) |
|
return torch.tensor(std_normal_contours, dtype=torch.float32) |
|
|
|
def generate_intermediate_points(gmm, grid_points, std_normal_contours, gmm_samples, T, N): |
|
intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(grid_points, T, N) |
|
contour_intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(std_normal_contours, T, N) |
|
grid_intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(grid_points, T, N) |
|
|
|
intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(gmm_samples, T, N) |
|
contour_intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(std_normal_contours, T, N) |
|
grid_intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(grid_points, T, N) |
|
|
|
return (intermediate_points_gmm_to_normal, contour_intermediate_points_gmm_to_normal, grid_intermediate_points_gmm_to_normal, |
|
intermediate_points_normal_to_gmm, contour_intermediate_points_normal_to_gmm, grid_intermediate_points_normal_to_gmm) |
|
|
|
def plot_samples_and_contours(samples, contours, grid_points, title): |
|
fig, ax = plt.subplots(figsize=(8, 6)) |
|
ax.scatter(grid_points[:, 0], grid_points[:, 1], alpha=0.5, c='black', s=1, label='Grid Points') |
|
ax.scatter(contours[:, 0], contours[:, 1], alpha=0.5, s=3, c='blue', label='Contours') |
|
ax.scatter(samples[:, 0], samples[:, 1], alpha=0.5, c='red', label='Samples') |
|
ax.set_title(title) |
|
ax.set_xlabel("x1") |
|
ax.set_ylabel("x2") |
|
ax.grid(True) |
|
ax.legend(loc='upper right') |
|
ax.set_xlim(-5, 5) |
|
ax.set_ylim(-5, 5) |
|
ax.set_aspect('equal', adjustable='box') |
|
plt.close(fig) |
|
return fig, ax |
|
|
|
def create_animation(fig, ax, frames, intermediate_points, intermediate_samples, intermediate_contours, intermediate_grid): |
|
scatter_grid = ax.scatter([], [], c='black', alpha=0.5, s=1, label='Grid Points') |
|
contour_scatter = ax.scatter([], [], c='blue', alpha=0.5, s=3, label='Contours') |
|
scatter_samples = ax.scatter([], [], c='red', alpha=0.5, label='Samples') |
|
|
|
def update(frame): |
|
scatter_grid.set_offsets(intermediate_points[frame].numpy()) |
|
scatter_samples.set_offsets(intermediate_samples[frame].numpy()) |
|
contour_scatter.set_offsets(intermediate_contours[frame].numpy()) |
|
return scatter_grid, scatter_samples, contour_scatter |
|
|
|
anim = FuncAnimation(fig, update, frames=frames, blit=True) |
|
return anim |