File size: 891 Bytes
d58b45f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# app.py

import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image

# Dummy generator function — Replace this with your real model inference!
def generate_image(seed):
    key = jax.random.PRNGKey(seed)
    # Generate a fake "image" of size 64x64x3 (RGB)
    img = jax.random.uniform(key, (64, 64, 3), minval=0, maxval=1.0)
    img_np = np.array(img * 255, dtype=np.uint8)
    return Image.fromarray(img_np)

# Define Gradio Interface
iface = gr.Interface(
    fn=generate_image,
    inputs=gr.Slider(0, 10000, value=42, step=1, label="Random Seed"),
    outputs=gr.Image(type="pil", label="Generated Image"),
    title="JAX Diffusion Demo",
    description="🎨 Generate random diffusion samples using JAX! \n\n(Replace dummy function with your trained model.)",
    theme="default",
    live=False
)

if __name__ == "__main__":
    iface.launch()