sjtu-deepvision commited on
Commit
9a81057
·
verified ·
1 Parent(s): cd178d4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -1,9 +1,10 @@
1
- import spaces # 必须放在最前面
2
  import os
3
  import numpy as np
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
 
7
 
8
  # 延迟 CUDA 初始化
9
  weight_dtype = torch.float32
@@ -58,7 +59,8 @@ def process_image(input_image):
58
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
59
  processed_frame = Image.fromarray(processed_frame)
60
 
61
- return processed_frame
 
62
 
63
  # 创建 Gradio 界面
64
  def create_gradio_interface():
@@ -74,13 +76,18 @@ def create_gradio_interface():
74
  input_image = gr.Image(label="Input Image", type="numpy")
75
  submit_btn = gr.Button("Remove Reflection", variant="primary")
76
  with gr.Column():
77
- output_image = gr.Image(label="Processed Image")
 
 
 
 
 
78
 
79
  # 添加示例
80
  gr.Examples(
81
  examples=example_images,
82
  inputs=input_image,
83
- outputs=output_image,
84
  fn=process_image,
85
  cache_examples=False, # 缓存结果以加快加载速度
86
  label="Example Images",
@@ -90,7 +97,7 @@ def create_gradio_interface():
90
  submit_btn.click(
91
  fn=process_image,
92
  inputs=input_image,
93
- outputs=output_image,
94
  )
95
 
96
  return demo
 
1
+ import spaces
2
  import os
3
  import numpy as np
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
7
+ from gradio_imageslider import ImageSlider # 导入 ImageSlider 组件
8
 
9
  # 延迟 CUDA 初始化
10
  weight_dtype = torch.float32
 
59
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
60
  processed_frame = Image.fromarray(processed_frame)
61
 
62
+ # 返回输入图像和处理后的图像
63
+ return input_image, processed_frame
64
 
65
  # 创建 Gradio 界面
66
  def create_gradio_interface():
 
76
  input_image = gr.Image(label="Input Image", type="numpy")
77
  submit_btn = gr.Button("Remove Reflection", variant="primary")
78
  with gr.Column():
79
+ # 使用 ImageSlider 显示前后对比
80
+ output_slider = ImageSlider(
81
+ label="Before & After",
82
+ show_download_button=True,
83
+ show_share_button=True,
84
+ )
85
 
86
  # 添加示例
87
  gr.Examples(
88
  examples=example_images,
89
  inputs=input_image,
90
+ outputs=output_slider,
91
  fn=process_image,
92
  cache_examples=False, # 缓存结果以加快加载速度
93
  label="Example Images",
 
97
  submit_btn.click(
98
  fn=process_image,
99
  inputs=input_image,
100
+ outputs=output_slider,
101
  )
102
 
103
  return demo