soiz1 commited on
Commit
a9b3658
·
verified ·
1 Parent(s): 8df85b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -176
app.py CHANGED
@@ -1,182 +1,133 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
- <head>
4
- <meta charset="UTF-8">
5
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
- <title>Highly Accurate Dichotomous Image Segmentation</title>
7
- <style>
8
- body {
9
- font-family: Arial, sans-serif;
10
- max-width: 800px;
11
- margin: 0 auto;
12
- padding: 20px;
13
- line-height: 1.6;
14
- }
15
- .container {
16
- display: flex;
17
- flex-direction: column;
18
- gap: 20px;
19
- }
20
- .upload-section {
21
- border: 2px dashed #ccc;
22
- padding: 20px;
23
- text-align: center;
24
- border-radius: 5px;
25
- }
26
- .results {
27
- display: flex;
28
- gap: 20px;
29
- flex-wrap: wrap;
30
- }
31
- .result-box {
32
- flex: 1;
33
- min-width: 300px;
34
- }
35
- img {
36
- max-width: 100%;
37
- height: auto;
38
- border: 1px solid #ddd;
39
- border-radius: 4px;
40
- }
41
- button {
42
- background-color: #4CAF50;
43
- color: white;
44
- padding: 10px 15px;
45
- border: none;
46
- border-radius: 4px;
47
- cursor: pointer;
48
- font-size: 16px;
49
- }
50
- button:hover {
51
- background-color: #45a049;
52
- }
53
- .code-block {
54
- background-color: #f5f5f5;
55
- padding: 15px;
56
- border-radius: 5px;
57
- overflow-x: auto;
58
- }
59
- </style>
60
- </head>
61
- <body>
62
- <div class="container">
63
- <h1>Highly Accurate Dichotomous Image Segmentation</h1>
64
- <p>This is a demo for DIS, a model that can remove the background from a given image.</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- <div class="upload-section">
67
- <h2>Upload Image</h2>
68
- <input type="file" id="imageInput" accept="image/*">
69
- <button onclick="processImage()">Remove Background</button>
70
- </div>
71
 
72
- <div class="results">
73
- <div class="result-box">
74
- <h3>Original Image</h3>
75
- <img id="originalImage" src="" alt="Original image will appear here" style="display: none;">
76
- </div>
77
- <div class="result-box">
78
- <h3>Result (RGBA)</h3>
79
- <img id="resultImage" src="" alt="Result will appear here" style="display: none;">
80
- </div>
81
- <div class="result-box">
82
- <h3>Mask</h3>
83
- <img id="maskImage" src="" alt="Mask will appear here" style="display: none;">
84
- </div>
85
- </div>
86
 
87
- <div>
88
- <h2>API Usage Example</h2>
89
- <p>You can also use the API directly with this JavaScript code:</p>
90
- <div class="code-block">
91
- <pre><code>
92
- async function removeBackground(imageFile) {
93
- const formData = new FormData();
94
- formData.append('image', imageFile);
95
-
96
- try {
97
- const response = await fetch('/api/remove_bg', {
98
- method: 'POST',
99
- body: formData
100
- });
101
 
102
- if (!response.ok) {
103
- throw new Error(`HTTP error! status: ${response.status}`);
104
- }
105
 
106
- const data = await response.json();
107
- console.log('Result:', data);
108
- return data;
109
- } catch (error) {
110
- console.error('Error:', error);
111
- throw error;
112
- }
113
- }
 
114
 
115
- // Usage example:
116
- // const fileInput = document.querySelector('input[type="file"]');
117
- // removeBackground(fileInput.files[0])
118
- // .then(data => {
119
- // // Handle response data
120
- // document.getElementById('resultImage').src = data.rgba_url;
121
- // document.getElementById('maskImage').src = data.mask_url;
122
- // });
123
- </code></pre>
124
- </div>
125
- </div>
126
- </div>
127
-
128
- <script>
129
- function processImage() {
130
- const fileInput = document.getElementById('imageInput');
131
- if (!fileInput.files || fileInput.files.length === 0) {
132
- alert('Please select an image first');
133
- return;
134
- }
135
-
136
- const file = fileInput.files[0];
137
- const reader = new FileReader();
138
-
139
- reader.onload = function(e) {
140
- document.getElementById('originalImage').src = e.target.result;
141
- document.getElementById('originalImage').style.display = 'block';
142
- };
143
- reader.readAsDataURL(file);
144
-
145
- removeBackground(file)
146
- .then(data => {
147
- document.getElementById('resultImage').src = data.rgba_url;
148
- document.getElementById('resultImage').style.display = 'block';
149
- document.getElementById('maskImage').src = data.mask_url;
150
- document.getElementById('maskImage').style.display = 'block';
151
- })
152
- .catch(error => {
153
- console.error('Error:', error);
154
- alert('An error occurred while processing the image');
155
- });
156
- }
157
-
158
- async function removeBackground(imageFile) {
159
- const formData = new FormData();
160
- formData.append('image', imageFile);
161
-
162
- try {
163
- const response = await fetch('/api/remove_bg', {
164
- method: 'POST',
165
- body: formData
166
- });
167
-
168
- if (!response.ok) {
169
- throw new Error(`HTTP error! status: ${response.status}`);
170
- }
171
-
172
- const data = await response.json();
173
- console.log('Result:', data);
174
- return data;
175
- } catch (error) {
176
- console.error('Error:', error);
177
- throw error;
178
- }
179
- }
180
- </script>
181
- </body>
182
- </html>
 
1
+ from flask import Flask, request, jsonify, render_template
2
+ import cv2
3
+ import os
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch
7
+ from torch.autograd import Variable
8
+ from torchvision import transforms
9
+ import torch.nn.functional as F
10
+ import uuid
11
+
12
+ app = Flask(__name__)
13
+
14
+ # モデル設定と初期化コード
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+
17
+ class GOSNormalize(object):
18
+ def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
19
+ self.mean = mean
20
+ self.std = std
21
+
22
+ def __call__(self,image):
23
+ image = normalize(image,self.mean,self.std)
24
+ return image
25
+
26
+ transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
27
+
28
+ def load_image(im_path, hypar):
29
+ im = im_reader(im_path)
30
+ im, im_shp = im_preprocess(im, hypar["cache_size"])
31
+ im = torch.divide(im,255.0)
32
+ shape = torch.from_numpy(np.array(im_shp))
33
+ return transform(im).unsqueeze(0), shape.unsqueeze(0)
34
+
35
+ def build_model(hypar,device):
36
+ net = hypar["model"]
37
+ if(hypar["model_digit"]=="half"):
38
+ net.half()
39
+ for layer in net.modules():
40
+ if isinstance(layer, nn.BatchNorm2d):
41
+ layer.float()
42
+ net.to(device)
43
+ if(hypar["restore_model"]!=""):
44
+ net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
45
+ net.to(device)
46
+ net.eval()
47
+ return net
48
+
49
+ def predict(net, inputs_val, shapes_val, hypar, device):
50
+ net.eval()
51
+ if(hypar["model_digit"]=="full"):
52
+ inputs_val = inputs_val.type(torch.FloatTensor)
53
+ else:
54
+ inputs_val = inputs_val.type(torch.HalfTensor)
55
+
56
+ inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
57
+ ds_val = net(inputs_val_v)[0]
58
+ pred_val = ds_val[0][0,:,:,:]
59
+ pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
60
+ ma = torch.max(pred_val)
61
+ mi = torch.min(pred_val)
62
+ pred_val = (pred_val-mi)/(ma-mi)
63
+ if device == 'cuda': torch.cuda.empty_cache()
64
+ return (pred_val.detach().cpu().numpy()*255).astype(np.uint8)
65
+
66
+ # モデル初期化
67
+ hypar = {
68
+ "model_path": "./saved_models",
69
+ "restore_model": "isnet.pth",
70
+ "interm_sup": False,
71
+ "model_digit": "full",
72
+ "seed": 0,
73
+ "cache_size": [1024, 1024],
74
+ "input_size": [1024, 1024],
75
+ "crop_size": [1024, 1024],
76
+ "model": ISNetDIS()
77
+ }
78
+
79
+ net = build_model(hypar, device)
80
+
81
+ # 結果を保存するディレクトリを作成
82
+ os.makedirs('static/results', exist_ok=True)
83
+
84
+ @app.route('/')
85
+ def index():
86
+ return render_template('index.html')
87
+
88
+ @app.route('/api/remove_bg', methods=['POST'])
89
+ def remove_bg():
90
+ if 'image' not in request.files:
91
+ return jsonify({'error': 'No image provided'}), 400
92
+
93
+ file = request.files['image']
94
+ if file.filename == '':
95
+ return jsonify({'error': 'No image selected'}), 400
96
+
97
+ # 一時ファイルとして保存
98
+ temp_path = f"static/temp_{uuid.uuid4().hex}.png"
99
+ file.save(temp_path)
100
+
101
+ try:
102
+ # 画像処理
103
+ image_tensor, orig_size = load_image(temp_path, hypar)
104
+ mask = predict(net, image_tensor, orig_size, hypar, device)
105
 
106
+ pil_mask = Image.fromarray(mask).convert('L')
107
+ im_rgb = Image.open(temp_path).convert("RGB")
 
 
 
108
 
109
+ # 結果を保存
110
+ result_id = uuid.uuid4().hex
111
+ rgba_path = f"static/results/{result_id}_rgba.png"
112
+ mask_path = f"static/results/{result_id}_mask.png"
 
 
 
 
 
 
 
 
 
 
113
 
114
+ im_rgba = im_rgb.copy()
115
+ im_rgba.putalpha(pil_mask)
116
+ im_rgba.save(rgba_path)
117
+ pil_mask.save(mask_path)
 
 
 
 
 
 
 
 
 
 
118
 
119
+ # 一時ファイルを削除
120
+ os.remove(temp_path)
 
121
 
122
+ return jsonify({
123
+ 'rgba_url': f"/{rgba_path}",
124
+ 'mask_url': f"/{mask_path}"
125
+ })
126
+ except Exception as e:
127
+ # エラーが発生したら一時ファイルを削除
128
+ if os.path.exists(temp_path):
129
+ os.remove(temp_path)
130
+ return jsonify({'error': str(e)}), 500
131
 
132
+ if __name__ == '__main__':
133
+ app.run(debug=True)