JichenHu commited on
Commit
e0b6c9b
·
verified ·
1 Parent(s): 9a644b2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -152
app.py CHANGED
@@ -1,109 +1,18 @@
1
- from __future__ import annotations
2
-
3
- import functools
4
  import os
5
- import tempfile
6
- import logging
7
-
8
  import gradio as gr
9
- import numpy as np
10
-
11
  from PIL import Image
12
  from gradio_imageslider import ImageSlider
13
- from tqdm import tqdm
14
-
15
- from pathlib import Path
16
- import gradio
17
- from gradio.utils import get_cache_folder
18
- from DAI.pipeline_all import DAIPipeline
19
-
20
- from diffusers import (
21
- AutoencoderKL,
22
- UNet2DConditionModel,
23
- )
24
-
25
- from transformers import CLIPTextModel, AutoTokenizer
26
-
27
- from DAI.controlnetvae import ControlNetVAEModel
28
-
29
- from DAI.decoder import CustomAutoencoderKL
30
-
31
- import torch
32
-
33
-
34
- class Examples(gradio.helpers.Examples):
35
- def __init__(self, *args, directory_name=None, **kwargs):
36
- super().__init__(*args, **kwargs, _initiated_directly=False)
37
- if directory_name is not None:
38
- self.cached_folder = get_cache_folder() / directory_name
39
- self.cached_file = Path(self.cached_folder) / "log.csv"
40
- self.create()
41
-
42
-
43
- default_seed = 2024
44
- default_batch_size = 1
45
-
46
-
47
- def process_image_check(path_input):
48
- logging.info(f"Input image path: {path_input}")
49
- if path_input is None:
50
- raise gr.Error(
51
- "Missing image in the first pane: upload a file or use one from the gallery below."
52
- )
53
-
54
- def resize_image(input_image, resolution):
55
- if not isinstance(input_image, Image.Image):
56
- raise ValueError("input_image should be a PIL Image object")
57
 
58
- input_image_np = np.asarray(input_image)
59
- H, W, C = input_image_np.shape
60
- H = float(H)
61
- W = float(W)
62
-
63
- k = float(resolution) / min(H, W)
64
-
65
- H *= k
66
- W *= k
67
- H = int(np.round(H / 64.0)) * 64
68
- W = int(np.round(W / 64.0)) * 64
69
-
70
- img = input_image.resize((W, H), Image.Resampling.LANCZOS)
71
-
72
- return img
73
 
74
- def process_image(
75
- pipe,
76
- vae_2,
77
- path_input,
78
- ):
79
  try:
80
- name_base, name_ext = os.path.splitext(os.path.basename(path_input))
81
- logging.info(f"Processing image {name_base}{name_ext}")
82
-
83
- path_output_dir = tempfile.mkdtemp()
84
- path_out_png = os.path.join(path_output_dir, f"{name_base}_delight.png")
85
  input_image = Image.open(path_input)
86
-
87
- pipe_out = pipe(
88
- image=input_image,
89
- prompt="remove glass reflection",
90
- vae_2=vae_2,
91
- processing_resolution=None,
92
- )
93
-
94
- processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
95
- processed_frame = (processed_frame[0] * 255).astype(np.uint8)
96
- processed_frame = Image.fromarray(processed_frame)
97
- processed_frame.save(path_out_png)
98
- yield [input_image, path_out_png]
99
  except Exception as e:
100
- logging.error(f"Error processing image: {e}")
101
- yield [None, None]
102
 
103
 
104
- def run_demo_server(pipe, vae_2):
105
- process_pipe_image = functools.partial(process_image, pipe, vae_2)
106
-
107
  gradio_theme = gr.themes.Default()
108
 
109
  with gr.Blocks(
@@ -127,7 +36,7 @@ def run_demo_server(pipe, vae_2):
127
  """
128
  # Dereflection Any Image
129
  <p align="center">
130
- Upload an image to remove reflections.
131
  </p>
132
  """
133
  )
@@ -142,7 +51,7 @@ def run_demo_server(pipe, vae_2):
142
  )
143
  with gr.Row():
144
  image_submit_btn = gr.Button(
145
- value="Remove Reflection", variant="primary"
146
  )
147
  image_reset_btn = gr.Button(value="Reset")
148
  with gr.Column():
@@ -154,8 +63,7 @@ def run_demo_server(pipe, vae_2):
154
  elem_classes="slider",
155
  )
156
 
157
- Examples(
158
- fn=process_pipe_image,
159
  examples=sorted([
160
  os.path.join("files", "image", name)
161
  for name in os.listdir(os.path.join("files", "image"))
@@ -163,23 +71,13 @@ def run_demo_server(pipe, vae_2):
163
  ]),
164
  inputs=[image_input],
165
  outputs=[image_output_slider],
166
- cache_examples=True,
167
- directory_name="examples_image",
168
  )
169
 
170
  image_submit_btn.click(
171
- fn=process_image_check,
172
  inputs=image_input,
173
- outputs=None,
174
- preprocess=False,
175
- queue=False,
176
- ).success(
177
- fn=process_pipe_image,
178
- inputs=[
179
- image_input,
180
- ],
181
  outputs=[image_output_slider],
182
- concurrency_limit=1,
183
  )
184
 
185
  image_reset_btn.click(
@@ -192,49 +90,10 @@ def run_demo_server(pipe, vae_2):
192
  image_input,
193
  image_output_slider,
194
  ],
195
- queue=False,
196
- )
197
-
198
- demo.queue(api_open=True).launch(share=False)
199
-
200
-
201
- def main():
202
- pretrained_model_name_or_path = "JichenHu/dereflection-any-image-v0"
203
- pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
204
- revision = None
205
- variant = None
206
-
207
- controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=torch.float32)
208
- unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch.float32)
209
- vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=torch.float32)
210
-
211
- vae = AutoencoderKL.from_pretrained(
212
- pretrained_model_name_or_path2, subfolder="vae", revision=revision, variant=variant
213
- )
214
-
215
- text_encoder = CLIPTextModel.from_pretrained(
216
- pretrained_model_name_or_path2, subfolder="text_encoder", revision=revision, variant=variant
217
- )
218
- tokenizer = AutoTokenizer.from_pretrained(
219
- pretrained_model_name_or_path2,
220
- subfolder="tokenizer",
221
- revision=revision,
222
- use_fast=False,
223
- )
224
- pipe = DAIPipeline(
225
- vae=vae,
226
- text_encoder=text_encoder,
227
- tokenizer=tokenizer,
228
- unet=unet,
229
- controlnet=controlnet,
230
- safety_checker=None,
231
- scheduler=None,
232
- feature_extractor=None,
233
- t_start=0,
234
  )
235
 
236
- run_demo_server(pipe, vae_2)
237
 
238
 
239
  if __name__ == "__main__":
240
- main()
 
 
 
 
1
  import os
 
 
 
2
  import gradio as gr
 
 
3
  from PIL import Image
4
  from gradio_imageslider import ImageSlider
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ def process_image_direct(path_input):
 
 
 
 
8
  try:
 
 
 
 
 
9
  input_image = Image.open(path_input)
10
+ return [input_image, path_input]
 
 
 
 
 
 
 
 
 
 
 
 
11
  except Exception as e:
12
+ return [None, None]
 
13
 
14
 
15
+ def run_demo_server():
 
 
16
  gradio_theme = gr.themes.Default()
17
 
18
  with gr.Blocks(
 
36
  """
37
  # Dereflection Any Image
38
  <p align="center">
39
+ Upload an image to display it directly.
40
  </p>
41
  """
42
  )
 
51
  )
52
  with gr.Row():
53
  image_submit_btn = gr.Button(
54
+ value="Display Image", variant="primary"
55
  )
56
  image_reset_btn = gr.Button(value="Reset")
57
  with gr.Column():
 
63
  elem_classes="slider",
64
  )
65
 
66
+ gr.Examples(
 
67
  examples=sorted([
68
  os.path.join("files", "image", name)
69
  for name in os.listdir(os.path.join("files", "image"))
 
71
  ]),
72
  inputs=[image_input],
73
  outputs=[image_output_slider],
74
+ cache_examples=False,
 
75
  )
76
 
77
  image_submit_btn.click(
78
+ fn=process_image_direct,
79
  inputs=image_input,
 
 
 
 
 
 
 
 
80
  outputs=[image_output_slider],
 
81
  )
82
 
83
  image_reset_btn.click(
 
90
  image_input,
91
  image_output_slider,
92
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  )
94
 
95
+ demo.launch(share=False)
96
 
97
 
98
  if __name__ == "__main__":
99
+ run_demo_server()