carrycooldude commited on
Commit
9f382a2
·
1 Parent(s): bd0c1e9

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -13
app.py CHANGED
@@ -3,23 +3,32 @@ import numpy as np
3
  import jax
4
  import jax.numpy as jnp
5
 
6
- # Function that uses a seed to generate an image
7
  def generate_image(seed):
8
- key = jax.random.PRNGKey(seed)
 
 
9
  img = jax.random.normal(key, (64, 64, 3))
10
  img = (img - img.min()) / (img.max() - img.min()) # Normalize to [0, 1]
11
  return np.array(img)
12
 
13
- # Gradio Interface
14
- iface = gr.Interface(
15
- fn=generate_image,
16
- inputs=gr.Number(label="Seed", value=0),
17
- outputs=gr.Image(type="numpy", label="Generated Image"),
18
- title="JAX Diffusion Demo",
19
- description="Enter a seed and click 'Generate' to sample a random diffusion image!",
20
- allow_flagging="never",
21
- examples=[[42], [1234], [2025]] # Optional examples
22
- )
 
 
 
 
 
 
23
 
24
  if __name__ == "__main__":
25
- iface.launch()
 
 
3
  import jax
4
  import jax.numpy as jnp
5
 
6
+ # Function to generate an image based on a seed
7
  def generate_image(seed):
8
+ if seed is None:
9
+ seed = 0 # default fallback
10
+ key = jax.random.PRNGKey(int(seed))
11
  img = jax.random.normal(key, (64, 64, 3))
12
  img = (img - img.min()) / (img.max() - img.min()) # Normalize to [0, 1]
13
  return np.array(img)
14
 
15
+ # Define the Gradio interface
16
+ with gr.Blocks() as demo:
17
+ gr.Markdown("# 🌟 JAX Diffusion Demo")
18
+ gr.Markdown("Enter a random seed to generate a diffusion-based image.")
19
+
20
+ with gr.Row():
21
+ seed_input = gr.Number(label="Seed", value=0)
22
+ generate_button = gr.Button("Generate")
23
+
24
+ output_image = gr.Image(label="Generated Image")
25
+
26
+ generate_button.click(
27
+ fn=generate_image,
28
+ inputs=seed_input,
29
+ outputs=output_image
30
+ )
31
 
32
  if __name__ == "__main__":
33
+ demo.launch()
34
+