File size: 3,039 Bytes
2c480a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import gradio as gr
from torchvision import transforms
from PIL import Image
import numpy as np
from model import model
import tempfile

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

resize_output = transforms.Resize((512, 512))

def interpolate_vectors(v1, v2, num_steps):
    return [v1 * (1 - alpha) + v2 * alpha for alpha in np.linspace(0, 1, num_steps)]

def to_pil(img_tensor):
    img = img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
    img = (img * 255).clip(0, 255).astype(np.uint8)
    return Image.fromarray(img)

def interpolate_images_gif(img1, img2, num_interpolations=10, duration=100):
    img1 = Image.fromarray(img1).convert('RGB')
    img2 = Image.fromarray(img2).convert('RGB')
    img1_tensor = transform(img1).unsqueeze(0).to(device)
    img2_tensor = transform(img2).unsqueeze(0).to(device)

    with torch.no_grad():
        mu1, _ = model.encode(img1_tensor)
        mu2, _ = model.encode(img2_tensor)
        interpolated = interpolate_vectors(mu1, mu2, num_interpolations)
        decoded_images = []
        for z in interpolated:
            out = model.decode(z)
            img = to_pil(out)
            img_resized = resize_output(img)
            decoded_images.append(img_resized)

    tmp_file = tempfile.NamedTemporaryFile(suffix=".gif", delete=False)
    decoded_images[0].save(
        tmp_file.name,
        save_all=True,
        append_images=decoded_images[1:],
        duration=duration,
        loop=0
    )
    return tmp_file.name

def get_interface():
    with gr.Blocks() as iface:
        gr.Markdown("## Latent Space Interpolation (GIF Output)")
        with gr.Row():
            img1 = gr.Image(label="First Image", type="numpy")
            img2 = gr.Image(label="Second Image", type="numpy")
        slider_steps = gr.Slider(5, 30, value=10, step=1, label="Number of Interpolations")
        slider_duration = gr.Slider(50, 500, value=100, step=10, label="Duration per Frame (ms)")
        output_gif = gr.Image(label="Interpolation GIF")
        run_button = gr.Button("Interpolate")

        run_button.click(
            fn=interpolate_images_gif,
            inputs=[img1, img2, slider_steps, slider_duration],
            outputs=output_gif
        )

        examples = [
            ["example_images/image1.jpg", "example_images/image2.jpg", 10, 100],
            ["example_images/image3.jpg", "example_images/image4.jpg", 15, 150],
            ["example_images/image5.jpg", "example_images/image6.jpg", 20, 200],
            ["example_images/image7.jpg", "example_images/image8.jpg", 25, 250],
        ]

        gr.Examples(
            examples=examples,
            inputs=[img1, img2, slider_steps, slider_duration],
            outputs=output_gif,
            fn=interpolate_images_gif,
            cache_examples=False
        )
    return iface