carrycooldude commited on
Commit
bd0c1e9
·
1 Parent(s): a470cd2

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -3,9 +3,9 @@ import numpy as np
3
  import jax
4
  import jax.numpy as jnp
5
 
6
- # Dummy simple function to simulate diffusion sampling
7
- def generate_image():
8
- key = jax.random.PRNGKey(np.random.randint(0, 10000))
9
  img = jax.random.normal(key, (64, 64, 3))
10
  img = (img - img.min()) / (img.max() - img.min()) # Normalize to [0, 1]
11
  return np.array(img)
@@ -13,11 +13,12 @@ def generate_image():
13
  # Gradio Interface
14
  iface = gr.Interface(
15
  fn=generate_image,
16
- inputs=[],
17
  outputs=gr.Image(type="numpy", label="Generated Image"),
18
  title="JAX Diffusion Demo",
19
- description="Click 'Generate' to sample a random diffusion image!",
20
- allow_flagging="never"
 
21
  )
22
 
23
  if __name__ == "__main__":
 
3
  import jax
4
  import jax.numpy as jnp
5
 
6
+ # Function that uses a seed to generate an image
7
+ def generate_image(seed):
8
+ key = jax.random.PRNGKey(seed)
9
  img = jax.random.normal(key, (64, 64, 3))
10
  img = (img - img.min()) / (img.max() - img.min()) # Normalize to [0, 1]
11
  return np.array(img)
 
13
  # Gradio Interface
14
  iface = gr.Interface(
15
  fn=generate_image,
16
+ inputs=gr.Number(label="Seed", value=0),
17
  outputs=gr.Image(type="numpy", label="Generated Image"),
18
  title="JAX Diffusion Demo",
19
+ description="Enter a seed and click 'Generate' to sample a random diffusion image!",
20
+ allow_flagging="never",
21
+ examples=[[42], [1234], [2025]] # Optional examples
22
  )
23
 
24
  if __name__ == "__main__":