carrycooldude commited on
Commit
d58b45f
·
1 Parent(s): decb75d

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -0
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import gradio as gr
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ # Dummy generator function — Replace this with your real model inference!
10
+ def generate_image(seed):
11
+ key = jax.random.PRNGKey(seed)
12
+ # Generate a fake "image" of size 64x64x3 (RGB)
13
+ img = jax.random.uniform(key, (64, 64, 3), minval=0, maxval=1.0)
14
+ img_np = np.array(img * 255, dtype=np.uint8)
15
+ return Image.fromarray(img_np)
16
+
17
+ # Define Gradio Interface
18
+ iface = gr.Interface(
19
+ fn=generate_image,
20
+ inputs=gr.Slider(0, 10000, value=42, step=1, label="Random Seed"),
21
+ outputs=gr.Image(type="pil", label="Generated Image"),
22
+ title="JAX Diffusion Demo",
23
+ description="🎨 Generate random diffusion samples using JAX! \n\n(Replace dummy function with your trained model.)",
24
+ theme="default",
25
+ live=False
26
+ )
27
+
28
+ if __name__ == "__main__":
29
+ iface.launch()