File size: 2,625 Bytes
0c9bb32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
Generate images from trained model
"""

import argparse
import pickle

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import yaml
from flax import nnx
from jax.experimental import ode

from model import DiT, DiTConfig


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config", type=str, default="config.yaml", help="Path to config file"
    )
    parser.add_argument(
        "--ckpt", type=str, default=None, help="Path to checkpoint file"
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    return parser.parse_args()


def load_config(config_path):
    with open(config_path) as f:
        config = yaml.safe_load(f)
    return config


@jax.jit
def sample_images(graphdef, state, rng):
    flow = nnx.merge(graphdef, state)

    def flow_fn(y, t):
        o = flow(y, t[None])
        return o

    x = jax.random.normal(rng, shape=(16, 64, 64, 3), dtype=jnp.float32)
    o = ode.odeint(flow_fn, x, jnp.linspace(0, 1, 1000))
    o = jnp.clip(o[-1], 0, 1)
    return o


def plot_new_images(graphdef, state, seed):
    images = sample_images(graphdef, state, nnx.Rngs(seed)())

    plt.figure(figsize=(2, 2))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow(images[i])
        plt.axis("off")
    plt.subplots_adjust(left=0, bottom=0, top=1, right=1, wspace=0, hspace=0)
    plt.savefig(f"samples.png")
    plt.close()


def main():
    args = parse_args()
    config = load_config(args.config)

    dit_config = DiTConfig(
        input_dim=config["model"]["input_dim"],
        hidden_dim=config["model"]["hidden_dim"],
        num_blocks=config["model"]["num_blocks"],
        num_heads=config["model"]["num_heads"],
        patch_size=config["model"]["patch_size"],
        patch_stride=config["model"]["patch_stride"],
        time_freq_dim=config["model"]["time_freq_dim"],
        time_max_period=config["model"]["time_max_period"],
        mlp_ratio=config["model"]["mlp_ratio"],
        use_bias=config["model"]["use_bias"],
        padding=config["model"]["padding"],
        pos_embed_cls_token=config["model"]["pos_embed_cls_token"],
        pos_embed_extra_tokens=config["model"]["pos_embed_extra_tokens"],
    )

    abstract_flow = nnx.eval_shape(lambda: DiT(dit_config, rngs=nnx.Rngs(0)))
    graphdef, _ = nnx.split(abstract_flow)
    with open(args.ckpt, "rb") as f:
        state = pickle.load(f, fix_imports=True)
        if "time_embedding" not in state:
            state = state[0]
    plot_new_images(graphdef, state, args.seed)


if __name__ == "__main__":
    main()