LPX55's picture
Update app.py
a280553 verified
import gradio as gr
import numpy as np
from io import BytesIO
from PIL import Image, ImageOps
import zipfile
import os
import atexit
import shutil
import cv2
import imageio
import torchvision.transforms.functional as TF
# Create a persistent directory to store generated files
GENERATED_FILES_DIR = "generated_files"
if not os.path.exists(GENERATED_FILES_DIR):
os.makedirs(GENERATED_FILES_DIR)
def cleanup_generated_files():
if os.path.exists(GENERATED_FILES_DIR):
shutil.rmtree(GENERATED_FILES_DIR)
# Register the cleanup function to run when the script exits
atexit.register(cleanup_generated_files)
def split_image_grid(image, grid_cols, grid_rows):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
width, height = image.width, image.height
cell_width = width // grid_cols
cell_height = height // grid_rows
frames = []
for i in range(grid_rows):
for j in range(grid_cols):
left = j * cell_width
upper = i * cell_height
right = left + cell_width
lower = upper + cell_height
frame = image.crop((left, upper, right, lower))
frames.append(np.array(frame))
return frames
def interpolate_frames(frames, interpolation_factor=2):
interpolated_frames = []
for i in range(len(frames) - 1):
frame1 = frames[i]
frame2 = frames[i + 1]
interpolated_frames.append(frame1)
for j in range(1, interpolation_factor):
t = j / interpolation_factor
frame_t = cv2.addWeighted(frame1, 1 - t, frame2, t, 0)
interpolated_frames.append(frame_t)
interpolated_frames.append(frames[-1])
return interpolated_frames
def enhance_gif(images):
enhanced_images = []
for img in images:
img = ImageOps.autocontrast(Image.fromarray(img))
img = img.convert("RGB") # Ensure the image is in RGB mode
enhanced_images.append(np.array(img))
return enhanced_images
def create_gif_imageio(images, fps=10, ping_pong_animation=False, loop=0):
duration = 1000 / fps # Convert FPS to milliseconds
if ping_pong_animation:
images = images + images[-2:0:-1] # Create a ping-pong sequence
gif_path = os.path.join(GENERATED_FILES_DIR, "output_enhanced.gif")
images_pil = [Image.fromarray(img) for img in images]
imageio.mimsave(gif_path, images_pil, duration=duration, loop=loop)
return gif_path
def process_image(image, grid_cols_input, grid_rows_input):
frames = split_image_grid(image, grid_cols_input, grid_rows_input)
zip_file = zip_images(frames)
return zip_file
def process_image_to_gif(image, grid_cols_input, grid_rows_input, fps_input, ping_pong_toggle, interpolation_factor):
frames = split_image_grid(image, grid_cols_input, grid_rows_input)
interpolated_frames = interpolate_frames(frames, interpolation_factor)
enhanced_frames = enhance_gif(interpolated_frames)
gif_file = create_gif_imageio(enhanced_frames, fps=fps_input, ping_pong_animation=ping_pong_toggle, loop=0)
# Preview the first frame of the GIF
# preview_image = Image.fromarray(enhanced_frames[0])
# preview_image.save(os.path.join(GENERATED_FILES_DIR, "output_enhanced.gif"))
preview_path = os.path.join(GENERATED_FILES_DIR, "output_enhanced.gif")
return preview_path, gif_file
def zip_images(images):
zip_path = os.path.join(GENERATED_FILES_DIR, "output.zip")
with zipfile.ZipFile(zip_path, 'w') as zipf:
for idx, img in enumerate(images):
img_buffer = BytesIO()
img = Image.fromarray(img)
img.save(img_buffer, format='PNG')
img_buffer.seek(0)
zipf.writestr(f'image_{idx}.png', img_buffer.getvalue())
return zip_path
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Input Grid Image", type="pil")
with gr.Column():
with gr.Row():
grid_cols_input = gr.Slider(1, 10, value=2, step=1, label="Grid Columns")
grid_rows_input = gr.Slider(1, 10, value=2, step=1, label="Grid Rows")
with gr.Row():
fps_input = gr.Slider(1, 30, value=10, step=1, label="FPS (Frames per Second)")
ping_pong_toggle = gr.Checkbox(label="Ping-Pong Effect")
interpolation_factor_input = gr.Slider(1, 10, value=2, step=1, label="Interpolation Factor")
with gr.Row():
gif_button = gr.Button("Create GIF")
zip_button = gr.Button("Create Zip File")
with gr.Row():
preview_image = gr.Image(label="Preview GIF")
with gr.Row():
gif_output = gr.File(label="Download GIF")
zip_output = gr.File(label="Download Zip File")
zip_button.click(process_image, inputs=[image_input, grid_cols_input, grid_rows_input], outputs=zip_output)
gif_button.click(
process_image_to_gif,
inputs=[image_input, grid_cols_input, grid_rows_input, fps_input, ping_pong_toggle, interpolation_factor_input],
outputs=[preview_image, gif_output]
)
demo.launch(show_error=True)