sjtu-deepvision commited on
Commit
cd178d4
·
verified ·
1 Parent(s): c45eb57

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -14
app.py CHANGED
@@ -4,10 +4,8 @@ import numpy as np
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
7
- from gradio_imageslider import ImageSlider
8
 
9
  # 延迟 CUDA 初始化
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  weight_dtype = torch.float32
12
 
13
  # 加载模型组件
@@ -19,6 +17,7 @@ from transformers import CLIPTextModel, AutoTokenizer
19
 
20
  pretrained_model_name_or_path = "sjtu-deepvision/dereflection-any-image-v0"
21
  pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
 
22
 
23
  # 加载模型
24
  controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
@@ -41,7 +40,6 @@ pipe = DAIPipeline(
41
  t_start=0,
42
  ).to(device)
43
 
44
- # 使用 spaces.GPU 包装推理函数
45
  @spaces.GPU
46
  def process_image(input_image):
47
  # 将 Gradio 输入转换为 PIL 图像
@@ -60,8 +58,7 @@ def process_image(input_image):
60
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
61
  processed_frame = Image.fromarray(processed_frame)
62
 
63
- # 返回输入图像和处理后的图像
64
- return input_image, processed_frame
65
 
66
  # 创建 Gradio 界面
67
  def create_gradio_interface():
@@ -77,18 +74,13 @@ def create_gradio_interface():
77
  input_image = gr.Image(label="Input Image", type="numpy")
78
  submit_btn = gr.Button("Remove Reflection", variant="primary")
79
  with gr.Column():
80
- # 使用 ImageSlider 显示前后对比
81
- output_slider = ImageSlider(
82
- label="Before & After",
83
- show_download_button=True,
84
- show_share_button=True,
85
- )
86
 
87
  # 添加示例
88
  gr.Examples(
89
  examples=example_images,
90
  inputs=input_image,
91
- outputs=output_slider,
92
  fn=process_image,
93
  cache_examples=False, # 缓存结果以加快加载速度
94
  label="Example Images",
@@ -98,7 +90,7 @@ def create_gradio_interface():
98
  submit_btn.click(
99
  fn=process_image,
100
  inputs=input_image,
101
- outputs=output_slider,
102
  )
103
 
104
  return demo
@@ -106,7 +98,7 @@ def create_gradio_interface():
106
  # 主函数
107
  def main():
108
  demo = create_gradio_interface()
109
- demo.queue().launch(show_api=False)
110
 
111
  if __name__ == "__main__":
112
  main()
 
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
 
7
 
8
  # 延迟 CUDA 初始化
 
9
  weight_dtype = torch.float32
10
 
11
  # 加载模型组件
 
17
 
18
  pretrained_model_name_or_path = "sjtu-deepvision/dereflection-any-image-v0"
19
  pretrained_model_name_or_path2 = "stabilityai/stable-diffusion-2-1"
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
  # 加载模型
23
  controlnet = ControlNetVAEModel.from_pretrained(pretrained_model_name_or_path, subfolder="controlnet", torch_dtype=weight_dtype).to(device)
 
40
  t_start=0,
41
  ).to(device)
42
 
 
43
  @spaces.GPU
44
  def process_image(input_image):
45
  # 将 Gradio 输入转换为 PIL 图像
 
58
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
59
  processed_frame = Image.fromarray(processed_frame)
60
 
61
+ return processed_frame
 
62
 
63
  # 创建 Gradio 界面
64
  def create_gradio_interface():
 
74
  input_image = gr.Image(label="Input Image", type="numpy")
75
  submit_btn = gr.Button("Remove Reflection", variant="primary")
76
  with gr.Column():
77
+ output_image = gr.Image(label="Processed Image")
 
 
 
 
 
78
 
79
  # 添加示例
80
  gr.Examples(
81
  examples=example_images,
82
  inputs=input_image,
83
+ outputs=output_image,
84
  fn=process_image,
85
  cache_examples=False, # 缓存结果以加快加载速度
86
  label="Example Images",
 
90
  submit_btn.click(
91
  fn=process_image,
92
  inputs=input_image,
93
+ outputs=output_image,
94
  )
95
 
96
  return demo
 
98
  # 主函数
99
  def main():
100
  demo = create_gradio_interface()
101
+ demo.launch(server_name="0.0.0.0", server_port=7860)
102
 
103
  if __name__ == "__main__":
104
  main()