MaxMilan1 commited on
Commit
05d3d42
·
1 Parent(s): ad9ba71
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -2,6 +2,7 @@ import spaces
2
  import gradio as gr
3
  import torch
4
  from diffusers import DiffusionPipeline
 
5
 
6
  model_id = "stabilityai/stable-diffusion-2-1"
7
  pipe = DiffusionPipeline.from_pretrained(model_id)
@@ -12,8 +13,9 @@ pipe.to("cuda")
12
  @spaces.GPU
13
  def generate_image(prompt):
14
 
15
- images = pipe(prompt).images
16
- return images
 
17
 
18
  _TITLE = "Shoe Generator"
19
  with gr.Blocks(_TITLE) as ShoeGen:
@@ -22,10 +24,8 @@ with gr.Blocks(_TITLE) as ShoeGen:
22
  prompt = gr.Textbox(label="Enter a prompt")
23
  button_gen = gr.Button("Generate Image")
24
  with gr.Column():
25
- # show images
26
- gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2], height="auto")
27
-
28
 
29
- button_gen.click(generate_image, inputs=[prompt], outputs=gallery)
30
 
31
  ShoeGen.launch()
 
2
  import gradio as gr
3
  import torch
4
  from diffusers import DiffusionPipeline
5
+ import rembg
6
 
7
  model_id = "stabilityai/stable-diffusion-2-1"
8
  pipe = DiffusionPipeline.from_pretrained(model_id)
 
13
  @spaces.GPU
14
  def generate_image(prompt):
15
 
16
+ image = pipe(prompt).images
17
+ image = rembg.remove(image)
18
+ return image
19
 
20
  _TITLE = "Shoe Generator"
21
  with gr.Blocks(_TITLE) as ShoeGen:
 
24
  prompt = gr.Textbox(label="Enter a prompt")
25
  button_gen = gr.Button("Generate Image")
26
  with gr.Column():
27
+ image = gr.Image(label="Generated Image", show_download_button=True)
 
 
28
 
29
+ button_gen.click(generate_image, inputs=[prompt], outputs=image)
30
 
31
  ShoeGen.launch()