File size: 3,159 Bytes
7bc83e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b1ccbd
 
 
 
7bc83e9
 
 
 
 
 
 
 
7b1ccbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bc83e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# utils.py

import torch
import numpy as np
import matplotlib.pyplot as plt
from gmm import GaussianMixtureModel

def initialize_gmm(mu_list, Sigma_list, pi_list):
    mu = torch.tensor(mu_list, dtype=torch.float32)
    Sigma = torch.tensor(Sigma_list, dtype=torch.float32)
    pi = torch.tensor(pi_list, dtype=torch.float32)
    return GaussianMixtureModel(mu, Sigma, pi)

def generate_grid(dx):
    x_positions = np.arange(-10, 10.5, 0.5)
    y_positions = np.arange(-10, 10.5, 0.5)
    fine_points = np.arange(-10, 10 + dx, dx)
    ones_same_size = np.ones_like(fine_points)
    vertical_lines = [np.stack([x*ones_same_size, fine_points], axis=1) for x in x_positions]
    horizontal_lines = [np.stack([fine_points, y*ones_same_size], axis=1) for y in y_positions]
    grid_points = np.concatenate(vertical_lines + horizontal_lines, axis=0)
    return torch.tensor(grid_points, dtype=torch.float32)

def generate_contours(dtheta):
    angles = np.linspace(0, 2 * np.pi, int(2 * np.pi / dtheta))
    std_normal_contours = np.concatenate([np.stack([r * np.cos(angles), r * np.sin(angles)], axis=1) for r in range(1, 4)], axis=0)
    return torch.tensor(std_normal_contours, dtype=torch.float32)

def transform_std_to_gmm_contours(std_contours, mu, Sigma):
    gmm_contours = []
    for k in range(mu.shape[0]):
        L = torch.linalg.cholesky(Sigma[k])
        gmm_contours.append(mu[k] + torch.matmul(std_contours, L.T))
    return torch.cat(gmm_contours, dim=0)

def generate_intermediate_points(gmm, grid_points, std_normal_contours, gmm_samples, normal_samples, T, N):
    gmm_contours = transform_std_to_gmm_contours(std_normal_contours, gmm.mu.squeeze(), gmm.Sigma)
    intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(gmm_samples.clone(), T, N)
    contour_intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(gmm_contours.clone(), T, N)
    grid_intermediate_points_gmm_to_normal = gmm.flow_gmm_to_normal(grid_points.clone(), T, N)

    intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(normal_samples.clone(), T, N)
    contour_intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(std_normal_contours.clone(), T, N)
    grid_intermediate_points_normal_to_gmm = gmm.flow_normal_to_gmm(grid_points.clone(), T, N)
    
    return (intermediate_points_gmm_to_normal, contour_intermediate_points_gmm_to_normal, grid_intermediate_points_gmm_to_normal,
            intermediate_points_normal_to_gmm, contour_intermediate_points_normal_to_gmm, grid_intermediate_points_normal_to_gmm)

def plot_samples_and_contours(samples, contours, grid_points, title):
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.scatter(grid_points[:, 0], grid_points[:, 1], alpha=0.5, c='black', s=1, label='Grid Points')
    ax.scatter(contours[:, 0], contours[:, 1], alpha=0.5, s=3, c='blue', label='Contours')
    ax.scatter(samples[:, 0], samples[:, 1], alpha=0.5, c='red', label='Samples')
    ax.set_title(title)
    ax.set_xlabel("x1")
    ax.set_ylabel("x2")
    ax.grid(True)
    ax.legend(loc='upper right')
    ax.set_xlim(-5, 5)
    ax.set_ylim(-5, 5)
    ax.set_aspect('equal', adjustable='box')
    plt.close(fig)
    return fig, ax