sjtu-deepvision commited on
Commit
6d71d5c
·
verified ·
1 Parent(s): 43b9d60

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -21
app.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
7
- from gradio_imageslider import ImageSlider
8
 
9
  # 延迟 CUDA 初始化
10
  weight_dtype = torch.float32
@@ -46,22 +45,11 @@ def process_image(input_image, resolution_choice):
46
  # 将 Gradio 输入转换为 PIL 图像
47
  input_image = Image.fromarray(input_image)
48
 
49
- # 如果 resolution_choice 为 '768',将 input_image resize 到最大边 768
50
- if resolution_choice == "768":
51
- max_size = 768
52
- width, height = input_image.size
53
- if max(width, height) > max_size:
54
- scaling_factor = max_size / max(width, height)
55
- new_width = int(width * scaling_factor)
56
- new_height = int(height * scaling_factor)
57
- input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
58
-
59
  # 根据用户选择设置处理分辨率
60
- # if resolution_choice == "768":
61
- # processing_resolution = None
62
- # else:
63
- # processing_resolution = 0 # 使用原始分辨率
64
- processing_resolution = 0 # 使用原始分辨率
65
 
66
  # 处理图像
67
  pipe_out = pipe(
@@ -76,7 +64,7 @@ def process_image(input_image, resolution_choice):
76
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
77
  processed_frame = Image.fromarray(processed_frame)
78
 
79
- return input_image, processed_frame
80
 
81
  # 创建 Gradio 界面
82
  def create_gradio_interface():
@@ -104,14 +92,13 @@ def create_gradio_interface():
104
  )
105
  submit_btn = gr.Button("Remove Reflection", variant="primary")
106
  with gr.Column():
107
- # output_image = gr.Image(label="Processed Image")
108
- output_slider = ImageSlider(label="Processed image", type="pil")
109
 
110
  # 添加示例
111
  gr.Examples(
112
  examples=example_images,
113
  inputs=[input_image, resolution_choice], # 输入组件列表
114
- outputs=output_slider,
115
  fn=process_image,
116
  cache_examples=False, # 缓存结果以加快加载速度
117
  label="Example Images",
@@ -121,7 +108,7 @@ def create_gradio_interface():
121
  submit_btn.click(
122
  fn=process_image,
123
  inputs=[input_image, resolution_choice], # 输入组件列表
124
- outputs=output_slider,
125
  )
126
 
127
  return demo
 
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
 
7
 
8
  # 延迟 CUDA 初始化
9
  weight_dtype = torch.float32
 
45
  # 将 Gradio 输入转换为 PIL 图像
46
  input_image = Image.fromarray(input_image)
47
 
 
 
 
 
 
 
 
 
 
 
48
  # 根据用户选择设置处理分辨率
49
+ if resolution_choice == "768":
50
+ processing_resolution = None
51
+ else:
52
+ processing_resolution = 0 # 使用原始分辨率
 
53
 
54
  # 处理图像
55
  pipe_out = pipe(
 
64
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
65
  processed_frame = Image.fromarray(processed_frame)
66
 
67
+ return processed_frame
68
 
69
  # 创建 Gradio 界面
70
  def create_gradio_interface():
 
92
  )
93
  submit_btn = gr.Button("Remove Reflection", variant="primary")
94
  with gr.Column():
95
+ output_image = gr.Image(label="Processed Image")
 
96
 
97
  # 添加示例
98
  gr.Examples(
99
  examples=example_images,
100
  inputs=[input_image, resolution_choice], # 输入组件列表
101
+ outputs=output_image,
102
  fn=process_image,
103
  cache_examples=False, # 缓存结果以加快加载速度
104
  label="Example Images",
 
108
  submit_btn.click(
109
  fn=process_image,
110
  inputs=[input_image, resolution_choice], # 输入组件列表
111
+ outputs=output_image,
112
  )
113
 
114
  return demo