File size: 759 Bytes
d58b45f
712b05f
d58b45f
 
 
bd0c1e9
 
 
712b05f
 
 
252d51a
712b05f
 
 
bd0c1e9
a470cd2
712b05f
bd0c1e9
 
 
712b05f
252d51a
712b05f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import gradio as gr
import numpy as np
import jax
import jax.numpy as jnp

# Function that uses a seed to generate an image
def generate_image(seed):
    key = jax.random.PRNGKey(seed)
    img = jax.random.normal(key, (64, 64, 3))
    img = (img - img.min()) / (img.max() - img.min())  # Normalize to [0, 1]
    return np.array(img)

# Gradio Interface
iface = gr.Interface(
    fn=generate_image,
    inputs=gr.Number(label="Seed", value=0),
    outputs=gr.Image(type="numpy", label="Generated Image"),
    title="JAX Diffusion Demo",
    description="Enter a seed and click 'Generate' to sample a random diffusion image!",
    allow_flagging="never",
    examples=[[42], [1234], [2025]]  # Optional examples
)

if __name__ == "__main__":
    iface.launch()