carrycooldude commited on
Commit
9341c6d
·
1 Parent(s): 712b05f

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -4,8 +4,8 @@ import jax
4
  import jax.numpy as jnp
5
 
6
  # Dummy simple function to simulate diffusion sampling
7
- def generate_image(seed):
8
- key = jax.random.PRNGKey(seed)
9
  img = jax.random.normal(key, (64, 64, 3))
10
  img = (img - img.min()) / (img.max() - img.min()) # Normalize to [0, 1]
11
  return np.array(img)
@@ -13,10 +13,12 @@ def generate_image(seed):
13
  # Gradio Interface
14
  iface = gr.Interface(
15
  fn=generate_image,
16
- inputs=gr.Number(label="Random Seed"),
17
  outputs=gr.Image(shape=(64, 64)),
18
  title="JAX Diffusion Demo",
19
- description="A simple diffusion-style random image generator using JAX.",
 
 
20
  )
21
 
22
  if __name__ == "__main__":
 
4
  import jax.numpy as jnp
5
 
6
  # Dummy simple function to simulate diffusion sampling
7
+ def generate_image():
8
+ key = jax.random.PRNGKey(np.random.randint(0, 10000))
9
  img = jax.random.normal(key, (64, 64, 3))
10
  img = (img - img.min()) / (img.max() - img.min()) # Normalize to [0, 1]
11
  return np.array(img)
 
13
  # Gradio Interface
14
  iface = gr.Interface(
15
  fn=generate_image,
16
+ inputs=[],
17
  outputs=gr.Image(shape=(64, 64)),
18
  title="JAX Diffusion Demo",
19
+ description="Click 'Generate' to sample a random diffusion image!",
20
+ allow_flagging="never",
21
+ live=False
22
  )
23
 
24
  if __name__ == "__main__":