import json
import os
import types
from urllib.parse import urlparse
import cv2
import diffusers
import gradio as gr
import numpy as np
import torch
from einops import rearrange
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from PIL import Image, ImageOps
from safetensors.torch import load_file
from torch.nn import functional as F
from torchdiffeq import odeint_adjoint as odeint
import spaces
from echoflow.common import instantiate_class_from_config, unscale_latents
from echoflow.common.models import (
ContrastiveModel,
DiffuserSTDiT,
ResNet18,
SegDiTTransformer2DModel,
)
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
# 4f4 latent space
B, T, C, H, W = 1, 64, 4, 28, 28
VIEWS = ["A4C", "PSAX", "PLAX"]
def load_model(path):
if path.startswith("http"):
parsed_url = urlparse(path)
if "huggingface.co" in parsed_url.netloc:
parts = parsed_url.path.strip("/").split("/")
repo_id = "/".join(parts[:2])
subfolder = None
if len(parts) > 3:
subfolder = "/".join(parts[4:])
local_root = "./tmp"
local_dir = os.path.join(local_root, repo_id.replace("/", "_"))
if subfolder:
local_dir = os.path.join(local_root, subfolder)
os.makedirs(local_root, exist_ok=True)
config_file = hf_hub_download(
repo_id=repo_id,
subfolder=subfolder,
filename="config.json",
local_dir=local_root,
repo_type="model",
token=os.getenv("READ_HF_TOKEN"),
local_dir_use_symlinks=False,
)
assert os.path.exists(config_file)
hf_hub_download(
repo_id=repo_id,
filename="diffusion_pytorch_model.safetensors",
subfolder=subfolder,
local_dir=local_root,
local_dir_use_symlinks=False,
token=os.getenv("READ_HF_TOKEN"),
)
path = local_dir
model_root = os.path.join(config_file.split("config.json")[0])
json_path = os.path.join(model_root, "config.json")
assert os.path.exists(json_path)
with open(json_path, "r") as f:
config = json.load(f)
klass_name = config["_class_name"]
klass = getattr(diffusers, klass_name, None) or globals().get(klass_name, None)
assert (
klass is not None
), f"Could not find class {klass_name} in diffusers or global scope."
assert hasattr(
klass, "from_pretrained"
), f"Class {klass_name} does not support 'from_pretrained'."
return klass.from_pretrained(path)
def load_reid(path):
parsed_url = urlparse(path)
parts = parsed_url.path.strip("/").split("/")
repo_id = "/".join(parts[:2])
subfolder = "/".join(parts[4:])
local_root = "./tmp"
config_file = hf_hub_download(
repo_id=repo_id,
subfolder=subfolder,
filename="config.yaml",
local_dir=local_root,
repo_type="model",
token=os.getenv("READ_HF_TOKEN"),
local_dir_use_symlinks=False,
)
weights_file = hf_hub_download(
repo_id=repo_id,
subfolder=subfolder,
filename="backbone.safetensors",
local_dir=local_root,
repo_type="model",
token=os.getenv("READ_HF_TOKEN"),
local_dir_use_symlinks=False,
)
config = OmegaConf.load(config_file)
backbone = instantiate_class_from_config(config.backbone)
backbone = ContrastiveModel.patch_backbone(
backbone, config.model.args.in_channels, config.model.args.out_channels
)
state_dict = load_file(weights_file)
backbone.load_state_dict(state_dict)
backbone = backbone.to(device, dtype=dtype)
backbone.eval()
return backbone
def get_vae_scaler(path):
scaler = torch.load(path)
scaler = {k: v.to(device) for k, v in scaler.items()}
return scaler
generator = torch.Generator(device=device).manual_seed(0)
lifm = load_model("https://huggingface.co/HReynaud/EchoFlow/tree/main/lifm/FMiT-S2-4f4")
lifm = lifm.to(device, dtype=dtype)
lifm.eval()
vae = load_model("https://huggingface.co/HReynaud/EchoFlow/tree/main/vae/avae-4f4")
vae = vae.to(device, dtype=dtype)
vae.eval()
vae_scaler = get_vae_scaler("assets/scaling.pt")
reid = {
"anatomies": {
"A4C": torch.cat(
[
torch.load("assets/anatomies_dynamic.pt"),
torch.load("assets/anatomies_ped_a4c.pt"),
],
dim=0,
),
"PSAX": torch.load("assets/anatomies_ped_psax.pt"),
"PLAX": torch.load("assets/anatomies_lvh.pt"),
},
"models": {
"A4C": load_reid(
"https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/dynamic-4f4"
),
"PSAX": load_reid(
"https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/ped_psax-4f4"
),
"PLAX": load_reid(
"https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/lvh-4f4"
),
},
"tau": {
"A4C": 0.9997,
"PSAX": 0.9953,
"PLAX": 0.9950,
},
}
lvfm = load_model("https://huggingface.co/HReynaud/EchoFlow/tree/main/lvfm/FMvT-S2-4f4")
lvfm = lvfm.to(device, dtype=dtype)
lvfm.eval()
def load_default_mask():
"""Load the default mask from disk. If not found, return a blank black mask."""
default_mask_path = os.path.join("assets", "default_mask.png")
try:
if os.path.exists(default_mask_path):
mask = Image.open(default_mask_path).convert("L")
# Ensure the mask is square and of proper size
mask = mask.resize((400, 400), Image.Resampling.LANCZOS)
# Make sure it's binary (0 or 255)
mask = ImageOps.autocontrast(mask, cutoff=0)
return np.array(mask)
except Exception as e:
print(f"Error loading default mask: {e}")
# Return a blank black mask if no default mask is found
return np.zeros((400, 400), dtype=np.uint8)
def preprocess_mask(mask):
"""Ensure mask is properly formatted for the model."""
if mask is None:
return np.zeros((112, 112), dtype=np.uint8)
# Check if mask is an EditorValue with multiple parts
if isinstance(mask, dict) and "composite" in mask:
# Use the composite image from the ImageEditor
mask = mask["composite"]
# If mask is already a numpy array, convert to PIL for processing
if isinstance(mask, np.ndarray):
mask_pil = Image.fromarray(mask)
else:
mask_pil = mask
# Ensure the mask is in L mode (grayscale)
mask_pil = mask_pil.convert("L")
# Apply contrast to make it binary (0 or 255)
mask_pil = ImageOps.autocontrast(mask_pil, cutoff=0)
# Threshold to ensure binary values
mask_pil = mask_pil.point(lambda p: 255 if p > 127 else 0)
# Print sizes for debugging
# print(f"Original mask size: {mask_pil.size}")
# Resize to 112x112 for the model
mask_pil = mask_pil.resize((112, 112), Image.Resampling.LANCZOS)
# Convert back to numpy array
return np.array(mask_pil)
@spaces.GPU
def generate_latent_image(mask, class_selection, sampling_steps=50):
"""Generate a latent image based on mask, class selection, and sampling steps"""
# Mask
mask = preprocess_mask(mask)
mask = torch.from_numpy(mask).to(device, dtype=dtype)
mask = mask.unsqueeze(0).unsqueeze(0)
mask = F.interpolate(mask, size=(H, W), mode="bilinear", align_corners=False)
mask = 1.0 * (mask > 0)
# print(mask.shape, mask.min(), mask.max(), mask.mean(), mask.std())
# Class
class_idx = VIEWS.index(class_selection)
class_idx = torch.tensor([class_idx], device=device, dtype=torch.long)
# Timesteps
timesteps = torch.linspace(
1.0, 0.0, steps=sampling_steps + 1, device=device, dtype=dtype
)
forward_kwargs = {
"class_labels": class_idx, # B x 1
"segmentation": mask, # B x 1 x H x W
}
z_1 = torch.randn(
(B, C, H, W),
device=device,
dtype=dtype,
generator=generator,
)
lifm.forward_original = lifm.forward
def new_forward(self, t, y, *args, **kwargs):
kwargs = {**kwargs, **forward_kwargs}
return self.forward_original(y, t.view(1), *args, **kwargs).sample
lifm.forward = types.MethodType(new_forward, lifm)
# Use odeint to integrate
with torch.autocast("cuda"):
latent_image = odeint(
lifm,
z_1,
timesteps,
atol=1e-5,
rtol=1e-5,
adjoint_params=lifm.parameters(),
method="euler",
)[-1]
lifm.forward = lifm.forward_original
latent_image = latent_image.detach().cpu().numpy()
# callm VAE here
return latent_image # B x C x H x W
@spaces.GPU
def decode_images(latents, vae):
"""Decode latent representations to pixel space using a VAE.
Args:
latents: A numpy array of shape [B, C, H, W] for single image
or [B, C, T, H, W] for sequences/animations
vae: The VAE model for decoding
Returns:
numpy array of decoded images in [B, H, W, 3] format for single image
or [B, C, T, H, W] for sequences
"""
if latents is None:
return None
# Convert to torch tensor if needed
if not isinstance(latents, torch.Tensor):
latents = torch.from_numpy(latents).to(device, dtype=dtype)
# Unscale latents
latents = unscale_latents(latents, vae_scaler)
# Handle both single images and sequences
is_sequence = len(latents.shape) == 5 # B C T H W
# print("Sequence:", is_sequence)
if is_sequence:
B, C, T, H, W = latents.shape
latents = rearrange(latents[0], "c t h w -> t c h w")
else:
B, C, H, W = latents.shape
# print("Latents:", latents.shape)
with torch.no_grad():
# Decode latents to pixel space
# decode one by one
decoded = []
for i in range(latents.shape[0]):
decoded.append(vae.decode(latents[i : i + 1].float()).sample)
decoded = torch.cat(decoded, dim=0)
decoded = (decoded + 1) * 128
decoded = decoded.clamp(0, 255).to(torch.uint8).cpu()
if is_sequence:
# Reshape back to [B, C, T, H, W] for sequences
decoded = rearrange(decoded, "t c h w -> c t h w").unsqueeze(0)
else:
decoded = decoded.squeeze()
decoded = decoded.permute(1, 2, 0)
# print("Decoded:", decoded.shape)
return decoded.numpy()
def decode_latent_to_pixel(latent_image):
"""Decode a single latent image to pixel space"""
global vae
if latent_image is None:
return None
# Add batch dimension if needed
if len(latent_image.shape) == 3:
latent_image = latent_image[None, ...]
decoded_image = decode_images(latent_image, vae)
decoded_image = cv2.resize(
decoded_image, (400, 400), interpolation=cv2.INTER_NEAREST
)
return decoded_image
def check_privacy(latent_image_numpy, class_selection):
"""Check if the latent image is too similar to database images"""
latent_image = torch.from_numpy(latent_image_numpy).to(device, dtype=dtype)
reid_model = reid["models"][class_selection].to(device, dtype=dtype)
real_anatomies = reid["anatomies"][class_selection] # already scaled
tau = reid["tau"][class_selection]
with torch.no_grad():
features = reid_model(latent_image).sigmoid().cpu()
corr = torch.corrcoef(torch.cat([real_anatomies, features], dim=0))[0, 1:]
corr = corr.max()
if corr > tau:
return (
None,
f"⚠️ **Warning:** Generated image is too similar to training data. Privacy check failed (corr = {corr:.4f} / tau = {tau:.4f})",
)
else:
return (
latent_image_numpy,
f"✅ **Success:** Generated image passed privacy check (corr = {corr:.4f} / tau = {tau:.4f})",
)
@spaces.GPU
def generate_animation(
latent_image, ejection_fraction, sampling_steps=50, cfg_scale=1.0
):
"""Generate an animated sequence of latent images based on EF"""
# print(
# f"Generating animation with EF = {ejection_fraction}, steps = {sampling_steps}, CFG = {cfg_scale}"
# )
# print(latent_image.shape, type(latent_image))
if latent_image is None:
return None
lvefs = torch.tensor([ejection_fraction / 100.0], device=device, dtype=dtype)
lvefs = lvefs[:, None, None].to(device, dtype)
uncond_lvefs = -1 * torch.ones_like(lvefs)
ref_images = torch.from_numpy(latent_image).to(device, dtype)
ref_images = ref_images[:, :, None, :, :] # B x C x 1 x H x W
ref_images = ref_images.repeat(1, 1, T, 1, 1) # B x C x T x H x W
uncond_images = torch.zeros_like(ref_images)
timesteps = torch.linspace(
1.0, 0.0, steps=sampling_steps + 1, device=device, dtype=dtype
)
forward_kwargs = {
"encoder_hidden_states": lvefs,
"cond_image": ref_images,
}
z_1 = torch.randn(
(B, C, T, H, W),
device=device,
dtype=dtype,
generator=generator,
)
# print(
# z_1.shape,
# forward_kwargs["encoder_hidden_states"].shape,
# forward_kwargs["cond_image"].shape,
# )
lvfm.forward_original = lvfm.forward
def new_forward(self, t, y, *args, **kwargs):
kwargs = {**kwargs, **forward_kwargs}
# y has shape (B, C, T, H, W)
pred = self.forward_original(y, t.repeat(y.size(0)), *args, **kwargs).sample
if cfg_scale != 1.0:
uncond_kwargs = {
"encoder_hidden_states": uncond_lvefs,
"cond_image": uncond_images,
}
uncond_pred = self.forward_original(
y, t.repeat(y.size(0)), *args, **uncond_kwargs
).sample
pred = uncond_pred + cfg_scale * (pred - uncond_pred)
return pred
lvfm.forward = types.MethodType(new_forward, lvfm)
with torch.autocast("cuda"):
synthetic_video = odeint(
lvfm,
z_1,
timesteps,
atol=1e-5,
rtol=1e-5,
adjoint_params=lvfm.parameters(),
method="euler",
)[-1]
lvfm.forward = lvfm.forward_original
# print("Synthetic video:", synthetic_video.shape)
return synthetic_video # B x C x T x H x W
def decode_animation(latent_animation):
"""Decode a latent animation to pixel space"""
global vae
if latent_animation is None:
return None
# Convert to torch tensor if needed
if not isinstance(latent_animation, torch.Tensor):
latent_animation = torch.from_numpy(latent_animation).to(device, dtype=dtype)
# Ensure shape is B x C x T x H x W
if len(latent_animation.shape) == 4: # [T, C, H, W]
latent_animation = latent_animation[None, ...] # Add batch dimension
# Decode using VAE
decoded = decode_images(
latent_animation, vae
) # Returns B x C x T x H x W numpy array
# Remove batch dimension and transpose to T x H x W x C
decoded = np.transpose(decoded[0], (1, 2, 3, 0)) # [T, H, W, C]
# Resize frames to 400x400
decoded = np.stack(
[
cv2.resize(frame, (400, 400), interpolation=cv2.INTER_NEAREST)
for frame in decoded
]
)
# Save to temporary file
temp_file = "temp_video_2.mp4"
fps = 32
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(temp_file, fourcc, fps, (400, 400))
# Write frames
for frame in decoded:
out.write(frame)
out.release()
return temp_file
def convert_latent_to_display(latent_image):
"""Convert multi-channel latent image to grayscale for display"""
if latent_image is None:
return None
# Check shape
if len(latent_image.shape) == 4: # [B, C, H, W]
# Remove batch dimension and average across channels
display_image = np.squeeze(latent_image, axis=0) # [C, H, W]
display_image = np.mean(display_image, axis=0) # [H, W]
elif len(latent_image.shape) == 3: # [C, H, W]
# Average across channels
display_image = np.mean(latent_image, axis=0) # [H, W]
else:
display_image = latent_image
# Normalize to 0-1 range
display_image = (display_image - display_image.min()) / (
display_image.max() - display_image.min() + 1e-8
)
# Convert to grayscale image
display_image = (display_image * 255).astype(np.uint8)
# Resize to a larger size (e.g., 400x400) using bicubic interpolation
display_image = cv2.resize(
display_image, (400, 400), interpolation=cv2.INTER_NEAREST
)
return display_image
def latent_animation_to_grayscale(latent_animation):
"""Convert multi-channel latent animation to grayscale for display"""
if latent_animation is None:
return None
# print("Input shape:", latent_animation.shape)
# Convert to numpy if it's a torch tensor
if torch.is_tensor(latent_animation):
latent_animation = latent_animation.detach().cpu().numpy()
# Handle shape B x C x T x H x W -> T x H x W
if len(latent_animation.shape) == 5: # [B, C, T, H, W]
latent_animation = np.squeeze(latent_animation, axis=0) # [C, T, H, W]
latent_animation = np.transpose(latent_animation, (1, 0, 2, 3)) # [T, C, H, W]
# print("After transpose:", latent_animation.shape)
# Average across channels
latent_animation = np.mean(latent_animation, axis=1) # [T, H, W]
# print("After channel reduction:", latent_animation.shape)
# Normalize each frame independently
min_vals = latent_animation.min(axis=(1, 2), keepdims=True)
max_vals = latent_animation.max(axis=(1, 2), keepdims=True)
latent_animation = (latent_animation - min_vals) / (max_vals - min_vals + 1e-8)
# Convert to uint8
latent_animation = (latent_animation * 255).astype(np.uint8)
# print("Before resize:", latent_animation.shape)
# Resize each frame
resized_frames = []
for frame in latent_animation:
resized = cv2.resize(frame, (400, 400), interpolation=cv2.INTER_NEAREST)
resized_frames.append(resized)
# Stack back into video
grayscale_video = np.stack(resized_frames)
# print("Final shape:", grayscale_video.shape)
# Add a dummy channel dimension for grayscale video
grayscale_video = grayscale_video[..., None].repeat(3, axis=-1) # Convert to RGB
# print("Output shape with channels:", grayscale_video.shape)
# Save to temporary file
temp_file = "temp_video.mp4"
fps = 32
# Create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(temp_file, fourcc, fps, (400, 400))
# Write frames
for frame in grayscale_video:
out.write(frame)
out.release()
return temp_file
def create_demo():
# Define the theme and layout
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# EchoFlow Demo")
gr.Markdown("## Dataset Generation Pipeline")
gr.Markdown(
"""
### 🎯 Purpose
This demo showcases EchoFlow's ability to generate synthetic echocardiogram images and videos while preserving patient privacy. The pipeline consists of four main steps:
1. **Latent Image Generation**: Draw a mask to indicate the region where the Left Ventricle should appear. Select the desired cardiac view, and click "Generate Latent Image". This outputs a latent image, which can be decoded into a pixel space image by clicking "Decode to Pixel Space".
2. **Privacy Filter**: When clicking "Run Privacy Check", the generated image will be checked against a database of all training anatomies to ensure it is sufficiently different from real patient data.
3. **Latent Video Generation**: If the privacy check passes, the latent image can be animated into a video with the desired Ejection Fraction.
4. **Video Decoding**: The video can be decoded into a pixel space video by clicking "Decode Video".
### ⚙️ Parameters
- **Sampling Steps**: Higher values produce better quality but take longer
- **Ejection Fraction**: Controls the strength of heart contraction in the animation
- **CFG Scale**: Controls how closely the animation follows the specified conditions
"""
)
# Main container with 4 columns
with gr.Row():
# Column 1: Latent Image Generation
with gr.Column():
gr.Markdown(
'
'
)
gr.Markdown("### Latent Image Generation")
with gr.Row():
# Input mask (binary image)
with gr.Column(scale=1):
# gr.Markdown("#### Mask Condition")
gr.Markdown("Draw the LV mask (white = region of interest)")
# Create a black background for the canvas
black_background = np.zeros((400, 400), dtype=np.uint8)
# Load the default mask image if it exists
try:
mask_image = Image.open("assets/seg.png").convert("L")
mask_image = mask_image.resize(
(400, 400), Image.Resampling.LANCZOS
)
# Make it binary (0 or 255)
mask_image = ImageOps.autocontrast(mask_image, cutoff=0)
mask_image = mask_image.point(
lambda p: 255 if p > 127 else 0
)
mask_array = np.array(mask_image)
# Create the editor value structure
editor_value = {
"background": black_background, # Black background
"layers": [mask_array], # The mask as an editable layer
"composite": mask_array, # The composite image (what's displayed)
}
except Exception as e:
print(f"Error loading mask image: {e}")
# Fall back to empty canvas
editor_value = black_background
mask_input = gr.ImageEditor(
label="Binary Mask",
height=400,
width=400,
image_mode="L",
value=editor_value,
type="numpy",
brush=gr.Brush(
colors=["#ffffff"],
color_mode="fixed",
default_size=20,
default_color="#ffffff",
),
eraser=gr.Eraser(default_size=20),
# show_label=False,
show_download_button=True,
sources=[],
canvas_size=(400, 400),
fixed_canvas=True,
layers=False, # Enable layers to make the mask editable
)
# # Class selection
# with gr.Column(scale=1):
# gr.Markdown("#### View Condition")
class_selection = gr.Radio(
choices=["A4C", "PSAX", "PLAX"],
label="View Class",
value="A4C",
)
# gr.Markdown("#### Sampling Steps")
sampling_steps = gr.Slider(
minimum=1,
maximum=200,
value=100,
step=1,
label="Number of Sampling Steps",
info="Higher values = better quality but slower generation",
)
# Generate button
generate_btn = gr.Button("Generate Latent Image", variant="primary")
# Display area for latent image (grayscale visualization)
latent_image_display = gr.Image(
label="Latent Image",
type="numpy",
height=400,
width=400,
# show_label=False,
)
# Decode button (initially disabled)
decode_btn = gr.Button(
"Decode to Pixel Space (Optional)",
interactive=False,
variant="primary",
)
# Display area for decoded image
decoded_image_display = gr.Image(
label="Decoded Image",
type="numpy",
height=400,
width=400,
# show_label=False,
)
# Column 2: Privacy Filter
with gr.Column():
gr.Markdown(
'
'
)
gr.Markdown("### Privacy Filter")
gr.Markdown(
"Checks if the generated image is too similar to training data"
)
# Privacy check button
privacy_btn = gr.Button(
"Run Privacy Check", interactive=False, variant="primary"
)
# Display area for privacy result status
privacy_status = gr.Markdown("No image processed yet")
# Display area for privacy-filtered latent image
filtered_latent_display = gr.Image(
label="Filtered Latent Image", type="numpy", height=400, width=400
)
# Column 3: Animation
with gr.Column():
gr.Markdown(
'
'
)
gr.Markdown("### Latent Video Generation")
# Ejection Fraction slider
ef_slider = gr.Slider(
minimum=0,
maximum=100,
value=65,
label="Ejection Fraction (%)",
info="Higher values = stronger contraction",
)
# Add sampling steps slider for animation
animation_steps = gr.Slider(
minimum=1,
maximum=200,
value=100,
step=1,
label="Number of Sampling Steps",
info="Higher values = better quality but slower generation",
)
# Add CFG slider
cfg_slider = gr.Slider(
minimum=0,
maximum=10,
value=1,
step=1,
label="Classifier-Free Guidance Scale",
# info="Higher values = better quality but slower generation",
)
# Animate button
animate_btn = gr.Button(
"Generate Video", interactive=False, variant="primary"
)
# Display area for latent animation (grayscale)
latent_animation_display = gr.Video(
label="Latent Video", format="mp4", autoplay=True, loop=True
)
# Column 4: Video Decoding
with gr.Column():
gr.Markdown(
'
'
)
gr.Markdown("### Video Decoding")
# Decode animation button
decode_animation_btn = gr.Button(
"Decode Video", interactive=False, variant="primary"
)
# Display area for decoded animation
decoded_animation_display = gr.Video(
label="Decoded Video", format="mp4", autoplay=True, loop=True
)
# Hidden state variables to store the full latent representations
latent_image_state = gr.State(None)
filtered_latent_state = gr.State(None)
latent_animation_state = gr.State(None)
# Event handlers
generate_btn.click(
fn=generate_latent_image,
inputs=[mask_input, class_selection, sampling_steps],
outputs=[latent_image_state],
queue=True,
).then(
fn=convert_latent_to_display,
inputs=[latent_image_state],
outputs=[latent_image_display],
queue=False,
).then(
fn=lambda x: gr.Button(
interactive=x is not None
), # Properly update button state
inputs=[latent_image_state],
outputs=[decode_btn],
queue=False,
).then(
fn=lambda x: gr.Button(
interactive=x is not None
), # Properly update button state
inputs=[latent_image_state],
outputs=[privacy_btn],
queue=False,
)
decode_btn.click(
fn=decode_latent_to_pixel,
inputs=[latent_image_state],
outputs=[decoded_image_display],
queue=True,
).then(
fn=lambda x: gr.Button(
interactive=x is not None
), # Properly update button state
inputs=[decoded_image_display],
outputs=[privacy_btn],
queue=False,
)
privacy_btn.click(
fn=check_privacy,
inputs=[latent_image_state, class_selection],
outputs=[filtered_latent_state, privacy_status],
queue=True,
).then(
fn=convert_latent_to_display,
inputs=[filtered_latent_state],
outputs=[filtered_latent_display],
queue=False,
).then(
fn=lambda x: gr.Button(
interactive=x is not None
), # Properly update button state
inputs=[filtered_latent_state],
outputs=[animate_btn],
queue=False,
)
animate_btn.click(
fn=generate_animation,
inputs=[filtered_latent_state, ef_slider, animation_steps, cfg_slider],
outputs=[latent_animation_state],
queue=True,
).then(
fn=latent_animation_to_grayscale,
inputs=[latent_animation_state],
outputs=[latent_animation_display],
queue=False,
).then(
fn=lambda x: gr.Button(
interactive=x is not None
), # Properly update button state
inputs=[latent_animation_state],
outputs=[decode_animation_btn],
queue=False,
)
decode_animation_btn.click(
fn=decode_animation,
inputs=[latent_animation_state], # Remove vae_state from inputs
outputs=[decoded_animation_display],
queue=True,
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch()