jax-diffusion / app.py
carrycooldude's picture
Add app.py
bd0c1e9
raw
history blame
759 Bytes
import gradio as gr
import numpy as np
import jax
import jax.numpy as jnp
# Function that uses a seed to generate an image
def generate_image(seed):
key = jax.random.PRNGKey(seed)
img = jax.random.normal(key, (64, 64, 3))
img = (img - img.min()) / (img.max() - img.min()) # Normalize to [0, 1]
return np.array(img)
# Gradio Interface
iface = gr.Interface(
fn=generate_image,
inputs=gr.Number(label="Seed", value=0),
outputs=gr.Image(type="numpy", label="Generated Image"),
title="JAX Diffusion Demo",
description="Enter a seed and click 'Generate' to sample a random diffusion image!",
allow_flagging="never",
examples=[[42], [1234], [2025]] # Optional examples
)
if __name__ == "__main__":
iface.launch()