carrycooldude commited on
Commit
c771993
ยท
1 Parent(s): 9f382a2

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -20
app.py CHANGED
@@ -1,34 +1,30 @@
1
- import gradio as gr
2
- import numpy as np
3
  import jax
4
  import jax.numpy as jnp
 
 
 
5
 
6
- # Function to generate an image based on a seed
7
  def generate_image(seed):
8
- if seed is None:
9
- seed = 0 # default fallback
10
- key = jax.random.PRNGKey(int(seed))
11
- img = jax.random.normal(key, (64, 64, 3))
12
- img = (img - img.min()) / (img.max() - img.min()) # Normalize to [0, 1]
13
- return np.array(img)
14
 
15
- # Define the Gradio interface
16
  with gr.Blocks() as demo:
17
- gr.Markdown("# ๐ŸŒŸ JAX Diffusion Demo")
18
- gr.Markdown("Enter a random seed to generate a diffusion-based image.")
19
 
20
  with gr.Row():
21
- seed_input = gr.Number(label="Seed", value=0)
22
- generate_button = gr.Button("Generate")
23
-
24
- output_image = gr.Image(label="Generated Image")
25
 
26
  generate_button.click(
27
  fn=generate_image,
28
- inputs=seed_input,
29
  outputs=output_image
30
  )
31
 
32
- if __name__ == "__main__":
33
- demo.launch()
34
-
 
 
 
1
  import jax
2
  import jax.numpy as jnp
3
+ import gradio as gr
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
 
7
+ # Dummy diffusion sampling function
8
  def generate_image(seed):
9
+ key = jax.random.PRNGKey(seed)
10
+ img = jax.random.uniform(key, (64, 64, 3))
11
+ img = np.array(img) # Convert to numpy for display
12
+ return img
 
 
13
 
 
14
  with gr.Blocks() as demo:
15
+ gr.Markdown("# ๐ŸŒ€ JAX Diffusion Demo")
16
+ gr.Markdown("Generate random images using JAX diffusion (dummy example).")
17
 
18
  with gr.Row():
19
+ seed_slider = gr.Slider(minimum=0, maximum=10000, step=1, value=42, label="Seed")
20
+ generate_button = gr.Button("Generate Image")
21
+
22
+ output_image = gr.Image(type="numpy", label="Generated Image", shape=(64, 64))
23
 
24
  generate_button.click(
25
  fn=generate_image,
26
+ inputs=seed_slider,
27
  outputs=output_image
28
  )
29
 
30
+ demo.launch()