File size: 3,039 Bytes
2c480a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import torch
import gradio as gr
from torchvision import transforms
from PIL import Image
import numpy as np
from model import model
import tempfile
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
])
resize_output = transforms.Resize((512, 512))
def interpolate_vectors(v1, v2, num_steps):
return [v1 * (1 - alpha) + v2 * alpha for alpha in np.linspace(0, 1, num_steps)]
def to_pil(img_tensor):
img = img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
img = (img * 255).clip(0, 255).astype(np.uint8)
return Image.fromarray(img)
def interpolate_images_gif(img1, img2, num_interpolations=10, duration=100):
img1 = Image.fromarray(img1).convert('RGB')
img2 = Image.fromarray(img2).convert('RGB')
img1_tensor = transform(img1).unsqueeze(0).to(device)
img2_tensor = transform(img2).unsqueeze(0).to(device)
with torch.no_grad():
mu1, _ = model.encode(img1_tensor)
mu2, _ = model.encode(img2_tensor)
interpolated = interpolate_vectors(mu1, mu2, num_interpolations)
decoded_images = []
for z in interpolated:
out = model.decode(z)
img = to_pil(out)
img_resized = resize_output(img)
decoded_images.append(img_resized)
tmp_file = tempfile.NamedTemporaryFile(suffix=".gif", delete=False)
decoded_images[0].save(
tmp_file.name,
save_all=True,
append_images=decoded_images[1:],
duration=duration,
loop=0
)
return tmp_file.name
def get_interface():
with gr.Blocks() as iface:
gr.Markdown("## Latent Space Interpolation (GIF Output)")
with gr.Row():
img1 = gr.Image(label="First Image", type="numpy")
img2 = gr.Image(label="Second Image", type="numpy")
slider_steps = gr.Slider(5, 30, value=10, step=1, label="Number of Interpolations")
slider_duration = gr.Slider(50, 500, value=100, step=10, label="Duration per Frame (ms)")
output_gif = gr.Image(label="Interpolation GIF")
run_button = gr.Button("Interpolate")
run_button.click(
fn=interpolate_images_gif,
inputs=[img1, img2, slider_steps, slider_duration],
outputs=output_gif
)
examples = [
["example_images/image1.jpg", "example_images/image2.jpg", 10, 100],
["example_images/image3.jpg", "example_images/image4.jpg", 15, 150],
["example_images/image5.jpg", "example_images/image6.jpg", 20, 200],
["example_images/image7.jpg", "example_images/image8.jpg", 25, 250],
]
gr.Examples(
examples=examples,
inputs=[img1, img2, slider_steps, slider_duration],
outputs=output_gif,
fn=interpolate_images_gif,
cache_examples=False
)
return iface
|