sjtu-deepvision commited on
Commit
7cdacae
·
verified ·
1 Parent(s): 1cedc13

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -18
app.py CHANGED
@@ -3,19 +3,24 @@ import numpy as np
3
  import torch
4
  from PIL import Image
5
  import gradio as gr
 
 
 
 
 
 
 
 
6
  from DAI.pipeline_all import DAIPipeline
7
  from DAI.controlnetvae import ControlNetVAEModel
8
  from DAI.decoder import CustomAutoencoderKL
9
  from diffusers import AutoencoderKL, UNet2DConditionModel
10
  from transformers import CLIPTextModel, AutoTokenizer
11
 
12
- # Initialize device and model paths
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- weight_dtype = torch.float32
15
  pretrained_model_name_or_path = "sjtu-deepvision/dereflection-any-image-v0"
16
  pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
17
 
18
- # Load the model components
19
  controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
20
  unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
21
  vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)
@@ -23,7 +28,7 @@ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path2, subfolder="v
23
  text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path2, subfolder="text_encoder").to(device)
24
  tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path2, subfolder="tokenizer", use_fast=False)
25
 
26
- # Create the pipeline
27
  pipe = DAIPipeline(
28
  vae=vae,
29
  text_encoder=text_encoder,
@@ -36,12 +41,13 @@ pipe = DAIPipeline(
36
  t_start=0,
37
  ).to(device)
38
 
39
- # Function to process the image
 
40
  def process_image(input_image):
41
- # Convert Gradio input to PIL Image
42
  input_image = Image.fromarray(input_image)
43
 
44
- # Process the image
45
  pipe_out = pipe(
46
  image=input_image,
47
  prompt="remove glass reflection",
@@ -49,16 +55,17 @@ def process_image(input_image):
49
  processing_resolution=None,
50
  )
51
 
52
- # Convert the output to an image
53
  processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
54
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
55
  processed_frame = Image.fromarray(processed_frame)
56
 
57
- return processed_frame
 
58
 
59
- # Gradio interface
60
  def create_gradio_interface():
61
- # Example images
62
  example_images = [
63
  os.path.join("files", "image", f"{i}.png") for i in range(1, 9)
64
  ]
@@ -70,27 +77,33 @@ def create_gradio_interface():
70
  input_image = gr.Image(label="Input Image", type="numpy")
71
  submit_btn = gr.Button("Remove Reflection", variant="primary")
72
  with gr.Column():
73
- output_image = gr.Image(label="Processed Image")
 
 
 
 
 
74
 
75
- # Add examples
76
  gr.Examples(
77
  examples=example_images,
78
  inputs=input_image,
79
- outputs=output_image,
80
  fn=process_image,
81
- cache_examples=False, # Cache results for faster loading
82
  label="Example Images",
83
  )
84
 
 
85
  submit_btn.click(
86
  fn=process_image,
87
  inputs=input_image,
88
- outputs=output_image,
89
  )
90
 
91
  return demo
92
 
93
- # Main function to launch the Gradio app
94
  def main():
95
  demo = create_gradio_interface()
96
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
3
  import torch
4
  from PIL import Image
5
  import gradio as gr
6
+ from gradio_imageslider import ImageSlider
7
+ import spaces # 必须放在最前面,确保 ZeroGPU 初始化
8
+
9
+ # 延迟 CUDA 初始化
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ weight_dtype = torch.float32
12
+
13
+ # 加载模型组件
14
  from DAI.pipeline_all import DAIPipeline
15
  from DAI.controlnetvae import ControlNetVAEModel
16
  from DAI.decoder import CustomAutoencoderKL
17
  from diffusers import AutoencoderKL, UNet2DConditionModel
18
  from transformers import CLIPTextModel, AutoTokenizer
19
 
 
 
 
20
  pretrained_model_name_or_path = "sjtu-deepvision/dereflection-any-image-v0"
21
  pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
22
 
23
+ # 加载模型
24
  controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
25
  unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype).to(device)
26
  vae_2 = CustomAutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae_2", torch_dtype=weight_dtype).to(device)
 
28
  text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path2, subfolder="text_encoder").to(device)
29
  tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path2, subfolder="tokenizer", use_fast=False)
30
 
31
+ # 创建推理管道
32
  pipe = DAIPipeline(
33
  vae=vae,
34
  text_encoder=text_encoder,
 
41
  t_start=0,
42
  ).to(device)
43
 
44
+ # 使用 spaces.GPU 包装推理函数
45
+ @spaces.GPU
46
  def process_image(input_image):
47
+ # Gradio 输入转换为 PIL 图像
48
  input_image = Image.fromarray(input_image)
49
 
50
+ # 处理图像
51
  pipe_out = pipe(
52
  image=input_image,
53
  prompt="remove glass reflection",
 
55
  processing_resolution=None,
56
  )
57
 
58
+ # 将输出转换为图像
59
  processed_frame = (pipe_out.prediction.clip(-1, 1) + 1) / 2
60
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
61
  processed_frame = Image.fromarray(processed_frame)
62
 
63
+ # 返回输入图像和处理后的图像
64
+ return input_image, processed_frame
65
 
66
+ # 创建 Gradio 界面
67
  def create_gradio_interface():
68
+ # 示例图像
69
  example_images = [
70
  os.path.join("files", "image", f"{i}.png") for i in range(1, 9)
71
  ]
 
77
  input_image = gr.Image(label="Input Image", type="numpy")
78
  submit_btn = gr.Button("Remove Reflection", variant="primary")
79
  with gr.Column():
80
+ # 使用 ImageSlider 显示前后对比
81
+ output_slider = ImageSlider(
82
+ label="Before & After",
83
+ show_download_button=True,
84
+ show_share_button=True,
85
+ )
86
 
87
+ # 添加示例
88
  gr.Examples(
89
  examples=example_images,
90
  inputs=input_image,
91
+ outputs=output_slider,
92
  fn=process_image,
93
+ cache_examples=False, # 缓存结果以加快加载速度
94
  label="Example Images",
95
  )
96
 
97
+ # 绑定按钮点击事件
98
  submit_btn.click(
99
  fn=process_image,
100
  inputs=input_image,
101
+ outputs=output_slider,
102
  )
103
 
104
  return demo
105
 
106
+ # 主函数
107
  def main():
108
  demo = create_gradio_interface()
109
  demo.launch(server_name="0.0.0.0", server_port=7860)