soiz1 commited on
Commit
e0d04f2
·
verified ·
1 Parent(s): 6b39384

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -54
app.py CHANGED
@@ -1,4 +1,3 @@
1
- from flask import Flask, request, jsonify, render_template
2
  import cv2
3
  import os
4
  from PIL import Image
@@ -7,18 +6,32 @@ import torch
7
  from torch.autograd import Variable
8
  from torchvision import transforms
9
  import torch.nn.functional as F
10
- import uuid
11
-
12
- import gdown
13
- import matplotlib.pyplot as plt
14
  import warnings
 
15
 
16
- app = Flask(__name__)
 
 
 
 
 
 
 
17
 
18
- # モデル設定と初期化コード
19
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
 
 
 
 
 
 
 
21
  class GOSNormalize(object):
 
 
 
22
  def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
23
  self.mean = mean
24
  self.std = std
@@ -34,16 +47,20 @@ def load_image(im_path, hypar):
34
  im, im_shp = im_preprocess(im, hypar["cache_size"])
35
  im = torch.divide(im,255.0)
36
  shape = torch.from_numpy(np.array(im_shp))
37
- return transform(im).unsqueeze(0), shape.unsqueeze(0)
38
 
39
  def build_model(hypar,device):
40
- net = hypar["model"]
 
 
41
  if(hypar["model_digit"]=="half"):
42
  net.half()
43
  for layer in net.modules():
44
  if isinstance(layer, nn.BatchNorm2d):
45
  layer.float()
 
46
  net.to(device)
 
47
  if(hypar["restore_model"]!=""):
48
  net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
49
  net.to(device)
@@ -51,46 +68,60 @@ def build_model(hypar,device):
51
  return net
52
 
53
  def predict(net, inputs_val, shapes_val, hypar, device):
 
 
 
54
  net.eval()
 
55
  if(hypar["model_digit"]=="full"):
56
  inputs_val = inputs_val.type(torch.FloatTensor)
57
  else:
58
  inputs_val = inputs_val.type(torch.HalfTensor)
59
-
60
- inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
61
- ds_val = net(inputs_val_v)[0]
62
- pred_val = ds_val[0][0,:,:,:]
 
 
 
63
  pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
 
64
  ma = torch.max(pred_val)
65
  mi = torch.min(pred_val)
66
- pred_val = (pred_val-mi)/(ma-mi)
 
67
  if device == 'cuda': torch.cuda.empty_cache()
68
- return (pred_val.detach().cpu().numpy()*255).astype(np.uint8)
69
-
70
- # モデル初期化
71
- hypar = {
72
- "model_path": "./saved_models",
73
- "restore_model": "isnet.pth",
74
- "interm_sup": False,
75
- "model_digit": "full",
76
- "seed": 0,
77
- "cache_size": [1024, 1024],
78
- "input_size": [1024, 1024],
79
- "crop_size": [1024, 1024],
80
- "model": ISNetDIS()
81
- }
82
 
 
83
  net = build_model(hypar, device)
84
 
85
- # 結果を保存するディレクトリを作成
86
- os.makedirs('static/results', exist_ok=True)
 
 
 
 
 
87
 
88
- @app.route('/')
89
  def index():
90
  return render_template('index.html')
91
 
92
  @app.route('/api/remove_bg', methods=['POST'])
93
- def remove_bg():
94
  if 'image' not in request.files:
95
  return jsonify({'error': 'No image provided'}), 400
96
 
@@ -98,40 +129,42 @@ def remove_bg():
98
  if file.filename == '':
99
  return jsonify({'error': 'No image selected'}), 400
100
 
101
- # 一時ファイルとして保存
102
- temp_path = f"static/temp_{uuid.uuid4().hex}.png"
103
- file.save(temp_path)
104
 
105
  try:
106
- # 画像処理
107
- image_tensor, orig_size = load_image(temp_path, hypar)
108
  mask = predict(net, image_tensor, orig_size, hypar, device)
109
 
 
110
  pil_mask = Image.fromarray(mask).convert('L')
111
- im_rgb = Image.open(temp_path).convert("RGB")
112
-
113
- # 結果を保存
114
- result_id = uuid.uuid4().hex
115
- rgba_path = f"static/results/{result_id}_rgba.png"
116
- mask_path = f"static/results/{result_id}_mask.png"
117
-
118
  im_rgba = im_rgb.copy()
119
  im_rgba.putalpha(pil_mask)
120
- im_rgba.save(rgba_path)
121
- pil_mask.save(mask_path)
122
 
123
- # 一時ファイルを削除
124
- os.remove(temp_path)
 
 
 
 
125
 
126
  return jsonify({
127
- 'rgba_url': f"/{rgba_path}",
128
- 'mask_url': f"/{mask_path}"
129
  })
130
  except Exception as e:
131
- # エラーが発生したら一時ファイルを削除
132
- if os.path.exists(temp_path):
133
- os.remove(temp_path)
134
  return jsonify({'error': str(e)}), 500
135
 
 
 
 
 
 
 
 
 
136
  if __name__ == '__main__':
137
- app.run(debug=True)
 
 
1
  import cv2
2
  import os
3
  from PIL import Image
 
6
  from torch.autograd import Variable
7
  from torchvision import transforms
8
  import torch.nn.functional as F
9
+ from flask import Flask, request, jsonify, render_template, send_from_directory
 
 
 
10
  import warnings
11
+ warnings.filterwarnings("ignore")
12
 
13
+ # Clone repository and setup (only run once)
14
+ if not os.path.exists("DIS"):
15
+ os.system("git clone https://github.com/xuebinqin/DIS")
16
+ os.system("mv DIS/IS-Net/* .")
17
+
18
+ # Project imports
19
+ from data_loader_cache import normalize, im_reader, im_preprocess
20
+ from models import *
21
 
22
+ # Setup device
23
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
 
25
+ # Download official weights if not exists
26
+ if not os.path.exists("saved_models"):
27
+ os.mkdir("saved_models")
28
+ if not os.path.exists("saved_models/isnet.pth"):
29
+ os.system("mv isnet.pth saved_models/")
30
+
31
  class GOSNormalize(object):
32
+ '''
33
+ Normalize the Image using torch.transforms
34
+ '''
35
  def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
36
  self.mean = mean
37
  self.std = std
 
47
  im, im_shp = im_preprocess(im, hypar["cache_size"])
48
  im = torch.divide(im,255.0)
49
  shape = torch.from_numpy(np.array(im_shp))
50
+ return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
51
 
52
  def build_model(hypar,device):
53
+ net = hypar["model"]#GOSNETINC(3,1)
54
+
55
+ # convert to half precision
56
  if(hypar["model_digit"]=="half"):
57
  net.half()
58
  for layer in net.modules():
59
  if isinstance(layer, nn.BatchNorm2d):
60
  layer.float()
61
+
62
  net.to(device)
63
+
64
  if(hypar["restore_model"]!=""):
65
  net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
66
  net.to(device)
 
68
  return net
69
 
70
  def predict(net, inputs_val, shapes_val, hypar, device):
71
+ '''
72
+ Given an Image, predict the mask
73
+ '''
74
  net.eval()
75
+
76
  if(hypar["model_digit"]=="full"):
77
  inputs_val = inputs_val.type(torch.FloatTensor)
78
  else:
79
  inputs_val = inputs_val.type(torch.HalfTensor)
80
+
81
+ inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
82
+ ds_val = net(inputs_val_v)[0] # list of 6 results
83
+
84
+ pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
85
+
86
+ ## recover the prediction spatial size to the orignal image size
87
  pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
88
+
89
  ma = torch.max(pred_val)
90
  mi = torch.min(pred_val)
91
+ pred_val = (pred_val-mi)/(ma-mi) # max = 1
92
+
93
  if device == 'cuda': torch.cuda.empty_cache()
94
+ return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
95
+
96
+ # Set Parameters
97
+ hypar = {} # paramters for inferencing
98
+ hypar["model_path"] ="./saved_models" ## load trained weights from this path
99
+ hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
100
+ hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
101
+ hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
102
+ hypar["seed"] = 0
103
+ hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution
104
+ hypar["input_size"] = [1024, 1024] ## model input spatial size
105
+ hypar["crop_size"] = [1024, 1024] ## random crop size from the input
106
+ hypar["model"] = ISNetDIS()
 
107
 
108
+ # Build Model
109
  net = build_model(hypar, device)
110
 
111
+ # Flask app
112
+ app = Flask(__name__)
113
+ app.config['UPLOAD_FOLDER'] = 'uploads'
114
+ app.config['RESULT_FOLDER'] = 'results'
115
+
116
+ os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
117
+ os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
118
 
119
+ @app.route('/', methods=['GET'])
120
  def index():
121
  return render_template('index.html')
122
 
123
  @app.route('/api/remove_bg', methods=['POST'])
124
+ def remove_background():
125
  if 'image' not in request.files:
126
  return jsonify({'error': 'No image provided'}), 400
127
 
 
129
  if file.filename == '':
130
  return jsonify({'error': 'No image selected'}), 400
131
 
132
+ # Save uploaded file
133
+ upload_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
134
+ file.save(upload_path)
135
 
136
  try:
137
+ # Process image
138
+ image_tensor, orig_size = load_image(upload_path, hypar)
139
  mask = predict(net, image_tensor, orig_size, hypar, device)
140
 
141
+ # Create results
142
  pil_mask = Image.fromarray(mask).convert('L')
143
+ im_rgb = Image.open(upload_path).convert("RGB")
 
 
 
 
 
 
144
  im_rgba = im_rgb.copy()
145
  im_rgba.putalpha(pil_mask)
 
 
146
 
147
+ # Save results
148
+ result_rgba_path = os.path.join(app.config['RESULT_FOLDER'], f"rgba_{file.filename}")
149
+ result_mask_path = os.path.join(app.config['RESULT_FOLDER'], f"mask_{file.filename}")
150
+
151
+ im_rgba.save(result_rgba_path, format="PNG")
152
+ pil_mask.save(result_mask_path, format="PNG")
153
 
154
  return jsonify({
155
+ 'rgba_image': f"/results/rgba_{file.filename}",
156
+ 'mask_image': f"/results/mask_{file.filename}"
157
  })
158
  except Exception as e:
 
 
 
159
  return jsonify({'error': str(e)}), 500
160
 
161
+ @app.route('/results/<filename>')
162
+ def serve_result(filename):
163
+ return send_from_directory(app.config['RESULT_FOLDER'], filename)
164
+
165
+ @app.route('/uploads/<filename>')
166
+ def serve_upload(filename):
167
+ return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
168
+
169
  if __name__ == '__main__':
170
+ app.run(host='0.0.0.0', port=5000, debug=True)