wolo-wolo
commited on
Commit
·
4d10ed1
1
Parent(s):
e46e042
V1.0
Browse files- app.py +249 -329
- engine_finetune.py +58 -251
- models_vit.py +1 -1
- util/crop.py +1 -1
- util/datasets.py +1 -1
- util/lars.py +1 -1
- util/loss_contrastive.py +1 -1
- util/lr_decay.py +1 -1
- util/lr_sched.py +1 -1
- util/metrics.py +29 -1
- util/misc.py +1 -1
- util/pos_embed.py +1 -1
app.py
CHANGED
@@ -1,220 +1,114 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
-
# Author: Gaojian Wang@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
# --------------------------------------------------------
|
7 |
-
# pip uninstall nvidia_cublas_cu11
|
8 |
|
9 |
import sys
|
10 |
-
|
11 |
-
sys.path.append('..')
|
12 |
import os
|
13 |
-
|
14 |
os.system(f'pip install dlib')
|
15 |
-
import
|
|
|
16 |
import numpy as np
|
17 |
from PIL import Image
|
18 |
-
|
19 |
-
|
|
|
20 |
import gradio as gr
|
21 |
|
22 |
import models_vit
|
23 |
from util.datasets import build_dataset
|
24 |
-
import
|
25 |
-
from engine_finetune import test_all
|
26 |
-
import dlib
|
27 |
-
from huggingface_hub import hf_hub_download
|
28 |
-
|
29 |
-
P = os.path.abspath(__file__)
|
30 |
-
FRAME_SAVE_PATH = os.path.join(P[:-6], 'frame')
|
31 |
-
CKPT_SAVE_PATH = os.path.join(P[:-6], 'checkpoints')
|
32 |
-
CKPT_LIST = ['DfD-Checkpoint_Fine-tuned_on_FF++',
|
33 |
-
'FAS-Checkpoint_Fine-tuned_on_MCIO']
|
34 |
-
CKPT_NAME = {'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
|
35 |
-
'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth'}
|
36 |
-
os.makedirs(FRAME_SAVE_PATH, exist_ok=True)
|
37 |
-
os.makedirs(CKPT_SAVE_PATH, exist_ok=True)
|
38 |
|
39 |
|
40 |
def get_args_parser():
|
41 |
-
parser = argparse.ArgumentParser('
|
42 |
-
parser.add_argument('--batch_size', default=64, type=int,
|
43 |
-
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
|
44 |
parser.add_argument('--epochs', default=50, type=int)
|
45 |
-
parser.add_argument('--accum_iter', default=1, type=int,
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
|
50 |
-
help='Name of model to train')
|
51 |
-
|
52 |
-
parser.add_argument('--input_size', default=224, type=int,
|
53 |
-
help='images input size')
|
54 |
-
parser.add_argument('--normalize_from_IMN', action='store_true',
|
55 |
-
help='cal mean and std from imagenet, else from pretrain datasets')
|
56 |
parser.set_defaults(normalize_from_IMN=True)
|
57 |
-
parser.add_argument('--apply_simple_augment', action='store_true',
|
58 |
-
|
59 |
-
|
60 |
-
parser.add_argument('--
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
parser.add_argument('--
|
65 |
-
|
66 |
-
parser.add_argument('--
|
67 |
-
|
68 |
-
|
69 |
-
parser.add_argument('--
|
70 |
-
|
71 |
-
parser.add_argument('--
|
72 |
-
|
73 |
-
parser.add_argument('--
|
74 |
-
|
75 |
-
|
76 |
-
parser.add_argument('--
|
77 |
-
|
78 |
-
|
79 |
-
parser.add_argument('--
|
80 |
-
help='epochs to warmup LR')
|
81 |
-
|
82 |
-
# Augmentation parameters
|
83 |
-
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
|
84 |
-
help='Color jitter factor (enabled only when not using Auto/RandAug)')
|
85 |
-
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
|
86 |
-
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
|
87 |
-
parser.add_argument('--smoothing', type=float, default=0.1,
|
88 |
-
help='Label smoothing (default: 0.1)')
|
89 |
-
|
90 |
-
# * Random Erase params
|
91 |
-
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
92 |
-
help='Random erase prob (default: 0.25)')
|
93 |
-
parser.add_argument('--remode', type=str, default='pixel',
|
94 |
-
help='Random erase mode (default: "pixel")')
|
95 |
-
parser.add_argument('--recount', type=int, default=1,
|
96 |
-
help='Random erase count (default: 1)')
|
97 |
-
parser.add_argument('--resplit', action='store_true', default=False,
|
98 |
-
help='Do not random erase first (clean) augmentation split')
|
99 |
-
|
100 |
-
# * Mixup params
|
101 |
-
parser.add_argument('--mixup', type=float, default=0,
|
102 |
-
help='mixup alpha, mixup enabled if > 0.')
|
103 |
-
parser.add_argument('--cutmix', type=float, default=0,
|
104 |
-
help='cutmix alpha, cutmix enabled if > 0.')
|
105 |
-
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
|
106 |
-
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
107 |
-
parser.add_argument('--mixup_prob', type=float, default=1.0,
|
108 |
-
help='Probability of performing mixup or cutmix when either/both is enabled')
|
109 |
-
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
|
110 |
-
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
111 |
-
parser.add_argument('--mixup_mode', type=str, default='batch',
|
112 |
-
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
113 |
-
|
114 |
-
# * Finetuning params
|
115 |
-
parser.add_argument('--finetune', default='',
|
116 |
-
help='finetune from checkpoint')
|
117 |
parser.add_argument('--global_pool', action='store_true')
|
118 |
parser.set_defaults(global_pool=True)
|
119 |
-
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
parser.add_argument('--
|
124 |
-
|
125 |
-
parser.add_argument('--nb_classes', default=1000, type=int,
|
126 |
-
help='number of the classification types')
|
127 |
-
|
128 |
-
parser.add_argument('--output_dir', default='',
|
129 |
-
help='path where to save, empty for no saving')
|
130 |
-
parser.add_argument('--log_dir', default='',
|
131 |
-
help='path where to tensorboard log')
|
132 |
-
parser.add_argument('--device', default='cuda',
|
133 |
-
help='device to use for training / testing')
|
134 |
parser.add_argument('--seed', default=0, type=int)
|
135 |
-
parser.add_argument('--resume', default='',
|
136 |
-
|
137 |
-
|
138 |
-
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
139 |
-
help='start epoch')
|
140 |
-
parser.add_argument('--eval', action='store_true',
|
141 |
-
help='Perform evaluation only')
|
142 |
parser.set_defaults(eval=True)
|
143 |
-
parser.add_argument('--dist_eval', action='store_true', default=False,
|
144 |
-
help='Enabling distributed evaluation (recommended during training for faster monitor')
|
145 |
parser.add_argument('--num_workers', default=10, type=int)
|
146 |
-
parser.add_argument('--pin_mem', action='store_true',
|
147 |
-
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
148 |
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
149 |
parser.set_defaults(pin_mem=True)
|
150 |
-
|
151 |
-
# distributed training parameters
|
152 |
-
parser.add_argument('--world_size', default=1, type=int,
|
153 |
-
help='number of distributed processes')
|
154 |
parser.add_argument('--local_rank', default=-1, type=int)
|
155 |
parser.add_argument('--dist_on_itp', action='store_true')
|
156 |
-
parser.add_argument('--dist_url', default='env://',
|
157 |
-
help='url used to set up distributed training')
|
158 |
-
|
159 |
return parser
|
160 |
|
161 |
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
if os.path.isfile(args.resume) == False:
|
180 |
-
hf_hub_download(local_dir=CKPT_SAVE_PATH,
|
181 |
-
repo_id='Wolowolo/fsfm-3c/' + CKPT_NAME[ckpt],
|
182 |
-
filename=ckpt)
|
183 |
-
checkpoint = torch.load(args.resume, map_location='cpu')
|
184 |
-
model.load_state_dict(checkpoint['model'])
|
185 |
model.eval()
|
186 |
-
return gr.update()
|
187 |
|
188 |
|
189 |
def get_boundingbox(face, width, height, minsize=None):
|
190 |
-
|
191 |
-
From FF++:
|
192 |
-
https://github.com/ondyari/FaceForensics/blob/master/classification/detect_from_video.py
|
193 |
-
Expects a dlib face to generate a quadratic bounding box.
|
194 |
-
:param face: dlib face class
|
195 |
-
:param width: frame width
|
196 |
-
:param height: frame height
|
197 |
-
:param cfg.face_scale: bounding box size multiplier to get a bigger face region
|
198 |
-
:param minsize: set minimum bounding box size
|
199 |
-
:return: x, y, bounding_box_size in opencv form
|
200 |
-
"""
|
201 |
-
x1 = face.left()
|
202 |
-
y1 = face.top()
|
203 |
-
x2 = face.right()
|
204 |
-
y2 = face.bottom()
|
205 |
size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
|
206 |
-
if minsize:
|
207 |
-
|
208 |
-
size_bb = minsize
|
209 |
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
|
210 |
-
|
211 |
-
# Check for out of bounds, x-y top left corner
|
212 |
-
x1 = max(int(center_x - size_bb // 2), 0)
|
213 |
-
y1 = max(int(center_y - size_bb // 2), 0)
|
214 |
-
# Check for too big bb size for given x, y
|
215 |
size_bb = min(width - x1, size_bb)
|
216 |
size_bb = min(height - y1, size_bb)
|
217 |
-
|
218 |
return x1, y1, size_bb
|
219 |
|
220 |
|
@@ -222,200 +116,226 @@ def extract_face(frame):
|
|
222 |
face_detector = dlib.get_frontal_face_detector()
|
223 |
image = np.array(frame.convert('RGB'))
|
224 |
faces = face_detector(image, 1)
|
225 |
-
if
|
226 |
-
# For now only take the biggest face
|
227 |
face = faces[0]
|
228 |
-
# Face crop and rescale(follow FF++)
|
229 |
x, y, size = get_boundingbox(face, image.shape[1], image.shape[0])
|
230 |
-
# Get the landmarks/parts for the face in box d only with the five key points
|
231 |
cropped_face = image[y:y + size, x:x + size]
|
232 |
-
# cropped_face = cv2.resize(cropped_face, (224, 224), interpolation=cv2.INTER_CUBIC)
|
233 |
return Image.fromarray(cropped_face)
|
234 |
-
|
235 |
-
return None
|
236 |
|
237 |
|
238 |
def get_frame_index_uniform_sample(total_frame_num, extract_frame_num):
|
239 |
-
|
240 |
-
return interval.tolist()
|
241 |
-
|
242 |
-
|
243 |
-
import cv2
|
244 |
|
245 |
|
246 |
-
def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None
|
247 |
-
"""
|
248 |
-
1) extract specific num of frames from videos in [1st(index 0) frame, last frame] with uniform sample interval
|
249 |
-
2) extract face from frame with specific enlarge size
|
250 |
-
"""
|
251 |
video_capture = cv2.VideoCapture(src_video)
|
252 |
-
total_frames = video_capture.get(
|
253 |
-
|
254 |
-
# extract from the 1st(index 0) frame
|
255 |
-
if num_frames is not None:
|
256 |
-
frame_indices = get_frame_index_uniform_sample(total_frames, num_frames)
|
257 |
-
else:
|
258 |
-
frame_indices = range(int(total_frames))
|
259 |
-
|
260 |
for frame_index in frame_indices:
|
261 |
video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
|
262 |
ret, frame = video_capture.read()
|
263 |
-
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
264 |
-
img = extract_face(image)
|
265 |
-
if img == None:
|
266 |
-
continue
|
267 |
-
img = img.resize((224, 224), Image.BICUBIC)
|
268 |
if not ret:
|
269 |
continue
|
270 |
-
|
271 |
-
|
272 |
-
img
|
273 |
-
|
274 |
-
|
|
|
275 |
video_capture.release()
|
276 |
-
# cv2.destroyAllWindows()
|
277 |
return frame_indices
|
278 |
|
279 |
|
280 |
-
def
|
281 |
-
|
282 |
-
num_frames = 32
|
283 |
-
|
284 |
-
files = os.listdir(FRAME_SAVE_PATH)
|
285 |
-
num_files = len(files)
|
286 |
-
frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files))
|
287 |
os.makedirs(frame_path, exist_ok=True)
|
288 |
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
|
289 |
-
frame_indices = extract_face_from_fixed_num_frames(video, frame_path, num_frames=num_frames, device=device)
|
290 |
-
|
291 |
-
args.data_path = frame_path
|
292 |
-
args.batch_size = 32
|
293 |
-
dataset_val = build_dataset(is_train=False, args=args)
|
294 |
-
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
295 |
-
data_loader_val = torch.utils.data.DataLoader(
|
296 |
-
dataset_val, sampler=sampler_val,
|
297 |
-
batch_size=args.batch_size,
|
298 |
-
num_workers=args.num_workers,
|
299 |
-
pin_memory=args.pin_mem,
|
300 |
-
drop_last=False
|
301 |
-
)
|
302 |
-
|
303 |
-
frame_preds_list, video_pred_list = test_all(data_loader_val, model, device)
|
304 |
-
|
305 |
-
real_prob_frames = [round(1. - fake_score, 2) for fake_score in video_pred_list]
|
306 |
-
frame_results = {f"frame_{frame}": f"{int(real_prob_frames[i] * 100)}%" for i, frame in enumerate(frame_indices)}
|
307 |
-
|
308 |
-
real_prob_video = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
|
309 |
-
if real_prob_video > 50:
|
310 |
-
result_message = "real" if 'FAS' not in ckpt_select_dropdown else 'spoof'
|
311 |
-
else:
|
312 |
-
result_message = "fake" if 'FAS' not in ckpt_select_dropdown else 'real'
|
313 |
-
prob = 1 - real_prob_image if real_prob_video <= 50 else real_prob_video
|
314 |
-
image_results = (f"The face in this image may be {result_message} with probability is {real_prob_image}%")
|
315 |
-
|
316 |
-
video_results = (f"The face in this video may be {result_message} with probability {prob}")
|
317 |
-
|
318 |
-
return video_results
|
319 |
-
|
320 |
-
|
321 |
-
def FSFM3C_image_detection(image, ckpt_select_dropdown):
|
322 |
-
files = os.listdir(FRAME_SAVE_PATH)
|
323 |
-
num_files = len(files)
|
324 |
-
frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files))
|
325 |
-
os.makedirs(frame_path, exist_ok=True)
|
326 |
-
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
|
327 |
-
|
328 |
-
save_img_name = f"frame_0.png"
|
329 |
img = extract_face(image)
|
330 |
if img is None:
|
331 |
-
return
|
332 |
img = img.resize((224, 224), Image.BICUBIC)
|
333 |
-
img.save(os.path.join(frame_path, '0',
|
334 |
-
|
335 |
args.data_path = frame_path
|
336 |
args.batch_size = 1
|
337 |
dataset_val = build_dataset(is_train=False, args=args)
|
338 |
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
339 |
-
data_loader_val = torch.utils.data.DataLoader(
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
with
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
ckpt_select_dropdown = gr.Dropdown(
|
380 |
-
label="Select the Model
|
381 |
-
|
|
|
382 |
multiselect=False,
|
383 |
-
value='
|
384 |
interactive=True,
|
385 |
)
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
video_submit_btn = gr.Button("Submit")
|
400 |
-
output_results_video = gr.Textbox(label="Detection Result")
|
401 |
|
|
|
|
|
|
|
|
|
|
|
402 |
image_submit_btn.click(
|
403 |
fn=FSFM3C_image_detection,
|
404 |
-
inputs=[image
|
405 |
outputs=[output_results_image],
|
406 |
)
|
407 |
video_submit_btn.click(
|
408 |
fn=FSFM3C_video_detection,
|
409 |
-
inputs=[video,
|
410 |
outputs=[output_results_video],
|
411 |
)
|
412 |
-
|
413 |
-
fn=load_model,
|
414 |
-
inputs=[ckpt_select_dropdown],
|
415 |
-
outputs=[ckpt_select_dropdown],
|
416 |
-
)
|
417 |
|
418 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
gr.close_all()
|
420 |
demo.queue()
|
421 |
-
demo.launch()
|
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
# --------------------------------------------------------
|
|
|
7 |
|
8 |
import sys
|
|
|
|
|
9 |
import os
|
|
|
10 |
os.system(f'pip install dlib')
|
11 |
+
import dlib
|
12 |
+
import argparse
|
13 |
import numpy as np
|
14 |
from PIL import Image
|
15 |
+
import cv2
|
16 |
+
import torch
|
17 |
+
from huggingface_hub import hf_hub_download
|
18 |
import gradio as gr
|
19 |
|
20 |
import models_vit
|
21 |
from util.datasets import build_dataset
|
22 |
+
from engine_finetune import test_two_class, test_multi_class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
def get_args_parser():
|
26 |
+
parser = argparse.ArgumentParser('FSFM3C fine-tuning&Testing for image classification', add_help=False)
|
27 |
+
parser.add_argument('--batch_size', default=64, type=int, help='Batch size per GPU')
|
|
|
28 |
parser.add_argument('--epochs', default=50, type=int)
|
29 |
+
parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations')
|
30 |
+
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', help='Name of model to train')
|
31 |
+
parser.add_argument('--input_size', default=224, type=int, help='images input size')
|
32 |
+
parser.add_argument('--normalize_from_IMN', action='store_true', help='cal mean and std from imagenet')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
parser.set_defaults(normalize_from_IMN=True)
|
34 |
+
parser.add_argument('--apply_simple_augment', action='store_true', help='apply simple data augment')
|
35 |
+
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path rate')
|
36 |
+
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', help='Clip gradient norm')
|
37 |
+
parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay')
|
38 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate')
|
39 |
+
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', help='base learning rate')
|
40 |
+
parser.add_argument('--layer_decay', type=float, default=0.75, help='layer-wise lr decay')
|
41 |
+
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound')
|
42 |
+
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR')
|
43 |
+
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', help='Color jitter factor')
|
44 |
+
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy')
|
45 |
+
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing')
|
46 |
+
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob')
|
47 |
+
parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode')
|
48 |
+
parser.add_argument('--recount', type=int, default=1, help='Random erase count')
|
49 |
+
parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first augmentation split')
|
50 |
+
parser.add_argument('--mixup', type=float, default=0, help='mixup alpha')
|
51 |
+
parser.add_argument('--cutmix', type=float, default=0, help='cutmix alpha')
|
52 |
+
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio')
|
53 |
+
parser.add_argument('--mixup_prob', type=float, default=1.0, help='Probability of performing mixup or cutmix')
|
54 |
+
parser.add_argument('--mixup_switch_prob', type=float, default=0.5, help='Probability of switching to cutmix')
|
55 |
+
parser.add_argument('--mixup_mode', type=str, default='batch', help='How to apply mixup/cutmix params')
|
56 |
+
parser.add_argument('--finetune', default='', help='finetune from checkpoint')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
parser.add_argument('--global_pool', action='store_true')
|
58 |
parser.set_defaults(global_pool=True)
|
59 |
+
parser.add_argument('--cls_token', action='store_false', dest='global_pool', help='Use class token for classification')
|
60 |
+
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, help='dataset path')
|
61 |
+
parser.add_argument('--nb_classes', default=1000, type=int, help='number of the classification types')
|
62 |
+
parser.add_argument('--output_dir', default='', help='path where to save')
|
63 |
+
parser.add_argument('--log_dir', default='', help='path where to tensorboard log')
|
64 |
+
parser.add_argument('--device', default='cuda', help='device to use for training / testing')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
parser.add_argument('--seed', default=0, type=int)
|
66 |
+
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
67 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
|
68 |
+
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
|
|
|
|
|
|
|
|
|
69 |
parser.set_defaults(eval=True)
|
70 |
+
parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation')
|
|
|
71 |
parser.add_argument('--num_workers', default=10, type=int)
|
72 |
+
parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader')
|
|
|
73 |
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
|
74 |
parser.set_defaults(pin_mem=True)
|
75 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
|
|
|
|
|
|
76 |
parser.add_argument('--local_rank', default=-1, type=int)
|
77 |
parser.add_argument('--dist_on_itp', action='store_true')
|
78 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
|
|
|
|
79 |
return parser
|
80 |
|
81 |
|
82 |
+
def load_model(select_skpt):
|
83 |
+
global ckpt, device, model, checkpoint
|
84 |
+
if select_skpt not in CKPT_NAME:
|
85 |
+
return gr.update(), "Select a correct model"
|
86 |
+
ckpt = select_skpt
|
87 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
88 |
+
args.nb_classes = CKPT_CLASS[ckpt]
|
89 |
+
model = models_vit.__dict__[CKPT_MODEL[ckpt]](
|
90 |
+
num_classes=args.nb_classes,
|
91 |
+
drop_path_rate=args.drop_path,
|
92 |
+
global_pool=args.global_pool,
|
93 |
+
).to(device)
|
94 |
+
|
95 |
+
args.resume = os.path.join(CKPT_SAVE_PATH, ckpt)
|
96 |
+
args.resume = CKPT_PATH[ckpt]
|
97 |
+
checkpoint = torch.load(args.resume, map_location=device)
|
98 |
+
model.load_state_dict(checkpoint['model'], strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
model.eval()
|
100 |
+
return gr.update(), f"[Loaded Model Successfully:] {args.resume}] "
|
101 |
|
102 |
|
103 |
def get_boundingbox(face, width, height, minsize=None):
|
104 |
+
x1, y1, x2, y2 = face.left(), face.top(), face.right(), face.bottom()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
|
106 |
+
if minsize and size_bb < minsize:
|
107 |
+
size_bb = minsize
|
|
|
108 |
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
|
109 |
+
x1, y1 = max(int(center_x - size_bb // 2), 0), max(int(center_y - size_bb // 2), 0)
|
|
|
|
|
|
|
|
|
110 |
size_bb = min(width - x1, size_bb)
|
111 |
size_bb = min(height - y1, size_bb)
|
|
|
112 |
return x1, y1, size_bb
|
113 |
|
114 |
|
|
|
116 |
face_detector = dlib.get_frontal_face_detector()
|
117 |
image = np.array(frame.convert('RGB'))
|
118 |
faces = face_detector(image, 1)
|
119 |
+
if faces:
|
|
|
120 |
face = faces[0]
|
|
|
121 |
x, y, size = get_boundingbox(face, image.shape[1], image.shape[0])
|
|
|
122 |
cropped_face = image[y:y + size, x:x + size]
|
|
|
123 |
return Image.fromarray(cropped_face)
|
124 |
+
return None
|
|
|
125 |
|
126 |
|
127 |
def get_frame_index_uniform_sample(total_frame_num, extract_frame_num):
|
128 |
+
return np.linspace(0, total_frame_num - 1, num=extract_frame_num, dtype=int).tolist()
|
|
|
|
|
|
|
|
|
129 |
|
130 |
|
131 |
+
def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None):
|
|
|
|
|
|
|
|
|
132 |
video_capture = cv2.VideoCapture(src_video)
|
133 |
+
total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
134 |
+
frame_indices = get_frame_index_uniform_sample(total_frames, num_frames) if num_frames else range(total_frames)
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
for frame_index in frame_indices:
|
136 |
video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
|
137 |
ret, frame = video_capture.read()
|
|
|
|
|
|
|
|
|
|
|
138 |
if not ret:
|
139 |
continue
|
140 |
+
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
141 |
+
img = extract_face(image)
|
142 |
+
if img:
|
143 |
+
img = img.resize((224, 224), Image.BICUBIC)
|
144 |
+
save_img_name = f"frame_{frame_index}.png"
|
145 |
+
img.save(os.path.join(dst_path, '0', save_img_name))
|
146 |
video_capture.release()
|
|
|
147 |
return frame_indices
|
148 |
|
149 |
|
150 |
+
def FSFM3C_image_detection(image):
|
151 |
+
frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
|
|
|
|
|
|
|
|
|
|
|
152 |
os.makedirs(frame_path, exist_ok=True)
|
153 |
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
img = extract_face(image)
|
155 |
if img is None:
|
156 |
+
return 'No face detected, please upload a clear face!'
|
157 |
img = img.resize((224, 224), Image.BICUBIC)
|
158 |
+
img.save(os.path.join(frame_path, '0', "frame_0.png"))
|
|
|
159 |
args.data_path = frame_path
|
160 |
args.batch_size = 1
|
161 |
dataset_val = build_dataset(is_train=False, args=args)
|
162 |
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
163 |
+
data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
|
164 |
+
|
165 |
+
if CKPT_CLASS[ckpt] > 2:
|
166 |
+
frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
|
167 |
+
class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
|
168 |
+
avg_video_pred = np.mean(video_pred_list, axis=0)
|
169 |
+
max_prob_index = np.argmax(avg_video_pred)
|
170 |
+
max_prob_class = class_names[max_prob_index]
|
171 |
+
probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
|
172 |
+
image_results = f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]"
|
173 |
+
return image_results
|
174 |
+
|
175 |
+
if CKPT_CLASS[ckpt] == 2:
|
176 |
+
frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
|
177 |
+
if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
|
178 |
+
prob = sum(video_pred_list) / len(video_pred_list)
|
179 |
+
label = "Deepfake" if prob <= 0.5 else "Real"
|
180 |
+
prob = prob if label == "Real" else 1 - prob
|
181 |
+
if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
|
182 |
+
prob = sum(video_pred_list) / len(video_pred_list)
|
183 |
+
label = "Spoofing" if prob <= 0.5 else "Bonafide"
|
184 |
+
prob = prob if label == "Bonafide" else 1 - prob
|
185 |
+
image_results = f"The largest face in this image may be {label} with probability {prob * 100:.1f}%"
|
186 |
+
return image_results
|
187 |
+
|
188 |
+
|
189 |
+
def FSFM3C_video_detection(video, num_frames):
|
190 |
+
try:
|
191 |
+
frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
|
192 |
+
os.makedirs(frame_path, exist_ok=True)
|
193 |
+
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
|
194 |
+
frame_indices = extract_face_from_fixed_num_frames(video, frame_path, num_frames=num_frames)
|
195 |
+
args.data_path = frame_path
|
196 |
+
args.batch_size = num_frames
|
197 |
+
dataset_val = build_dataset(is_train=False, args=args)
|
198 |
+
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
199 |
+
data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
|
200 |
+
|
201 |
+
if CKPT_CLASS[ckpt] > 2:
|
202 |
+
frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
|
203 |
+
class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
|
204 |
+
avg_video_pred = np.mean(video_pred_list, axis=0)
|
205 |
+
max_prob_index = np.argmax(avg_video_pred)
|
206 |
+
max_prob_class = class_names[max_prob_index]
|
207 |
+
probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
|
208 |
+
|
209 |
+
frame_results = {f"frame_{frame_indices[i]}": [f"{class_names[j]}: {prob * 100:.1f}%" for j, prob in enumerate(frame_preds_list[i])] for i in range(len(frame_indices))}
|
210 |
+
video_results = (f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]\n \n"
|
211 |
+
f"The frame-level detection results ['frame_index': 'probabilities']: \n{frame_results}")
|
212 |
+
return video_results
|
213 |
+
|
214 |
+
if CKPT_CLASS[ckpt] == 2:
|
215 |
+
frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
|
216 |
+
if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
|
217 |
+
prob = sum(video_pred_list) / len(video_pred_list)
|
218 |
+
label = "Deepfake" if prob <= 0.5 else "Real"
|
219 |
+
prob = prob if label == "Real" else 1 - prob
|
220 |
+
frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
|
221 |
+
range(len(frame_indices))} if label == "Real" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
|
222 |
+
range(len(frame_indices))}
|
223 |
+
|
224 |
+
if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
|
225 |
+
prob = sum(video_pred_list) / len(video_pred_list)
|
226 |
+
label = "Spoofing" if prob <= 0.5 else "Bonafide"
|
227 |
+
prob = prob if label == "Bonafide" else 1 - prob
|
228 |
+
frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
|
229 |
+
range(len(frame_indices))} if label == "Bonafide" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
|
230 |
+
range(len(frame_indices))}
|
231 |
+
|
232 |
+
video_results = (f"The largest face in this image may be {label} with probability {prob * 100:.1f}%\n \n"
|
233 |
+
f"The frame-level detection results ['frame_index': 'real_face_probability']: \n{frame_results}")
|
234 |
+
return video_results
|
235 |
+
except Exception as e:
|
236 |
+
return f"Error occurred. Please provide a clear face video or reduce the number of frames."
|
237 |
+
|
238 |
+
# Paths and Constants
|
239 |
+
P = os.path.abspath(__file__)
|
240 |
+
FRAME_SAVE_PATH = os.path.join(os.path.dirname(P), 'frame')
|
241 |
+
CKPT_SAVE_PATH = os.path.join(os.path.dirname(P), 'checkpoints')
|
242 |
+
os.makedirs(FRAME_SAVE_PATH, exist_ok=True)
|
243 |
+
os.makedirs(CKPT_SAVE_PATH, exist_ok=True)
|
244 |
+
CKPT_NAME = [
|
245 |
+
'✨Unified-detector_v1_Fine-tuned_on_4_classes',
|
246 |
+
'DfD-Checkpoint_Fine-tuned_on_FF++',
|
247 |
+
'FAS-Checkpoint_Fine-tuned_on_MCIO',
|
248 |
+
]
|
249 |
+
# CKPT_PATH = {
|
250 |
+
# '✨Unified-detector_v1_Fine-tuned_on_4_classes': 'finetuned_models/Unified-detector/v1_Fine-tuned_on_4_classes/checkpoint-min_val_loss.pth',
|
251 |
+
# 'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
|
252 |
+
# 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth',
|
253 |
+
# }
|
254 |
+
CKPT_PATH = {
|
255 |
+
'✨Unified-detector_v1_Fine-tuned_on_4_classes': './checkpoints/checkpoint-min_train_loss.pth',
|
256 |
+
'DfD-Checkpoint_Fine-tuned_on_FF++': '/mnt/localDisk2/wgj/FSFM/released/FSFM-main/fsfm-3c/finuetune/cross_dataset_DfD/checkpoint/finetuned_models/ft_on_FF++_c23_32frames/pt_from_VF2_ViT-B_epoch600/checkpoint-min_val_loss.pth',
|
257 |
+
'FAS-Checkpoint_Fine-tuned_on_MCIO': '/mnt/localDisk2/wgj/FSFM/FSFM-3C/codespace/fsfm-3c/finuetune/cross_dataset_DfD/finetuned_models/FAS_MCIO/checkpoint-199.pth',
|
258 |
+
}
|
259 |
+
CKPT_CLASS = {
|
260 |
+
'✨Unified-detector_v1_Fine-tuned_on_4_classes': 4,
|
261 |
+
'DfD-Checkpoint_Fine-tuned_on_FF++': 2,
|
262 |
+
'FAS-Checkpoint_Fine-tuned_on_MCIO': 2
|
263 |
+
}
|
264 |
+
CKPT_MODEL = {
|
265 |
+
'✨Unified-detector_v1_Fine-tuned_on_4_classes': 'vit_base_patch16',
|
266 |
+
'DfD-Checkpoint_Fine-tuned_on_FF++': 'vit_base_patch16',
|
267 |
+
'FAS-Checkpoint_Fine-tuned_on_MCIO': 'vit_base_patch16',
|
268 |
+
}
|
269 |
+
|
270 |
+
|
271 |
+
with gr.Blocks(css=".custom-label { font-weight: bold !important; font-size: 16px !important; }") as demo:
|
272 |
+
gr.HTML("<h1 style='text-align: center;'>🦱 Real Facial Image&Video Detection <br> Against Face Forgery (Deepfake/Diffusion) and Spoofing (Presentation-attacks)</h1>")
|
273 |
+
gr.Markdown("<b>☉ Powered by the fine-tuned model that is pre-trained from [FSFM-3C](https://fsfm-3c.github.io/)</b> <br> "
|
274 |
+
"<b>☉ Release (Continuously updating) </b> <br> <b>[V1.0]</b> 2025/02/22-Current🎉: "
|
275 |
+
"1) Updated <b>[✨Unified-detector_v1] for Unified Physical-Digital Face Attack&Forgery Detection, a vanilla ViT-B/16-224 (FSFM Pre-trained) that could identify Real&Bonafide, Deepfake, Diffduion&AIGC, Spooing&Presentation-attacks facial images or videos </b> ; 2) Provided the selection of the number of video frames (uniformly sampling, more frames are too time-consuming, and we would be grateful if you support us to open paid GPU acceleration); 3) Fixed the errors of V0.1 including loading model and prediction. <br>"
|
276 |
+
"<b>[V0.1]</b> 2024/12-2025/02/21: "
|
277 |
+
"Create this page with basic detectors [DfD-Checkpoint_Fine-tuned_on_FF++, FAS-Checkpoint_Fine-tuned_on_MCIO] that follow the paper implementation. <br> ")
|
278 |
+
gr.Markdown("- Please <b>provide a facial image or video(<100s)</b>, and <b>select the model</b> for detection: <br> <b>[suggest] [✨Unified-detector_v1_Fine-tuned_on_4_classes]</b> <b>a (FSFM Pre-trained) ViT-B/16-224 for Both Real/Deepfake/Diffusion/Spoofing facial images&videos Detection <b> <br> <b>[DfD-Checkpoint_Fine-tuned_on_FF++]</b> for deepfake detection, FSFM ViT-B/16-224 fine-tuned on the FF++_c23 train&val sets (4 manipulations, 32 frames per video) <br> <b>[FAS-Checkpoint_Fine-tuned_on_MCIO]</b> for face anti-spoofing, FSFM ViT-B/16-224 fine-tuned on the MCIO datasets (2 frames per video)")
|
279 |
+
|
280 |
+
|
281 |
+
with gr.Row():
|
282 |
ckpt_select_dropdown = gr.Dropdown(
|
283 |
+
label="Select the Model for Detection ⬇️",
|
284 |
+
elem_classes="custom-label",
|
285 |
+
choices=['Choose Model Here 🖱️'] + CKPT_NAME + ['continuously updating...'],
|
286 |
multiselect=False,
|
287 |
+
value='Choose Model Here 🖱️',
|
288 |
interactive=True,
|
289 |
)
|
290 |
+
model_loading_status = gr.Textbox(label="Model Loading Status")
|
291 |
+
with gr.Row():
|
292 |
+
with gr.Column(scale=5):
|
293 |
+
gr.Markdown("### Image Detection")
|
294 |
+
image = gr.Image(label="Upload/Capture/Paste your image", type="pil")
|
295 |
+
image_submit_btn = gr.Button("Submit")
|
296 |
+
output_results_image = gr.Textbox(label="Detection Result")
|
297 |
+
with gr.Column(scale=5):
|
298 |
+
gr.Markdown("### Video Detection")
|
299 |
+
video = gr.Video(label="Upload/Capture your video")
|
300 |
+
frame_slider = gr.Slider(minimum=1, maximum=32, step=1, value=32, label="Number of Frames for Detection")
|
301 |
+
video_submit_btn = gr.Button("Submit")
|
302 |
+
output_results_video = gr.Textbox(label="Detection Result")
|
|
|
|
|
303 |
|
304 |
+
ckpt_select_dropdown.change(
|
305 |
+
fn=load_model,
|
306 |
+
inputs=[ckpt_select_dropdown],
|
307 |
+
outputs=[ckpt_select_dropdown, model_loading_status],
|
308 |
+
)
|
309 |
image_submit_btn.click(
|
310 |
fn=FSFM3C_image_detection,
|
311 |
+
inputs=[image],
|
312 |
outputs=[output_results_image],
|
313 |
)
|
314 |
video_submit_btn.click(
|
315 |
fn=FSFM3C_video_detection,
|
316 |
+
inputs=[video, frame_slider],
|
317 |
outputs=[output_results_video],
|
318 |
)
|
319 |
+
|
|
|
|
|
|
|
|
|
320 |
|
321 |
if __name__ == "__main__":
|
322 |
+
args = get_args_parser()
|
323 |
+
args = args.parse_args()
|
324 |
+
ckpt = 'DfD-Checkpoint_Fine-tuned_on_FF++'
|
325 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
326 |
+
args.nb_classes = CKPT_CLASS[ckpt]
|
327 |
+
model = models_vit.__dict__[CKPT_MODEL[ckpt]](
|
328 |
+
num_classes=args.nb_classes,
|
329 |
+
drop_path_rate=args.drop_path,
|
330 |
+
global_pool=args.global_pool,
|
331 |
+
).to(device)
|
332 |
+
args.resume = os.path.join(CKPT_SAVE_PATH, ckpt)
|
333 |
+
args.resume = CKPT_PATH[ckpt]
|
334 |
+
checkpoint = torch.load(args.resume, map_location=device)
|
335 |
+
model.load_state_dict(checkpoint['model'], strict=False)
|
336 |
+
model.eval()
|
337 |
+
|
338 |
gr.close_all()
|
339 |
demo.queue()
|
340 |
+
# demo.launch()
|
341 |
+
demo.launch(server_name="0.0.0.0", server_port=8888)
|
engine_finetune.py
CHANGED
@@ -1,323 +1,130 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
-
# Author: Gaojian Wang@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
# --------------------------------------------------------
|
7 |
|
8 |
-
import math
|
9 |
-
import sys
|
10 |
-
from typing import Iterable, Optional
|
11 |
-
|
12 |
import numpy as np
|
13 |
import torch
|
14 |
-
|
15 |
-
from timm.data import Mixup
|
16 |
-
from timm.utils import accuracy
|
17 |
|
18 |
import util.misc as misc
|
19 |
-
import util.lr_sched as lr_sched
|
20 |
from util.metrics import *
|
21 |
|
22 |
-
import torch.nn.functional as F
|
23 |
-
|
24 |
-
|
25 |
-
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
26 |
-
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
27 |
-
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
|
28 |
-
mixup_fn: Optional[Mixup] = None, log_writer=None,
|
29 |
-
args=None):
|
30 |
-
model.train(True)
|
31 |
-
metric_logger = misc.MetricLogger(delimiter=" ")
|
32 |
-
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
33 |
-
header = 'Epoch: [{}]'.format(epoch)
|
34 |
-
print_freq = 20
|
35 |
-
|
36 |
-
accum_iter = args.accum_iter
|
37 |
-
|
38 |
-
optimizer.zero_grad()
|
39 |
-
|
40 |
-
if log_writer is not None:
|
41 |
-
print('log_dir: {}'.format(log_writer.log_dir))
|
42 |
-
|
43 |
-
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
44 |
-
|
45 |
-
# we use a per iteration (instead of per epoch) lr scheduler
|
46 |
-
if data_iter_step % accum_iter == 0:
|
47 |
-
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
|
48 |
-
|
49 |
-
samples = samples.to(device, non_blocking=True)
|
50 |
-
targets = targets.to(device, non_blocking=True)
|
51 |
-
|
52 |
-
if mixup_fn is not None:
|
53 |
-
samples, targets = mixup_fn(samples, targets)
|
54 |
-
|
55 |
-
with torch.cuda.amp.autocast():
|
56 |
-
# outputs = model(samples)
|
57 |
-
outputs = model(samples).to(device, non_blocking=True) # modified
|
58 |
-
loss = criterion(outputs, targets)
|
59 |
-
|
60 |
-
loss_value = loss.item()
|
61 |
-
|
62 |
-
if not math.isfinite(loss_value):
|
63 |
-
print("Loss is {}, stopping training".format(loss_value))
|
64 |
-
sys.exit(1)
|
65 |
-
|
66 |
-
loss /= accum_iter
|
67 |
-
loss_scaler(loss, optimizer, clip_grad=max_norm,
|
68 |
-
parameters=model.parameters(), create_graph=False,
|
69 |
-
update_grad=(data_iter_step + 1) % accum_iter == 0)
|
70 |
-
if (data_iter_step + 1) % accum_iter == 0:
|
71 |
-
optimizer.zero_grad()
|
72 |
-
|
73 |
-
torch.cuda.synchronize()
|
74 |
-
|
75 |
-
metric_logger.update(loss=loss_value)
|
76 |
-
min_lr = 10.
|
77 |
-
max_lr = 0.
|
78 |
-
for group in optimizer.param_groups:
|
79 |
-
min_lr = min(min_lr, group["lr"])
|
80 |
-
max_lr = max(max_lr, group["lr"])
|
81 |
-
|
82 |
-
metric_logger.update(lr=max_lr)
|
83 |
-
|
84 |
-
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
85 |
-
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
|
86 |
-
""" We use epoch_1000x as the x-axis in tensorboard.
|
87 |
-
This calibrates different curves when batch size changes.
|
88 |
-
"""
|
89 |
-
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
90 |
-
log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
|
91 |
-
log_writer.add_scalar('lr', max_lr, epoch_1000x)
|
92 |
-
|
93 |
-
# gather the stats from all processes
|
94 |
-
metric_logger.synchronize_between_processes()
|
95 |
-
print("Averaged stats:", metric_logger)
|
96 |
-
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
97 |
-
|
98 |
-
|
99 |
-
@torch.no_grad()
|
100 |
-
def evaluate(data_loader, model, device):
|
101 |
-
criterion = torch.nn.CrossEntropyLoss()
|
102 |
-
|
103 |
-
metric_logger = misc.MetricLogger(delimiter=" ")
|
104 |
-
header = 'Test:'
|
105 |
-
|
106 |
-
# switch to evaluation mode
|
107 |
-
model.eval()
|
108 |
-
|
109 |
-
for batch in metric_logger.log_every(data_loader, 10, header):
|
110 |
-
images = batch[0]
|
111 |
-
target = batch[-1]
|
112 |
-
images = images.to(device, non_blocking=True)
|
113 |
-
target = target.to(device, non_blocking=True)
|
114 |
-
|
115 |
-
# compute output
|
116 |
-
with torch.cuda.amp.autocast():
|
117 |
-
# output = model(images)
|
118 |
-
output = model(images).to(device, non_blocking=True) # modified
|
119 |
-
loss = criterion(output, target)
|
120 |
-
|
121 |
-
# acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
122 |
-
acc = float(accuracy(output, target, topk=(1,))[0])
|
123 |
-
preds = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy())
|
124 |
-
trues = (target.detach().cpu().numpy())
|
125 |
-
auc_score = roc_auc_score(trues, preds) * 100.
|
126 |
-
|
127 |
-
batch_size = images.shape[0]
|
128 |
-
metric_logger.update(loss=loss.item())
|
129 |
-
# metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
130 |
-
# metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
|
131 |
-
metric_logger.meters['acc'].update(acc, n=batch_size)
|
132 |
-
metric_logger.meters['auc'].update(auc_score, n=batch_size)
|
133 |
-
|
134 |
-
# gather the stats from all processes
|
135 |
-
metric_logger.synchronize_between_processes()
|
136 |
-
# print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
|
137 |
-
# .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
|
138 |
-
print('* Acc {acc.global_avg:.3f} Auc {auc.global_avg:.3f} loss {losses.global_avg:.3f}'
|
139 |
-
.format(acc=metric_logger.acc, auc=metric_logger.auc, losses=metric_logger.loss))
|
140 |
-
|
141 |
-
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
142 |
-
|
143 |
-
|
144 |
-
@torch.no_grad()
|
145 |
-
def test_ori(data_loader, model, device):
|
146 |
-
criterion = torch.nn.CrossEntropyLoss()
|
147 |
-
|
148 |
-
metric_logger = misc.MetricLogger(delimiter=" ")
|
149 |
-
header = 'Test:'
|
150 |
-
|
151 |
-
# switch to evaluation mode
|
152 |
-
model.eval()
|
153 |
-
|
154 |
-
labels = np.array([])
|
155 |
-
preds = np.array([])
|
156 |
-
|
157 |
-
for batch in metric_logger.log_every(data_loader, 10, header):
|
158 |
-
images = batch[0]
|
159 |
-
target = batch[-1]
|
160 |
-
images = images.to(device, non_blocking=True)
|
161 |
-
target = target.to(device, non_blocking=True)
|
162 |
-
|
163 |
-
# compute output
|
164 |
-
with torch.cuda.amp.autocast():
|
165 |
-
# output = model(images)
|
166 |
-
output = model(images).to(device, non_blocking=True) # modified
|
167 |
-
loss = criterion(output, target)
|
168 |
-
|
169 |
-
# acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
170 |
-
acc = float(accuracy(output, target, topk=(1,))[0])
|
171 |
-
pred = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy())
|
172 |
-
preds = np.append(preds, pred)
|
173 |
-
label = (target.detach().cpu().numpy())
|
174 |
-
labels = np.append(labels, label)
|
175 |
-
|
176 |
-
batch_size = images.shape[0]
|
177 |
-
metric_logger.update(loss=loss.item())
|
178 |
-
# metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
179 |
-
# metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
|
180 |
-
metric_logger.meters['acc'].update(acc, n=batch_size)
|
181 |
-
|
182 |
-
# gather the stats from all processes
|
183 |
-
metric_logger.synchronize_between_processes()
|
184 |
-
# print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
|
185 |
-
# .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
|
186 |
-
auc_score = roc_auc_score(labels, preds) * 100.
|
187 |
-
metric_logger.meters['auc'].update(auc_score)
|
188 |
-
print('* Acc {acc.global_avg:.3f} Auc {auc.global_avg:.3f} loss {losses.global_avg:.3f}'
|
189 |
-
.format(acc=metric_logger.acc, auc=metric_logger.auc, losses=metric_logger.loss))
|
190 |
-
|
191 |
-
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
192 |
-
|
193 |
|
194 |
@torch.no_grad()
|
195 |
-
def
|
196 |
criterion = torch.nn.CrossEntropyLoss()
|
197 |
|
198 |
-
metric_logger = misc.MetricLogger(delimiter=" ")
|
199 |
-
header = 'Test:'
|
200 |
-
|
201 |
# switch to evaluation mode
|
202 |
model.eval()
|
203 |
|
204 |
frame_labels = np.array([]) # int label
|
205 |
frame_preds = np.array([]) # pred logit
|
206 |
frame_y_preds = np.array([]) # pred int
|
|
|
207 |
|
208 |
-
# for batch in metric_logger.log_every(data_loader, print_freq=len(data_loader), header=header):
|
209 |
for batch in data_loader:
|
210 |
images = batch[0] # torch.Size([BS, C, H, W])
|
211 |
target = batch[1] # torch.Size([BS])
|
212 |
-
|
213 |
images = images.to(device, non_blocking=True)
|
214 |
target = target.to(device, non_blocking=True)
|
215 |
|
216 |
-
#
|
217 |
-
|
218 |
-
# output = model(images)
|
219 |
-
output = model(images).to(device, non_blocking=True) # modified
|
220 |
-
loss = criterion(output, target)
|
221 |
|
222 |
frame_pred = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy())
|
223 |
frame_preds = np.append(frame_preds, frame_pred)
|
224 |
-
|
225 |
frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
|
226 |
frame_y_preds = np.append(frame_y_preds, frame_y_pred)
|
227 |
|
228 |
frame_label = (target.detach().cpu().numpy())
|
229 |
frame_labels = np.append(frame_labels, frame_label)
|
|
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
metric_logger.meters['frame_acc'].update(frame_level_acc(frame_labels, frame_y_preds))
|
236 |
-
metric_logger.meters['frame_balanced_acc'].update(frame_level_balanced_acc(frame_labels, frame_y_preds))
|
237 |
-
metric_logger.meters['frame_auc'].update(frame_level_auc(frame_labels, frame_preds))
|
238 |
-
metric_logger.meters['frame_eer'].update(frame_level_eer(frame_labels, frame_preds))
|
239 |
-
|
240 |
-
print('*[------FRAME-LEVEL------] \n'
|
241 |
-
'Acc {frame_acc.global_avg:.3f} Balanced_Acc {frame_balanced_acc.global_avg:.3f} '
|
242 |
-
'Auc {frame_auc.global_avg:.3f} EER {frame_eer.global_avg:.3f} loss {losses.global_avg:.3f}'
|
243 |
-
.format(frame_acc=metric_logger.frame_acc, frame_balanced_acc=metric_logger.frame_balanced_acc,
|
244 |
-
frame_auc=metric_logger.frame_auc, frame_eer=metric_logger.frame_eer, losses=metric_logger.loss))
|
245 |
|
246 |
-
return
|
247 |
|
248 |
|
249 |
@torch.no_grad()
|
250 |
-
def
|
251 |
criterion = torch.nn.CrossEntropyLoss()
|
252 |
|
253 |
-
metric_logger = misc.MetricLogger(delimiter=" ")
|
254 |
-
header = 'Test:'
|
255 |
-
|
256 |
# switch to evaluation mode
|
257 |
model.eval()
|
258 |
|
259 |
frame_labels = np.array([]) # int label
|
260 |
-
frame_preds = np.
|
261 |
frame_y_preds = np.array([]) # pred int
|
262 |
video_names_list = list()
|
263 |
|
264 |
-
# for batch in metric_logger.log_every(data_loader, print_freq=len(data_loader), header=header):
|
265 |
for batch in data_loader:
|
266 |
images = batch[0] # torch.Size([BS, C, H, W])
|
267 |
target = batch[1] # torch.Size([BS])
|
268 |
video_name = batch[-1] # list[BS]
|
269 |
-
|
270 |
images = images.to(device, non_blocking=True)
|
271 |
target = target.to(device, non_blocking=True)
|
272 |
|
273 |
-
|
274 |
-
# with torch.cuda.amp.autocast():
|
275 |
-
# output = model(images)
|
276 |
-
output = model(images).to(device, non_blocking=True) # modified
|
277 |
loss = criterion(output, target)
|
278 |
|
279 |
-
frame_pred =
|
280 |
-
frame_preds = np.append(frame_preds, frame_pred)
|
281 |
-
|
282 |
frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
|
283 |
frame_y_preds = np.append(frame_y_preds, frame_y_pred)
|
284 |
|
285 |
-
frame_label =
|
286 |
frame_labels = np.append(frame_labels, frame_label)
|
287 |
-
|
288 |
video_names_list.extend(list(video_name))
|
289 |
|
290 |
-
metric_logger.update(loss=loss.item())
|
291 |
-
|
292 |
-
# gather the stats from all processes
|
293 |
-
# metric_logger.synchronize_between_processes()
|
294 |
-
# metric_logger.meters['frame_acc'].update(frame_level_acc(frame_labels, frame_y_preds))
|
295 |
-
# metric_logger.meters['frame_balanced_acc'].update(frame_level_balanced_acc(frame_labels, frame_y_preds))
|
296 |
-
# metric_logger.meters['frame_auc'].update(frame_level_auc(frame_labels, frame_preds))
|
297 |
-
# metric_logger.meters['frame_eer'].update(frame_level_eer(frame_labels, frame_preds))
|
298 |
-
|
299 |
-
# print('*[------FRAME-LEVEL------] \n'
|
300 |
-
# 'Acc {frame_acc.global_avg:.3f} Balanced_Acc {frame_balanced_acc.global_avg:.3f} '
|
301 |
-
# 'Auc {frame_auc.global_avg:.3f} EER {frame_eer.global_avg:.3f} loss {losses.global_avg:.3f}'
|
302 |
-
# .format(frame_acc=metric_logger.frame_acc, frame_balanced_acc=metric_logger.frame_balanced_acc,
|
303 |
-
# frame_auc=metric_logger.frame_auc, frame_eer=metric_logger.frame_eer, losses=metric_logger.loss))
|
304 |
-
|
305 |
# video-level metrics:
|
306 |
frame_labels_list = frame_labels.tolist()
|
307 |
frame_preds_list = frame_preds.tolist()
|
|
|
308 |
|
309 |
-
|
310 |
-
# print(len(video_label_list), len(video_pred_list), len(video_y_pred_list))
|
311 |
-
# metric_logger.meters['video_acc'].update(video_level_acc(video_label_list, video_y_pred_list))
|
312 |
-
# metric_logger.meters['video_balanced_acc'].update(video_level_balanced_acc(video_label_list, video_y_pred_list))
|
313 |
-
# metric_logger.meters['video_auc'].update(video_level_auc(video_label_list, video_pred_list))
|
314 |
-
# metric_logger.meters['video_eer'].update(frame_level_eer(video_label_list, video_pred_list))
|
315 |
|
316 |
-
# print('*[------VIDEO-LEVEL------] \n'
|
317 |
-
# 'Acc {video_acc.global_avg:.3f} Balanced_Acc {video_balanced_acc.global_avg:.3f} '
|
318 |
-
# 'Auc {video_auc.global_avg:.3f} EER {video_eer.global_avg:.3f}'
|
319 |
-
# .format(video_acc=metric_logger.video_acc, video_balanced_acc=metric_logger.video_balanced_acc,
|
320 |
-
# video_auc=metric_logger.video_auc, video_eer=metric_logger.video_eer))
|
321 |
|
322 |
-
|
323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
# --------------------------------------------------------
|
7 |
|
|
|
|
|
|
|
|
|
8 |
import numpy as np
|
9 |
import torch
|
10 |
+
import torch.nn.functional as F
|
|
|
|
|
11 |
|
12 |
import util.misc as misc
|
|
|
13 |
from util.metrics import *
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
@torch.no_grad()
|
17 |
+
def test_two_class(data_loader, model, device):
|
18 |
criterion = torch.nn.CrossEntropyLoss()
|
19 |
|
|
|
|
|
|
|
20 |
# switch to evaluation mode
|
21 |
model.eval()
|
22 |
|
23 |
frame_labels = np.array([]) # int label
|
24 |
frame_preds = np.array([]) # pred logit
|
25 |
frame_y_preds = np.array([]) # pred int
|
26 |
+
video_names_list = list()
|
27 |
|
|
|
28 |
for batch in data_loader:
|
29 |
images = batch[0] # torch.Size([BS, C, H, W])
|
30 |
target = batch[1] # torch.Size([BS])
|
31 |
+
video_name = batch[-1] # list[BS]
|
32 |
images = images.to(device, non_blocking=True)
|
33 |
target = target.to(device, non_blocking=True)
|
34 |
|
35 |
+
output = model(images).to(device, non_blocking=True) # modified
|
36 |
+
loss = criterion(output, target)
|
|
|
|
|
|
|
37 |
|
38 |
frame_pred = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy())
|
39 |
frame_preds = np.append(frame_preds, frame_pred)
|
|
|
40 |
frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
|
41 |
frame_y_preds = np.append(frame_y_preds, frame_y_pred)
|
42 |
|
43 |
frame_label = (target.detach().cpu().numpy())
|
44 |
frame_labels = np.append(frame_labels, frame_label)
|
45 |
+
video_names_list.extend(list(video_name))
|
46 |
|
47 |
+
# video-level metrics:
|
48 |
+
frame_labels_list = frame_labels.tolist()
|
49 |
+
frame_preds_list = frame_preds.tolist()
|
50 |
+
video_label_list, video_pred_list, video_y_pred_list = get_video_level_label_pred(frame_labels_list, video_names_list, frame_preds_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
return frame_preds_list, video_pred_list
|
53 |
|
54 |
|
55 |
@torch.no_grad()
|
56 |
+
def test_multi_class(data_loader, model, device):
|
57 |
criterion = torch.nn.CrossEntropyLoss()
|
58 |
|
|
|
|
|
|
|
59 |
# switch to evaluation mode
|
60 |
model.eval()
|
61 |
|
62 |
frame_labels = np.array([]) # int label
|
63 |
+
frame_preds = np.empty((0, 4)) # pred logit, initialize as 2D array with 4 columns for 4 classes
|
64 |
frame_y_preds = np.array([]) # pred int
|
65 |
video_names_list = list()
|
66 |
|
|
|
67 |
for batch in data_loader:
|
68 |
images = batch[0] # torch.Size([BS, C, H, W])
|
69 |
target = batch[1] # torch.Size([BS])
|
70 |
video_name = batch[-1] # list[BS]
|
|
|
71 |
images = images.to(device, non_blocking=True)
|
72 |
target = target.to(device, non_blocking=True)
|
73 |
|
74 |
+
output = model(images).to(device, non_blocking=True)
|
|
|
|
|
|
|
75 |
loss = criterion(output, target)
|
76 |
|
77 |
+
frame_pred = F.softmax(output, dim=1).detach().cpu().numpy()
|
78 |
+
frame_preds = np.append(frame_preds, frame_pred, axis=0)
|
|
|
79 |
frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
|
80 |
frame_y_preds = np.append(frame_y_preds, frame_y_pred)
|
81 |
|
82 |
+
frame_label = target.detach().cpu().numpy()
|
83 |
frame_labels = np.append(frame_labels, frame_label)
|
|
|
84 |
video_names_list.extend(list(video_name))
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
# video-level metrics:
|
87 |
frame_labels_list = frame_labels.tolist()
|
88 |
frame_preds_list = frame_preds.tolist()
|
89 |
+
video_label_list, video_pred_list, video_y_pred_list = get_video_level_label_pred_multi_class(frame_labels_list, video_names_list, frame_preds_list)
|
90 |
|
91 |
+
return frame_preds_list, video_pred_list
|
|
|
|
|
|
|
|
|
|
|
92 |
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
+
# @torch.no_grad()
|
95 |
+
# def test_multi_class(data_loader, model, device):
|
96 |
+
# criterion = torch.nn.CrossEntropyLoss()
|
97 |
+
#
|
98 |
+
# # switch to evaluation mode
|
99 |
+
# model.eval()
|
100 |
+
#
|
101 |
+
# frame_labels = np.array([]) # int label
|
102 |
+
# frame_preds = np.array([]) # pred logit
|
103 |
+
# frame_y_preds = np.array([]) # pred int
|
104 |
+
# video_names_list = list()
|
105 |
+
#
|
106 |
+
# for batch in data_loader:
|
107 |
+
# images = batch[0] # torch.Size([BS, C, H, W])
|
108 |
+
# target = batch[1] # torch.Size([BS])
|
109 |
+
# video_name = batch[-1] # list[BS]
|
110 |
+
# images = images.to(device, non_blocking=True)
|
111 |
+
# target = target.to(device, non_blocking=True)
|
112 |
+
#
|
113 |
+
# output = model(images).to(device, non_blocking=True)
|
114 |
+
# loss = criterion(output, target)
|
115 |
+
#
|
116 |
+
# frame_pred = F.softmax(output, dim=1).detach().cpu().numpy()
|
117 |
+
# frame_preds = np.append(frame_preds, frame_pred, axis=0)
|
118 |
+
# frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
|
119 |
+
# frame_y_preds = np.append(frame_y_preds, frame_y_pred)
|
120 |
+
#
|
121 |
+
# frame_label = target.detach().cpu().numpy()
|
122 |
+
# frame_labels = np.append(frame_labels, frame_label)
|
123 |
+
# video_names_list.extend(list(video_name))
|
124 |
+
#
|
125 |
+
# # video-level metrics:
|
126 |
+
# frame_labels_list = frame_labels.tolist()
|
127 |
+
# frame_preds_list = frame_preds.tolist()
|
128 |
+
# video_label_list, video_pred_list, video_y_pred_list = get_video_level_label_pred_multi_class(frame_labels_list, video_names_list, frame_preds_list)
|
129 |
+
#
|
130 |
+
# return frame_preds_list, video_pred_list
|
models_vit.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
-
# Author: Gaojian Wang@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
util/crop.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
-
# Author: Gaojian Wang@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
util/datasets.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
-
# Author: Gaojian Wang@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
util/lars.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
-
# Author: Gaojian Wang@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
util/loss_contrastive.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
-
# Author: Gaojian Wang@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
util/lr_decay.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
-
# Author: Gaojian Wang@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
util/lr_sched.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
-
# Author: Gaojian Wang@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
util/metrics.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
-
# Author: Gaojian Wang@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
@@ -68,6 +68,34 @@ def get_video_level_label_pred(f_label_list, v_name_list, f_pred_list):
|
|
68 |
return video_label_list, video_pred_list, video_y_pred_list
|
69 |
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
def video_level_acc(video_label_list, video_y_pred_list):
|
72 |
return accuracy_score(video_label_list, video_y_pred_list) * 100.
|
73 |
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
|
|
68 |
return video_label_list, video_pred_list, video_y_pred_list
|
69 |
|
70 |
|
71 |
+
def get_video_level_label_pred_multi_class(f_label_list, v_name_list, f_pred_list):
|
72 |
+
import numpy as np
|
73 |
+
"""
|
74 |
+
Adapted for multi-class predictions.
|
75 |
+
"""
|
76 |
+
video_res_dict = dict()
|
77 |
+
video_pred_list = list()
|
78 |
+
video_y_pred_list = list()
|
79 |
+
video_label_list = list()
|
80 |
+
|
81 |
+
# Summarize all the results for each video
|
82 |
+
for label, video, score in zip(f_label_list, v_name_list, f_pred_list):
|
83 |
+
if video not in video_res_dict.keys():
|
84 |
+
video_res_dict[video] = {"scores": [score], "label": label}
|
85 |
+
else:
|
86 |
+
video_res_dict[video]["scores"].append(score)
|
87 |
+
|
88 |
+
# Get the score and label for each video
|
89 |
+
for video, res in video_res_dict.items():
|
90 |
+
avg_score = np.mean(res['scores'], axis=0)
|
91 |
+
label = res['label']
|
92 |
+
video_pred_list.append(avg_score)
|
93 |
+
video_label_list.append(label)
|
94 |
+
video_y_pred_list.append(np.argmax(avg_score))
|
95 |
+
|
96 |
+
return video_label_list, video_pred_list, video_y_pred_list
|
97 |
+
|
98 |
+
|
99 |
def video_level_acc(video_label_list, video_y_pred_list):
|
100 |
return accuracy_score(video_label_list, video_y_pred_list) * 100.
|
101 |
|
util/misc.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
-
# Author: Gaojian Wang@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
util/pos_embed.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
-
# Author: Gaojian Wang@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
|
3 |
# --------------------------------------------------------
|
4 |
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
# You can find the license in the LICENSE file in the root directory of this source tree.
|