RishabA commited on
Commit
8fb1bf9
·
verified ·
1 Parent(s): 7c82dff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -8,6 +8,10 @@ from model import (
8
  from huggingface_hub import hf_hub_download
9
  import json
10
 
 
 
 
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  # Download config and checkpoint files from HF Hub
@@ -50,7 +54,6 @@ def generate_image(text_prompt, mask_upload):
50
  This function returns a generator that yields an intermediate
51
  decoded image at every timestep from the diffusion process.
52
  """
53
- # sample_ddpm_inference is assumed to be a generator function (using yield)
54
  return sample_ddpm_inference(unet, vae, text_prompt, mask_upload, device)
55
 
56
 
@@ -80,14 +83,12 @@ with gr.Blocks(css=css_str) as demo:
80
  )
81
  mask_input = gr.Image(type="pil", label="Optional Mask for Conditioning")
82
  generate_button = gr.Button("Generate Image")
83
- output_image = gr.Image(label="Generated Image", type="pil")
84
 
85
- # Adding stream=True allows Gradio to process the generator output
86
  generate_button.click(
87
  fn=generate_image,
88
  inputs=[text_input, mask_input],
89
- outputs=output_image,
90
- stream=True,
91
  )
92
 
93
  if __name__ == "__main__":
 
8
  from huggingface_hub import hf_hub_download
9
  import json
10
 
11
+
12
+ print("Gradio version:", gr.__version__)
13
+
14
+
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
  # Download config and checkpoint files from HF Hub
 
54
  This function returns a generator that yields an intermediate
55
  decoded image at every timestep from the diffusion process.
56
  """
 
57
  return sample_ddpm_inference(unet, vae, text_prompt, mask_upload, device)
58
 
59
 
 
83
  )
84
  mask_input = gr.Image(type="pil", label="Optional Mask for Conditioning")
85
  generate_button = gr.Button("Generate Image")
86
+ # output_image = gr.Image(label="Generated Image", type="pil")
87
 
 
88
  generate_button.click(
89
  fn=generate_image,
90
  inputs=[text_input, mask_input],
91
+ outputs="image",
 
92
  )
93
 
94
  if __name__ == "__main__":