import jax import jax.numpy as jnp import gradio as gr import numpy as np # Dummy diffusion sampling function def generate_image(seed): key = jax.random.PRNGKey(seed) img = jax.random.uniform(key, (64, 64, 3)) img = np.array(img) # Convert to numpy for display return img with gr.Blocks() as demo: gr.Markdown("# 🌀 JAX Diffusion Demo") gr.Markdown("Generate random images using JAX diffusion (dummy example).") with gr.Row(): seed_slider = gr.Slider(minimum=0, maximum=10000, step=1, value=42, label="Seed") generate_button = gr.Button("Generate Image") output_image = gr.Image(type="numpy", label="Generated Image") # 🚫 Removed shape=(64, 64) generate_button.click( fn=generate_image, inputs=seed_slider, outputs=output_image ) demo.launch()