sjtu-deepvision commited on
Commit
68a7cdb
·
verified ·
1 Parent(s): c9620cb

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
 
7
 
8
  # 延迟 CUDA 初始化
9
  weight_dtype = torch.float32
@@ -45,11 +46,22 @@ def process_image(input_image, resolution_choice):
45
  # 将 Gradio 输入转换为 PIL 图像
46
  input_image = Image.fromarray(input_image)
47
 
48
- # 根据用户选择设置处理分辨率
49
  if resolution_choice == "768":
50
- processing_resolution = None
51
- else:
52
- processing_resolution = 0 # 使用原始分辨率
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # 处理图像
55
  pipe_out = pipe(
@@ -64,7 +76,7 @@ def process_image(input_image, resolution_choice):
64
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
65
  processed_frame = Image.fromarray(processed_frame)
66
 
67
- return processed_frame
68
 
69
  # 创建 Gradio 界面
70
  def create_gradio_interface():
@@ -88,17 +100,18 @@ def create_gradio_interface():
88
  value="768", # 默认选择原始分辨率
89
  )
90
  gr.Markdown(
91
- "Select the resolution for processing the image. Higher resolution may take longer to process."
92
  )
93
  submit_btn = gr.Button("Remove Reflection", variant="primary")
94
  with gr.Column():
95
- output_image = gr.Image(label="Processed Image")
 
96
 
97
  # 添加示例
98
  gr.Examples(
99
  examples=example_images,
100
  inputs=[input_image, resolution_choice], # 输入组件列表
101
- outputs=output_image,
102
  fn=process_image,
103
  cache_examples=False, # 缓存结果以加快加载速度
104
  label="Example Images",
@@ -108,7 +121,7 @@ def create_gradio_interface():
108
  submit_btn.click(
109
  fn=process_image,
110
  inputs=[input_image, resolution_choice], # 输入组件列表
111
- outputs=output_image,
112
  )
113
 
114
  return demo
 
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
7
+ from gradio_imageslider import ImageSlider
8
 
9
  # 延迟 CUDA 初始化
10
  weight_dtype = torch.float32
 
46
  # 将 Gradio 输入转换为 PIL 图像
47
  input_image = Image.fromarray(input_image)
48
 
49
+ # 如果 resolution_choice 为 '768',将 input_image resize 到最大边 768
50
  if resolution_choice == "768":
51
+ max_size = 768
52
+ width, height = input_image.size
53
+ if max(width, height) > max_size:
54
+ scaling_factor = max_size / max(width, height)
55
+ new_width = int(width * scaling_factor)
56
+ new_height = int(height * scaling_factor)
57
+ input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
58
+
59
+ # 根据用户选择设置处理分辨率
60
+ # if resolution_choice == "768":
61
+ # processing_resolution = None
62
+ # else:
63
+ # processing_resolution = 0 # 使用原始分辨率
64
+ processing_resolution = 0 # 使用原始分辨率
65
 
66
  # 处理图像
67
  pipe_out = pipe(
 
76
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
77
  processed_frame = Image.fromarray(processed_frame)
78
 
79
+ return input_image, processed_frame
80
 
81
  # 创建 Gradio 界面
82
  def create_gradio_interface():
 
100
  value="768", # 默认选择原始分辨率
101
  )
102
  gr.Markdown(
103
+ "Select the resolution for processing the image. Higher resolution may take longer to process. 768 is recommended for faster processing and stable performance."
104
  )
105
  submit_btn = gr.Button("Remove Reflection", variant="primary")
106
  with gr.Column():
107
+ # output_image = gr.Image(label="Processed Image")
108
+ output_slider = ImageSlider(label="Processed image", type="pil")
109
 
110
  # 添加示例
111
  gr.Examples(
112
  examples=example_images,
113
  inputs=[input_image, resolution_choice], # 输入组件列表
114
+ outputs=output_slider,
115
  fn=process_image,
116
  cache_examples=False, # 缓存结果以加快加载速度
117
  label="Example Images",
 
121
  submit_btn.click(
122
  fn=process_image,
123
  inputs=[input_image, resolution_choice], # 输入组件列表
124
+ outputs=output_slider,
125
  )
126
 
127
  return demo