carrycooldude commited on
Commit
252d51a
·
1 Parent(s): 4067e74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -23
app.py CHANGED
@@ -1,29 +1,33 @@
1
- # app.py
2
-
3
  import gradio as gr
4
  import jax
5
  import jax.numpy as jnp
6
- import numpy as np
7
- from PIL import Image
8
 
9
- # Dummy generator function Replace this with your real model inference!
10
- def generate_image(seed):
11
- key = jax.random.PRNGKey(seed)
12
- # Generate a fake "image" of size 64x64x3 (RGB)
13
- img = jax.random.uniform(key, (64, 64, 3), minval=0, maxval=1.0)
14
- img_np = np.array(img * 255, dtype=np.uint8)
15
- return Image.fromarray(img_np)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # Define Gradio Interface
18
- iface = gr.Interface(
19
- fn=generate_image,
20
- inputs=gr.Slider(0, 10000, value=42, step=1, label="Random Seed"),
21
- outputs=gr.Image(type="pil", label="Generated Image"),
22
- title="JAX Diffusion Demo",
23
- description="🎨 Generate random diffusion samples using JAX! \n\n(Replace dummy function with your trained model.)",
24
- theme="default",
25
- live=False
26
- )
27
 
28
- if __name__ == "__main__":
29
- iface.launch()
 
 
 
1
  import gradio as gr
2
  import jax
3
  import jax.numpy as jnp
 
 
4
 
5
+ # Dummy diffusion generation function (replace with your real one)
6
+ def generate_diffusion(prompt, steps):
7
+ key = jax.random.PRNGKey(0)
8
+ # For demo: Create random noise image
9
+ image = jax.random.uniform(key, (64, 64, 3))
10
+ image = jnp.clip(image, 0, 1)
11
+ return image
12
+
13
+ # Gradio Interface using Blocks
14
+ with gr.Blocks() as demo:
15
+ gr.Markdown("# 🌟 JAX Diffusion Demo")
16
+ gr.Markdown("Generate images using a simple diffusion model powered by **JAX**!")
17
+
18
+ with gr.Row():
19
+ with gr.Column():
20
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Describe your image...")
21
+ steps_input = gr.Slider(10, 100, value=50, step=5, label="Diffusion Steps")
22
+ generate_button = gr.Button("Generate")
23
+
24
+ with gr.Column():
25
+ output_image = gr.Image(label="Generated Image")
26
 
27
+ generate_button.click(
28
+ fn=generate_diffusion,
29
+ inputs=[prompt_input, steps_input],
30
+ outputs=output_image
31
+ )
 
 
 
 
 
32
 
33
+ demo.launch()