wolo-wolo commited on
Commit
4d10ed1
·
1 Parent(s): e46e042
app.py CHANGED
@@ -1,220 +1,114 @@
1
  # -*- coding: utf-8 -*-
2
- # Author: Gaojian Wang@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
6
  # --------------------------------------------------------
7
- # pip uninstall nvidia_cublas_cu11
8
 
9
  import sys
10
-
11
- sys.path.append('..')
12
  import os
13
-
14
  os.system(f'pip install dlib')
15
- import torch
 
16
  import numpy as np
17
  from PIL import Image
18
- from torch.nn import functional as F
19
-
 
20
  import gradio as gr
21
 
22
  import models_vit
23
  from util.datasets import build_dataset
24
- import argparse
25
- from engine_finetune import test_all
26
- import dlib
27
- from huggingface_hub import hf_hub_download
28
-
29
- P = os.path.abspath(__file__)
30
- FRAME_SAVE_PATH = os.path.join(P[:-6], 'frame')
31
- CKPT_SAVE_PATH = os.path.join(P[:-6], 'checkpoints')
32
- CKPT_LIST = ['DfD-Checkpoint_Fine-tuned_on_FF++',
33
- 'FAS-Checkpoint_Fine-tuned_on_MCIO']
34
- CKPT_NAME = {'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
35
- 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth'}
36
- os.makedirs(FRAME_SAVE_PATH, exist_ok=True)
37
- os.makedirs(CKPT_SAVE_PATH, exist_ok=True)
38
 
39
 
40
  def get_args_parser():
41
- parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
42
- parser.add_argument('--batch_size', default=64, type=int,
43
- help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
44
  parser.add_argument('--epochs', default=50, type=int)
45
- parser.add_argument('--accum_iter', default=1, type=int,
46
- help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
47
-
48
- # Model parameters
49
- parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
50
- help='Name of model to train')
51
-
52
- parser.add_argument('--input_size', default=224, type=int,
53
- help='images input size')
54
- parser.add_argument('--normalize_from_IMN', action='store_true',
55
- help='cal mean and std from imagenet, else from pretrain datasets')
56
  parser.set_defaults(normalize_from_IMN=True)
57
- parser.add_argument('--apply_simple_augment', action='store_true',
58
- help='apply simple data augment')
59
-
60
- parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
61
- help='Drop path rate (default: 0.1)')
62
-
63
- # Optimizer parameters
64
- parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
65
- help='Clip gradient norm (default: None, no clipping)')
66
- parser.add_argument('--weight_decay', type=float, default=0.05,
67
- help='weight decay (default: 0.05)')
68
-
69
- parser.add_argument('--lr', type=float, default=None, metavar='LR',
70
- help='learning rate (absolute lr)')
71
- parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
72
- help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
73
- parser.add_argument('--layer_decay', type=float, default=0.75,
74
- help='layer-wise lr decay from ELECTRA/BEiT')
75
-
76
- parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
77
- help='lower lr bound for cyclic schedulers that hit 0')
78
-
79
- parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
80
- help='epochs to warmup LR')
81
-
82
- # Augmentation parameters
83
- parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
84
- help='Color jitter factor (enabled only when not using Auto/RandAug)')
85
- parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
86
- help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
87
- parser.add_argument('--smoothing', type=float, default=0.1,
88
- help='Label smoothing (default: 0.1)')
89
-
90
- # * Random Erase params
91
- parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
92
- help='Random erase prob (default: 0.25)')
93
- parser.add_argument('--remode', type=str, default='pixel',
94
- help='Random erase mode (default: "pixel")')
95
- parser.add_argument('--recount', type=int, default=1,
96
- help='Random erase count (default: 1)')
97
- parser.add_argument('--resplit', action='store_true', default=False,
98
- help='Do not random erase first (clean) augmentation split')
99
-
100
- # * Mixup params
101
- parser.add_argument('--mixup', type=float, default=0,
102
- help='mixup alpha, mixup enabled if > 0.')
103
- parser.add_argument('--cutmix', type=float, default=0,
104
- help='cutmix alpha, cutmix enabled if > 0.')
105
- parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
106
- help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
107
- parser.add_argument('--mixup_prob', type=float, default=1.0,
108
- help='Probability of performing mixup or cutmix when either/both is enabled')
109
- parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
110
- help='Probability of switching to cutmix when both mixup and cutmix enabled')
111
- parser.add_argument('--mixup_mode', type=str, default='batch',
112
- help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
113
-
114
- # * Finetuning params
115
- parser.add_argument('--finetune', default='',
116
- help='finetune from checkpoint')
117
  parser.add_argument('--global_pool', action='store_true')
118
  parser.set_defaults(global_pool=True)
119
- parser.add_argument('--cls_token', action='store_false', dest='global_pool',
120
- help='Use class token instead of global pool for classification')
121
-
122
- # Dataset parameters
123
- parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
124
- help='dataset path')
125
- parser.add_argument('--nb_classes', default=1000, type=int,
126
- help='number of the classification types')
127
-
128
- parser.add_argument('--output_dir', default='',
129
- help='path where to save, empty for no saving')
130
- parser.add_argument('--log_dir', default='',
131
- help='path where to tensorboard log')
132
- parser.add_argument('--device', default='cuda',
133
- help='device to use for training / testing')
134
  parser.add_argument('--seed', default=0, type=int)
135
- parser.add_argument('--resume', default='',
136
- help='resume from checkpoint')
137
-
138
- parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
139
- help='start epoch')
140
- parser.add_argument('--eval', action='store_true',
141
- help='Perform evaluation only')
142
  parser.set_defaults(eval=True)
143
- parser.add_argument('--dist_eval', action='store_true', default=False,
144
- help='Enabling distributed evaluation (recommended during training for faster monitor')
145
  parser.add_argument('--num_workers', default=10, type=int)
146
- parser.add_argument('--pin_mem', action='store_true',
147
- help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
148
  parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
149
  parser.set_defaults(pin_mem=True)
150
-
151
- # distributed training parameters
152
- parser.add_argument('--world_size', default=1, type=int,
153
- help='number of distributed processes')
154
  parser.add_argument('--local_rank', default=-1, type=int)
155
  parser.add_argument('--dist_on_itp', action='store_true')
156
- parser.add_argument('--dist_url', default='env://',
157
- help='url used to set up distributed training')
158
-
159
  return parser
160
 
161
 
162
- args = get_args_parser()
163
- args = args.parse_args()
164
- args.nb_classes = 2
165
-
166
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
167
-
168
- model = models_vit.__dict__['vit_base_patch16'](
169
- num_classes=args.nb_classes,
170
- drop_path_rate=args.drop_path,
171
- global_pool=args.global_pool,
172
- ).to(device)
173
-
174
-
175
- def load_model(ckpt):
176
- if ckpt == 'choose from here' or 'continuously updating...':
177
- return gr.update()
178
- args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_NAME[ckpt])
179
- if os.path.isfile(args.resume) == False:
180
- hf_hub_download(local_dir=CKPT_SAVE_PATH,
181
- repo_id='Wolowolo/fsfm-3c/' + CKPT_NAME[ckpt],
182
- filename=ckpt)
183
- checkpoint = torch.load(args.resume, map_location='cpu')
184
- model.load_state_dict(checkpoint['model'])
185
  model.eval()
186
- return gr.update()
187
 
188
 
189
  def get_boundingbox(face, width, height, minsize=None):
190
- """
191
- From FF++:
192
- https://github.com/ondyari/FaceForensics/blob/master/classification/detect_from_video.py
193
- Expects a dlib face to generate a quadratic bounding box.
194
- :param face: dlib face class
195
- :param width: frame width
196
- :param height: frame height
197
- :param cfg.face_scale: bounding box size multiplier to get a bigger face region
198
- :param minsize: set minimum bounding box size
199
- :return: x, y, bounding_box_size in opencv form
200
- """
201
- x1 = face.left()
202
- y1 = face.top()
203
- x2 = face.right()
204
- y2 = face.bottom()
205
  size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
206
- if minsize:
207
- if size_bb < minsize:
208
- size_bb = minsize
209
  center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
210
-
211
- # Check for out of bounds, x-y top left corner
212
- x1 = max(int(center_x - size_bb // 2), 0)
213
- y1 = max(int(center_y - size_bb // 2), 0)
214
- # Check for too big bb size for given x, y
215
  size_bb = min(width - x1, size_bb)
216
  size_bb = min(height - y1, size_bb)
217
-
218
  return x1, y1, size_bb
219
 
220
 
@@ -222,200 +116,226 @@ def extract_face(frame):
222
  face_detector = dlib.get_frontal_face_detector()
223
  image = np.array(frame.convert('RGB'))
224
  faces = face_detector(image, 1)
225
- if len(faces) > 0:
226
- # For now only take the biggest face
227
  face = faces[0]
228
- # Face crop and rescale(follow FF++)
229
  x, y, size = get_boundingbox(face, image.shape[1], image.shape[0])
230
- # Get the landmarks/parts for the face in box d only with the five key points
231
  cropped_face = image[y:y + size, x:x + size]
232
- # cropped_face = cv2.resize(cropped_face, (224, 224), interpolation=cv2.INTER_CUBIC)
233
  return Image.fromarray(cropped_face)
234
- else:
235
- return None
236
 
237
 
238
  def get_frame_index_uniform_sample(total_frame_num, extract_frame_num):
239
- interval = np.linspace(0, total_frame_num - 1, num=extract_frame_num, dtype=int)
240
- return interval.tolist()
241
-
242
-
243
- import cv2
244
 
245
 
246
- def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None, device='cpu'):
247
- """
248
- 1) extract specific num of frames from videos in [1st(index 0) frame, last frame] with uniform sample interval
249
- 2) extract face from frame with specific enlarge size
250
- """
251
  video_capture = cv2.VideoCapture(src_video)
252
- total_frames = video_capture.get(7)
253
-
254
- # extract from the 1st(index 0) frame
255
- if num_frames is not None:
256
- frame_indices = get_frame_index_uniform_sample(total_frames, num_frames)
257
- else:
258
- frame_indices = range(int(total_frames))
259
-
260
  for frame_index in frame_indices:
261
  video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
262
  ret, frame = video_capture.read()
263
- image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
264
- img = extract_face(image)
265
- if img == None:
266
- continue
267
- img = img.resize((224, 224), Image.BICUBIC)
268
  if not ret:
269
  continue
270
- save_img_name = f"frame_{frame_index}.png"
271
-
272
- img.save(os.path.join(dst_path, '0', save_img_name))
273
- # cv2.imwrite(os.path.join(dst_path, '0', save_img_name), frame)
274
-
 
275
  video_capture.release()
276
- # cv2.destroyAllWindows()
277
  return frame_indices
278
 
279
 
280
- def FSFM3C_video_detection(video, ckpt_select_dropdown):
281
- # extract frames
282
- num_frames = 32
283
-
284
- files = os.listdir(FRAME_SAVE_PATH)
285
- num_files = len(files)
286
- frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files))
287
  os.makedirs(frame_path, exist_ok=True)
288
  os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
289
- frame_indices = extract_face_from_fixed_num_frames(video, frame_path, num_frames=num_frames, device=device)
290
-
291
- args.data_path = frame_path
292
- args.batch_size = 32
293
- dataset_val = build_dataset(is_train=False, args=args)
294
- sampler_val = torch.utils.data.SequentialSampler(dataset_val)
295
- data_loader_val = torch.utils.data.DataLoader(
296
- dataset_val, sampler=sampler_val,
297
- batch_size=args.batch_size,
298
- num_workers=args.num_workers,
299
- pin_memory=args.pin_mem,
300
- drop_last=False
301
- )
302
-
303
- frame_preds_list, video_pred_list = test_all(data_loader_val, model, device)
304
-
305
- real_prob_frames = [round(1. - fake_score, 2) for fake_score in video_pred_list]
306
- frame_results = {f"frame_{frame}": f"{int(real_prob_frames[i] * 100)}%" for i, frame in enumerate(frame_indices)}
307
-
308
- real_prob_video = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
309
- if real_prob_video > 50:
310
- result_message = "real" if 'FAS' not in ckpt_select_dropdown else 'spoof'
311
- else:
312
- result_message = "fake" if 'FAS' not in ckpt_select_dropdown else 'real'
313
- prob = 1 - real_prob_image if real_prob_video <= 50 else real_prob_video
314
- image_results = (f"The face in this image may be {result_message} with probability is {real_prob_image}%")
315
-
316
- video_results = (f"The face in this video may be {result_message} with probability {prob}")
317
-
318
- return video_results
319
-
320
-
321
- def FSFM3C_image_detection(image, ckpt_select_dropdown):
322
- files = os.listdir(FRAME_SAVE_PATH)
323
- num_files = len(files)
324
- frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files))
325
- os.makedirs(frame_path, exist_ok=True)
326
- os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
327
-
328
- save_img_name = f"frame_0.png"
329
  img = extract_face(image)
330
  if img is None:
331
- return ['Invalid Input']
332
  img = img.resize((224, 224), Image.BICUBIC)
333
- img.save(os.path.join(frame_path, '0', save_img_name))
334
-
335
  args.data_path = frame_path
336
  args.batch_size = 1
337
  dataset_val = build_dataset(is_train=False, args=args)
338
  sampler_val = torch.utils.data.SequentialSampler(dataset_val)
339
- data_loader_val = torch.utils.data.DataLoader(
340
- dataset_val, sampler=sampler_val,
341
- batch_size=args.batch_size,
342
- num_workers=args.num_workers,
343
- pin_memory=args.pin_mem,
344
- drop_last=False
345
- )
346
-
347
- frame_preds_list, video_pred_list = test_all(data_loader_val, model, device)
348
-
349
- real_prob_image = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
350
- if real_prob_image > 50:
351
- result_message = "real" if 'FAS' not in ckpt_select_dropdown else 'spoof'
352
- else:
353
- result_message = "fake" if 'FAS' not in ckpt_select_dropdown else 'real'
354
- prob = 1 - real_prob_image if real_prob_image <= 50 else real_prob_image
355
- image_results = (f"The face in this image may be {result_message} with probability is {real_prob_image}%")
356
-
357
- return image_results
358
-
359
-
360
- # WebUI
361
- with gr.Blocks() as demo:
362
- gr.HTML(
363
- "<h1 style='text-align: center;'>🦱 Real Facial Image&Video Detection <br> Against Face Forgery and Spoofing (Deepfake/Diffusion/Presentation-attacks)</h1>")
364
- gr.Markdown("### ---Powered by the fine-tuned model that is pre-trained from [FSFM-3C](https://fsfm-3c.github.io/)")
365
-
366
- gr.Markdown("### Release:")
367
-
368
- gr.Markdown("- <b>V1.0 [2024-12] (Current):</b> "
369
- "Create this page with basic detectors (simply fine-tuned models) that follow the paper implementation. "
370
- "<b>Notes:</b> Performance is limited because no any optimization of data, models, hyperparameters, etc. is done for downstream tasks. <br> "
371
- "<b>[TODO]: </b> Update practical models, and optimized interfaces, and provide more functions such as visualizations, a unified detector, and multi-modal diagnosis.")
372
-
373
- gr.Markdown(
374
- "> Please provide an <b>image</b> or a <b>video (<100s </b>, default to uniform sampling 32 frames)</b> and <b>select the model</b> for detection. <br>"
375
- "- <b>DfD-Checkpoint_Fine-tuned_on_FF++</b> for deepfake detection, FSFM VIT-B fine-tuned on the FF++_c23 dataset (train&val sets of 4 manipulations, 32 frames per video) <br>"
376
- "- <b>FAS-Checkpoint_Fine-tuned_on_MCIO</b> for face anti-spoofing, FSFM VIT-B fine-tuned on the MCIO datasets (2 frames per video) ")
377
-
378
- with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  ckpt_select_dropdown = gr.Dropdown(
380
- label="Select the Model Checkpoint for Detection (🖱️ below)",
381
- choices=['choose from here'] + CKPT_LIST + ['continuously updating...'],
 
382
  multiselect=False,
383
- value='choose from here',
384
  interactive=True,
385
  )
386
- with gr.Row(elem_classes="center-align"):
387
- with gr.Column(scale=5):
388
- gr.Markdown(
389
- "## Image Detection"
390
- )
391
- image = gr.Image(label="Upload/Capture/Paste your image", type="pil")
392
- image_submit_btn = gr.Button("Submit")
393
- output_results_image = gr.Textbox(label="Detection Result")
394
- with gr.Column(scale=5):
395
- gr.Markdown(
396
- "## Video Detection"
397
- )
398
- video = gr.Video(label="Upload/Capture your video")
399
- video_submit_btn = gr.Button("Submit")
400
- output_results_video = gr.Textbox(label="Detection Result")
401
 
 
 
 
 
 
402
  image_submit_btn.click(
403
  fn=FSFM3C_image_detection,
404
- inputs=[image, ckpt_select_dropdown],
405
  outputs=[output_results_image],
406
  )
407
  video_submit_btn.click(
408
  fn=FSFM3C_video_detection,
409
- inputs=[video, ckpt_select_dropdown],
410
  outputs=[output_results_video],
411
  )
412
- ckpt_select_dropdown.change(
413
- fn=load_model,
414
- inputs=[ckpt_select_dropdown],
415
- outputs=[ckpt_select_dropdown],
416
- )
417
 
418
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  gr.close_all()
420
  demo.queue()
421
- demo.launch()
 
 
1
  # -*- coding: utf-8 -*-
2
+ # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
6
  # --------------------------------------------------------
 
7
 
8
  import sys
 
 
9
  import os
 
10
  os.system(f'pip install dlib')
11
+ import dlib
12
+ import argparse
13
  import numpy as np
14
  from PIL import Image
15
+ import cv2
16
+ import torch
17
+ from huggingface_hub import hf_hub_download
18
  import gradio as gr
19
 
20
  import models_vit
21
  from util.datasets import build_dataset
22
+ from engine_finetune import test_two_class, test_multi_class
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  def get_args_parser():
26
+ parser = argparse.ArgumentParser('FSFM3C fine-tuning&Testing for image classification', add_help=False)
27
+ parser.add_argument('--batch_size', default=64, type=int, help='Batch size per GPU')
 
28
  parser.add_argument('--epochs', default=50, type=int)
29
+ parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations')
30
+ parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', help='Name of model to train')
31
+ parser.add_argument('--input_size', default=224, type=int, help='images input size')
32
+ parser.add_argument('--normalize_from_IMN', action='store_true', help='cal mean and std from imagenet')
 
 
 
 
 
 
 
33
  parser.set_defaults(normalize_from_IMN=True)
34
+ parser.add_argument('--apply_simple_augment', action='store_true', help='apply simple data augment')
35
+ parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path rate')
36
+ parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', help='Clip gradient norm')
37
+ parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay')
38
+ parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate')
39
+ parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', help='base learning rate')
40
+ parser.add_argument('--layer_decay', type=float, default=0.75, help='layer-wise lr decay')
41
+ parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound')
42
+ parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR')
43
+ parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', help='Color jitter factor')
44
+ parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy')
45
+ parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing')
46
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob')
47
+ parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode')
48
+ parser.add_argument('--recount', type=int, default=1, help='Random erase count')
49
+ parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first augmentation split')
50
+ parser.add_argument('--mixup', type=float, default=0, help='mixup alpha')
51
+ parser.add_argument('--cutmix', type=float, default=0, help='cutmix alpha')
52
+ parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio')
53
+ parser.add_argument('--mixup_prob', type=float, default=1.0, help='Probability of performing mixup or cutmix')
54
+ parser.add_argument('--mixup_switch_prob', type=float, default=0.5, help='Probability of switching to cutmix')
55
+ parser.add_argument('--mixup_mode', type=str, default='batch', help='How to apply mixup/cutmix params')
56
+ parser.add_argument('--finetune', default='', help='finetune from checkpoint')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  parser.add_argument('--global_pool', action='store_true')
58
  parser.set_defaults(global_pool=True)
59
+ parser.add_argument('--cls_token', action='store_false', dest='global_pool', help='Use class token for classification')
60
+ parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, help='dataset path')
61
+ parser.add_argument('--nb_classes', default=1000, type=int, help='number of the classification types')
62
+ parser.add_argument('--output_dir', default='', help='path where to save')
63
+ parser.add_argument('--log_dir', default='', help='path where to tensorboard log')
64
+ parser.add_argument('--device', default='cuda', help='device to use for training / testing')
 
 
 
 
 
 
 
 
 
65
  parser.add_argument('--seed', default=0, type=int)
66
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
67
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
68
+ parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
 
 
 
 
69
  parser.set_defaults(eval=True)
70
+ parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation')
 
71
  parser.add_argument('--num_workers', default=10, type=int)
72
+ parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader')
 
73
  parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
74
  parser.set_defaults(pin_mem=True)
75
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
 
 
 
76
  parser.add_argument('--local_rank', default=-1, type=int)
77
  parser.add_argument('--dist_on_itp', action='store_true')
78
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
 
 
79
  return parser
80
 
81
 
82
+ def load_model(select_skpt):
83
+ global ckpt, device, model, checkpoint
84
+ if select_skpt not in CKPT_NAME:
85
+ return gr.update(), "Select a correct model"
86
+ ckpt = select_skpt
87
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
+ args.nb_classes = CKPT_CLASS[ckpt]
89
+ model = models_vit.__dict__[CKPT_MODEL[ckpt]](
90
+ num_classes=args.nb_classes,
91
+ drop_path_rate=args.drop_path,
92
+ global_pool=args.global_pool,
93
+ ).to(device)
94
+
95
+ args.resume = os.path.join(CKPT_SAVE_PATH, ckpt)
96
+ args.resume = CKPT_PATH[ckpt]
97
+ checkpoint = torch.load(args.resume, map_location=device)
98
+ model.load_state_dict(checkpoint['model'], strict=False)
 
 
 
 
 
 
99
  model.eval()
100
+ return gr.update(), f"[Loaded Model Successfully:] {args.resume}] "
101
 
102
 
103
  def get_boundingbox(face, width, height, minsize=None):
104
+ x1, y1, x2, y2 = face.left(), face.top(), face.right(), face.bottom()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
106
+ if minsize and size_bb < minsize:
107
+ size_bb = minsize
 
108
  center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
109
+ x1, y1 = max(int(center_x - size_bb // 2), 0), max(int(center_y - size_bb // 2), 0)
 
 
 
 
110
  size_bb = min(width - x1, size_bb)
111
  size_bb = min(height - y1, size_bb)
 
112
  return x1, y1, size_bb
113
 
114
 
 
116
  face_detector = dlib.get_frontal_face_detector()
117
  image = np.array(frame.convert('RGB'))
118
  faces = face_detector(image, 1)
119
+ if faces:
 
120
  face = faces[0]
 
121
  x, y, size = get_boundingbox(face, image.shape[1], image.shape[0])
 
122
  cropped_face = image[y:y + size, x:x + size]
 
123
  return Image.fromarray(cropped_face)
124
+ return None
 
125
 
126
 
127
  def get_frame_index_uniform_sample(total_frame_num, extract_frame_num):
128
+ return np.linspace(0, total_frame_num - 1, num=extract_frame_num, dtype=int).tolist()
 
 
 
 
129
 
130
 
131
+ def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None):
 
 
 
 
132
  video_capture = cv2.VideoCapture(src_video)
133
+ total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
134
+ frame_indices = get_frame_index_uniform_sample(total_frames, num_frames) if num_frames else range(total_frames)
 
 
 
 
 
 
135
  for frame_index in frame_indices:
136
  video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
137
  ret, frame = video_capture.read()
 
 
 
 
 
138
  if not ret:
139
  continue
140
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
141
+ img = extract_face(image)
142
+ if img:
143
+ img = img.resize((224, 224), Image.BICUBIC)
144
+ save_img_name = f"frame_{frame_index}.png"
145
+ img.save(os.path.join(dst_path, '0', save_img_name))
146
  video_capture.release()
 
147
  return frame_indices
148
 
149
 
150
+ def FSFM3C_image_detection(image):
151
+ frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
 
 
 
 
 
152
  os.makedirs(frame_path, exist_ok=True)
153
  os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  img = extract_face(image)
155
  if img is None:
156
+ return 'No face detected, please upload a clear face!'
157
  img = img.resize((224, 224), Image.BICUBIC)
158
+ img.save(os.path.join(frame_path, '0', "frame_0.png"))
 
159
  args.data_path = frame_path
160
  args.batch_size = 1
161
  dataset_val = build_dataset(is_train=False, args=args)
162
  sampler_val = torch.utils.data.SequentialSampler(dataset_val)
163
+ data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
164
+
165
+ if CKPT_CLASS[ckpt] > 2:
166
+ frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
167
+ class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
168
+ avg_video_pred = np.mean(video_pred_list, axis=0)
169
+ max_prob_index = np.argmax(avg_video_pred)
170
+ max_prob_class = class_names[max_prob_index]
171
+ probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
172
+ image_results = f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]"
173
+ return image_results
174
+
175
+ if CKPT_CLASS[ckpt] == 2:
176
+ frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
177
+ if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
178
+ prob = sum(video_pred_list) / len(video_pred_list)
179
+ label = "Deepfake" if prob <= 0.5 else "Real"
180
+ prob = prob if label == "Real" else 1 - prob
181
+ if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
182
+ prob = sum(video_pred_list) / len(video_pred_list)
183
+ label = "Spoofing" if prob <= 0.5 else "Bonafide"
184
+ prob = prob if label == "Bonafide" else 1 - prob
185
+ image_results = f"The largest face in this image may be {label} with probability {prob * 100:.1f}%"
186
+ return image_results
187
+
188
+
189
+ def FSFM3C_video_detection(video, num_frames):
190
+ try:
191
+ frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
192
+ os.makedirs(frame_path, exist_ok=True)
193
+ os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
194
+ frame_indices = extract_face_from_fixed_num_frames(video, frame_path, num_frames=num_frames)
195
+ args.data_path = frame_path
196
+ args.batch_size = num_frames
197
+ dataset_val = build_dataset(is_train=False, args=args)
198
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
199
+ data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
200
+
201
+ if CKPT_CLASS[ckpt] > 2:
202
+ frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
203
+ class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
204
+ avg_video_pred = np.mean(video_pred_list, axis=0)
205
+ max_prob_index = np.argmax(avg_video_pred)
206
+ max_prob_class = class_names[max_prob_index]
207
+ probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
208
+
209
+ frame_results = {f"frame_{frame_indices[i]}": [f"{class_names[j]}: {prob * 100:.1f}%" for j, prob in enumerate(frame_preds_list[i])] for i in range(len(frame_indices))}
210
+ video_results = (f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]\n \n"
211
+ f"The frame-level detection results ['frame_index': 'probabilities']: \n{frame_results}")
212
+ return video_results
213
+
214
+ if CKPT_CLASS[ckpt] == 2:
215
+ frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
216
+ if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
217
+ prob = sum(video_pred_list) / len(video_pred_list)
218
+ label = "Deepfake" if prob <= 0.5 else "Real"
219
+ prob = prob if label == "Real" else 1 - prob
220
+ frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
221
+ range(len(frame_indices))} if label == "Real" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
222
+ range(len(frame_indices))}
223
+
224
+ if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
225
+ prob = sum(video_pred_list) / len(video_pred_list)
226
+ label = "Spoofing" if prob <= 0.5 else "Bonafide"
227
+ prob = prob if label == "Bonafide" else 1 - prob
228
+ frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
229
+ range(len(frame_indices))} if label == "Bonafide" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
230
+ range(len(frame_indices))}
231
+
232
+ video_results = (f"The largest face in this image may be {label} with probability {prob * 100:.1f}%\n \n"
233
+ f"The frame-level detection results ['frame_index': 'real_face_probability']: \n{frame_results}")
234
+ return video_results
235
+ except Exception as e:
236
+ return f"Error occurred. Please provide a clear face video or reduce the number of frames."
237
+
238
+ # Paths and Constants
239
+ P = os.path.abspath(__file__)
240
+ FRAME_SAVE_PATH = os.path.join(os.path.dirname(P), 'frame')
241
+ CKPT_SAVE_PATH = os.path.join(os.path.dirname(P), 'checkpoints')
242
+ os.makedirs(FRAME_SAVE_PATH, exist_ok=True)
243
+ os.makedirs(CKPT_SAVE_PATH, exist_ok=True)
244
+ CKPT_NAME = [
245
+ '✨Unified-detector_v1_Fine-tuned_on_4_classes',
246
+ 'DfD-Checkpoint_Fine-tuned_on_FF++',
247
+ 'FAS-Checkpoint_Fine-tuned_on_MCIO',
248
+ ]
249
+ # CKPT_PATH = {
250
+ # '✨Unified-detector_v1_Fine-tuned_on_4_classes': 'finetuned_models/Unified-detector/v1_Fine-tuned_on_4_classes/checkpoint-min_val_loss.pth',
251
+ # 'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
252
+ # 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth',
253
+ # }
254
+ CKPT_PATH = {
255
+ '✨Unified-detector_v1_Fine-tuned_on_4_classes': './checkpoints/checkpoint-min_train_loss.pth',
256
+ 'DfD-Checkpoint_Fine-tuned_on_FF++': '/mnt/localDisk2/wgj/FSFM/released/FSFM-main/fsfm-3c/finuetune/cross_dataset_DfD/checkpoint/finetuned_models/ft_on_FF++_c23_32frames/pt_from_VF2_ViT-B_epoch600/checkpoint-min_val_loss.pth',
257
+ 'FAS-Checkpoint_Fine-tuned_on_MCIO': '/mnt/localDisk2/wgj/FSFM/FSFM-3C/codespace/fsfm-3c/finuetune/cross_dataset_DfD/finetuned_models/FAS_MCIO/checkpoint-199.pth',
258
+ }
259
+ CKPT_CLASS = {
260
+ '✨Unified-detector_v1_Fine-tuned_on_4_classes': 4,
261
+ 'DfD-Checkpoint_Fine-tuned_on_FF++': 2,
262
+ 'FAS-Checkpoint_Fine-tuned_on_MCIO': 2
263
+ }
264
+ CKPT_MODEL = {
265
+ '✨Unified-detector_v1_Fine-tuned_on_4_classes': 'vit_base_patch16',
266
+ 'DfD-Checkpoint_Fine-tuned_on_FF++': 'vit_base_patch16',
267
+ 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'vit_base_patch16',
268
+ }
269
+
270
+
271
+ with gr.Blocks(css=".custom-label { font-weight: bold !important; font-size: 16px !important; }") as demo:
272
+ gr.HTML("<h1 style='text-align: center;'>🦱 Real Facial Image&Video Detection <br> Against Face Forgery (Deepfake/Diffusion) and Spoofing (Presentation-attacks)</h1>")
273
+ gr.Markdown("<b>☉ Powered by the fine-tuned model that is pre-trained from [FSFM-3C](https://fsfm-3c.github.io/)</b> <br> "
274
+ "<b>☉ Release (Continuously updating) </b> <br> <b>[V1.0]</b> 2025/02/22-Current🎉: "
275
+ "1) Updated <b>[✨Unified-detector_v1] for Unified Physical-Digital Face Attack&Forgery Detection, a vanilla ViT-B/16-224 (FSFM Pre-trained) that could identify Real&Bonafide, Deepfake, Diffduion&AIGC, Spooing&Presentation-attacks facial images or videos </b> ; 2) Provided the selection of the number of video frames (uniformly sampling, more frames are too time-consuming, and we would be grateful if you support us to open paid GPU acceleration); 3) Fixed the errors of V0.1 including loading model and prediction. <br>"
276
+ "<b>[V0.1]</b> 2024/12-2025/02/21: "
277
+ "Create this page with basic detectors [DfD-Checkpoint_Fine-tuned_on_FF++, FAS-Checkpoint_Fine-tuned_on_MCIO] that follow the paper implementation. <br> ")
278
+ gr.Markdown("- Please <b>provide a facial image or video(<100s)</b>, and <b>select the model</b> for detection: <br> <b>[suggest] [✨Unified-detector_v1_Fine-tuned_on_4_classes]</b> <b>a (FSFM Pre-trained) ViT-B/16-224 for Both Real/Deepfake/Diffusion/Spoofing facial images&videos Detection <b> <br> <b>[DfD-Checkpoint_Fine-tuned_on_FF++]</b> for deepfake detection, FSFM ViT-B/16-224 fine-tuned on the FF++_c23 train&val sets (4 manipulations, 32 frames per video) <br> <b>[FAS-Checkpoint_Fine-tuned_on_MCIO]</b> for face anti-spoofing, FSFM ViT-B/16-224 fine-tuned on the MCIO datasets (2 frames per video)")
279
+
280
+
281
+ with gr.Row():
282
  ckpt_select_dropdown = gr.Dropdown(
283
+ label="Select the Model for Detection ⬇️",
284
+ elem_classes="custom-label",
285
+ choices=['Choose Model Here 🖱️'] + CKPT_NAME + ['continuously updating...'],
286
  multiselect=False,
287
+ value='Choose Model Here 🖱️',
288
  interactive=True,
289
  )
290
+ model_loading_status = gr.Textbox(label="Model Loading Status")
291
+ with gr.Row():
292
+ with gr.Column(scale=5):
293
+ gr.Markdown("### Image Detection")
294
+ image = gr.Image(label="Upload/Capture/Paste your image", type="pil")
295
+ image_submit_btn = gr.Button("Submit")
296
+ output_results_image = gr.Textbox(label="Detection Result")
297
+ with gr.Column(scale=5):
298
+ gr.Markdown("### Video Detection")
299
+ video = gr.Video(label="Upload/Capture your video")
300
+ frame_slider = gr.Slider(minimum=1, maximum=32, step=1, value=32, label="Number of Frames for Detection")
301
+ video_submit_btn = gr.Button("Submit")
302
+ output_results_video = gr.Textbox(label="Detection Result")
 
 
303
 
304
+ ckpt_select_dropdown.change(
305
+ fn=load_model,
306
+ inputs=[ckpt_select_dropdown],
307
+ outputs=[ckpt_select_dropdown, model_loading_status],
308
+ )
309
  image_submit_btn.click(
310
  fn=FSFM3C_image_detection,
311
+ inputs=[image],
312
  outputs=[output_results_image],
313
  )
314
  video_submit_btn.click(
315
  fn=FSFM3C_video_detection,
316
+ inputs=[video, frame_slider],
317
  outputs=[output_results_video],
318
  )
319
+
 
 
 
 
320
 
321
  if __name__ == "__main__":
322
+ args = get_args_parser()
323
+ args = args.parse_args()
324
+ ckpt = 'DfD-Checkpoint_Fine-tuned_on_FF++'
325
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
326
+ args.nb_classes = CKPT_CLASS[ckpt]
327
+ model = models_vit.__dict__[CKPT_MODEL[ckpt]](
328
+ num_classes=args.nb_classes,
329
+ drop_path_rate=args.drop_path,
330
+ global_pool=args.global_pool,
331
+ ).to(device)
332
+ args.resume = os.path.join(CKPT_SAVE_PATH, ckpt)
333
+ args.resume = CKPT_PATH[ckpt]
334
+ checkpoint = torch.load(args.resume, map_location=device)
335
+ model.load_state_dict(checkpoint['model'], strict=False)
336
+ model.eval()
337
+
338
  gr.close_all()
339
  demo.queue()
340
+ # demo.launch()
341
+ demo.launch(server_name="0.0.0.0", server_port=8888)
engine_finetune.py CHANGED
@@ -1,323 +1,130 @@
1
  # -*- coding: utf-8 -*-
2
- # Author: Gaojian Wang@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
6
  # --------------------------------------------------------
7
 
8
- import math
9
- import sys
10
- from typing import Iterable, Optional
11
-
12
  import numpy as np
13
  import torch
14
-
15
- from timm.data import Mixup
16
- from timm.utils import accuracy
17
 
18
  import util.misc as misc
19
- import util.lr_sched as lr_sched
20
  from util.metrics import *
21
 
22
- import torch.nn.functional as F
23
-
24
-
25
- def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
26
- data_loader: Iterable, optimizer: torch.optim.Optimizer,
27
- device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
28
- mixup_fn: Optional[Mixup] = None, log_writer=None,
29
- args=None):
30
- model.train(True)
31
- metric_logger = misc.MetricLogger(delimiter=" ")
32
- metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
33
- header = 'Epoch: [{}]'.format(epoch)
34
- print_freq = 20
35
-
36
- accum_iter = args.accum_iter
37
-
38
- optimizer.zero_grad()
39
-
40
- if log_writer is not None:
41
- print('log_dir: {}'.format(log_writer.log_dir))
42
-
43
- for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
44
-
45
- # we use a per iteration (instead of per epoch) lr scheduler
46
- if data_iter_step % accum_iter == 0:
47
- lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
48
-
49
- samples = samples.to(device, non_blocking=True)
50
- targets = targets.to(device, non_blocking=True)
51
-
52
- if mixup_fn is not None:
53
- samples, targets = mixup_fn(samples, targets)
54
-
55
- with torch.cuda.amp.autocast():
56
- # outputs = model(samples)
57
- outputs = model(samples).to(device, non_blocking=True) # modified
58
- loss = criterion(outputs, targets)
59
-
60
- loss_value = loss.item()
61
-
62
- if not math.isfinite(loss_value):
63
- print("Loss is {}, stopping training".format(loss_value))
64
- sys.exit(1)
65
-
66
- loss /= accum_iter
67
- loss_scaler(loss, optimizer, clip_grad=max_norm,
68
- parameters=model.parameters(), create_graph=False,
69
- update_grad=(data_iter_step + 1) % accum_iter == 0)
70
- if (data_iter_step + 1) % accum_iter == 0:
71
- optimizer.zero_grad()
72
-
73
- torch.cuda.synchronize()
74
-
75
- metric_logger.update(loss=loss_value)
76
- min_lr = 10.
77
- max_lr = 0.
78
- for group in optimizer.param_groups:
79
- min_lr = min(min_lr, group["lr"])
80
- max_lr = max(max_lr, group["lr"])
81
-
82
- metric_logger.update(lr=max_lr)
83
-
84
- loss_value_reduce = misc.all_reduce_mean(loss_value)
85
- if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
86
- """ We use epoch_1000x as the x-axis in tensorboard.
87
- This calibrates different curves when batch size changes.
88
- """
89
- epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
90
- log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
91
- log_writer.add_scalar('lr', max_lr, epoch_1000x)
92
-
93
- # gather the stats from all processes
94
- metric_logger.synchronize_between_processes()
95
- print("Averaged stats:", metric_logger)
96
- return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
97
-
98
-
99
- @torch.no_grad()
100
- def evaluate(data_loader, model, device):
101
- criterion = torch.nn.CrossEntropyLoss()
102
-
103
- metric_logger = misc.MetricLogger(delimiter=" ")
104
- header = 'Test:'
105
-
106
- # switch to evaluation mode
107
- model.eval()
108
-
109
- for batch in metric_logger.log_every(data_loader, 10, header):
110
- images = batch[0]
111
- target = batch[-1]
112
- images = images.to(device, non_blocking=True)
113
- target = target.to(device, non_blocking=True)
114
-
115
- # compute output
116
- with torch.cuda.amp.autocast():
117
- # output = model(images)
118
- output = model(images).to(device, non_blocking=True) # modified
119
- loss = criterion(output, target)
120
-
121
- # acc1, acc5 = accuracy(output, target, topk=(1, 5))
122
- acc = float(accuracy(output, target, topk=(1,))[0])
123
- preds = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy())
124
- trues = (target.detach().cpu().numpy())
125
- auc_score = roc_auc_score(trues, preds) * 100.
126
-
127
- batch_size = images.shape[0]
128
- metric_logger.update(loss=loss.item())
129
- # metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
130
- # metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
131
- metric_logger.meters['acc'].update(acc, n=batch_size)
132
- metric_logger.meters['auc'].update(auc_score, n=batch_size)
133
-
134
- # gather the stats from all processes
135
- metric_logger.synchronize_between_processes()
136
- # print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
137
- # .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
138
- print('* Acc {acc.global_avg:.3f} Auc {auc.global_avg:.3f} loss {losses.global_avg:.3f}'
139
- .format(acc=metric_logger.acc, auc=metric_logger.auc, losses=metric_logger.loss))
140
-
141
- return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
142
-
143
-
144
- @torch.no_grad()
145
- def test_ori(data_loader, model, device):
146
- criterion = torch.nn.CrossEntropyLoss()
147
-
148
- metric_logger = misc.MetricLogger(delimiter=" ")
149
- header = 'Test:'
150
-
151
- # switch to evaluation mode
152
- model.eval()
153
-
154
- labels = np.array([])
155
- preds = np.array([])
156
-
157
- for batch in metric_logger.log_every(data_loader, 10, header):
158
- images = batch[0]
159
- target = batch[-1]
160
- images = images.to(device, non_blocking=True)
161
- target = target.to(device, non_blocking=True)
162
-
163
- # compute output
164
- with torch.cuda.amp.autocast():
165
- # output = model(images)
166
- output = model(images).to(device, non_blocking=True) # modified
167
- loss = criterion(output, target)
168
-
169
- # acc1, acc5 = accuracy(output, target, topk=(1, 5))
170
- acc = float(accuracy(output, target, topk=(1,))[0])
171
- pred = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy())
172
- preds = np.append(preds, pred)
173
- label = (target.detach().cpu().numpy())
174
- labels = np.append(labels, label)
175
-
176
- batch_size = images.shape[0]
177
- metric_logger.update(loss=loss.item())
178
- # metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
179
- # metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
180
- metric_logger.meters['acc'].update(acc, n=batch_size)
181
-
182
- # gather the stats from all processes
183
- metric_logger.synchronize_between_processes()
184
- # print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
185
- # .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
186
- auc_score = roc_auc_score(labels, preds) * 100.
187
- metric_logger.meters['auc'].update(auc_score)
188
- print('* Acc {acc.global_avg:.3f} Auc {auc.global_avg:.3f} loss {losses.global_avg:.3f}'
189
- .format(acc=metric_logger.acc, auc=metric_logger.auc, losses=metric_logger.loss))
190
-
191
- return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
192
-
193
 
194
  @torch.no_grad()
195
- def test(data_loader, model, device):
196
  criterion = torch.nn.CrossEntropyLoss()
197
 
198
- metric_logger = misc.MetricLogger(delimiter=" ")
199
- header = 'Test:'
200
-
201
  # switch to evaluation mode
202
  model.eval()
203
 
204
  frame_labels = np.array([]) # int label
205
  frame_preds = np.array([]) # pred logit
206
  frame_y_preds = np.array([]) # pred int
 
207
 
208
- # for batch in metric_logger.log_every(data_loader, print_freq=len(data_loader), header=header):
209
  for batch in data_loader:
210
  images = batch[0] # torch.Size([BS, C, H, W])
211
  target = batch[1] # torch.Size([BS])
212
-
213
  images = images.to(device, non_blocking=True)
214
  target = target.to(device, non_blocking=True)
215
 
216
- # compute output
217
- with torch.cuda.amp.autocast():
218
- # output = model(images)
219
- output = model(images).to(device, non_blocking=True) # modified
220
- loss = criterion(output, target)
221
 
222
  frame_pred = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy())
223
  frame_preds = np.append(frame_preds, frame_pred)
224
-
225
  frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
226
  frame_y_preds = np.append(frame_y_preds, frame_y_pred)
227
 
228
  frame_label = (target.detach().cpu().numpy())
229
  frame_labels = np.append(frame_labels, frame_label)
 
230
 
231
- metric_logger.update(loss=loss.item())
232
-
233
- # gather the stats from all processes
234
- metric_logger.synchronize_between_processes()
235
- metric_logger.meters['frame_acc'].update(frame_level_acc(frame_labels, frame_y_preds))
236
- metric_logger.meters['frame_balanced_acc'].update(frame_level_balanced_acc(frame_labels, frame_y_preds))
237
- metric_logger.meters['frame_auc'].update(frame_level_auc(frame_labels, frame_preds))
238
- metric_logger.meters['frame_eer'].update(frame_level_eer(frame_labels, frame_preds))
239
-
240
- print('*[------FRAME-LEVEL------] \n'
241
- 'Acc {frame_acc.global_avg:.3f} Balanced_Acc {frame_balanced_acc.global_avg:.3f} '
242
- 'Auc {frame_auc.global_avg:.3f} EER {frame_eer.global_avg:.3f} loss {losses.global_avg:.3f}'
243
- .format(frame_acc=metric_logger.frame_acc, frame_balanced_acc=metric_logger.frame_balanced_acc,
244
- frame_auc=metric_logger.frame_auc, frame_eer=metric_logger.frame_eer, losses=metric_logger.loss))
245
 
246
- return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
247
 
248
 
249
  @torch.no_grad()
250
- def test_all(data_loader, model, device):
251
  criterion = torch.nn.CrossEntropyLoss()
252
 
253
- metric_logger = misc.MetricLogger(delimiter=" ")
254
- header = 'Test:'
255
-
256
  # switch to evaluation mode
257
  model.eval()
258
 
259
  frame_labels = np.array([]) # int label
260
- frame_preds = np.array([]) # pred logit
261
  frame_y_preds = np.array([]) # pred int
262
  video_names_list = list()
263
 
264
- # for batch in metric_logger.log_every(data_loader, print_freq=len(data_loader), header=header):
265
  for batch in data_loader:
266
  images = batch[0] # torch.Size([BS, C, H, W])
267
  target = batch[1] # torch.Size([BS])
268
  video_name = batch[-1] # list[BS]
269
-
270
  images = images.to(device, non_blocking=True)
271
  target = target.to(device, non_blocking=True)
272
 
273
- # compute output
274
- # with torch.cuda.amp.autocast():
275
- # output = model(images)
276
- output = model(images).to(device, non_blocking=True) # modified
277
  loss = criterion(output, target)
278
 
279
- frame_pred = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy())
280
- frame_preds = np.append(frame_preds, frame_pred)
281
-
282
  frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
283
  frame_y_preds = np.append(frame_y_preds, frame_y_pred)
284
 
285
- frame_label = (target.detach().cpu().numpy())
286
  frame_labels = np.append(frame_labels, frame_label)
287
-
288
  video_names_list.extend(list(video_name))
289
 
290
- metric_logger.update(loss=loss.item())
291
-
292
- # gather the stats from all processes
293
- # metric_logger.synchronize_between_processes()
294
- # metric_logger.meters['frame_acc'].update(frame_level_acc(frame_labels, frame_y_preds))
295
- # metric_logger.meters['frame_balanced_acc'].update(frame_level_balanced_acc(frame_labels, frame_y_preds))
296
- # metric_logger.meters['frame_auc'].update(frame_level_auc(frame_labels, frame_preds))
297
- # metric_logger.meters['frame_eer'].update(frame_level_eer(frame_labels, frame_preds))
298
-
299
- # print('*[------FRAME-LEVEL------] \n'
300
- # 'Acc {frame_acc.global_avg:.3f} Balanced_Acc {frame_balanced_acc.global_avg:.3f} '
301
- # 'Auc {frame_auc.global_avg:.3f} EER {frame_eer.global_avg:.3f} loss {losses.global_avg:.3f}'
302
- # .format(frame_acc=metric_logger.frame_acc, frame_balanced_acc=metric_logger.frame_balanced_acc,
303
- # frame_auc=metric_logger.frame_auc, frame_eer=metric_logger.frame_eer, losses=metric_logger.loss))
304
-
305
  # video-level metrics:
306
  frame_labels_list = frame_labels.tolist()
307
  frame_preds_list = frame_preds.tolist()
 
308
 
309
- video_label_list, video_pred_list, video_y_pred_list = get_video_level_label_pred(frame_labels_list, video_names_list, frame_preds_list)
310
- # print(len(video_label_list), len(video_pred_list), len(video_y_pred_list))
311
- # metric_logger.meters['video_acc'].update(video_level_acc(video_label_list, video_y_pred_list))
312
- # metric_logger.meters['video_balanced_acc'].update(video_level_balanced_acc(video_label_list, video_y_pred_list))
313
- # metric_logger.meters['video_auc'].update(video_level_auc(video_label_list, video_pred_list))
314
- # metric_logger.meters['video_eer'].update(frame_level_eer(video_label_list, video_pred_list))
315
 
316
- # print('*[------VIDEO-LEVEL------] \n'
317
- # 'Acc {video_acc.global_avg:.3f} Balanced_Acc {video_balanced_acc.global_avg:.3f} '
318
- # 'Auc {video_auc.global_avg:.3f} EER {video_eer.global_avg:.3f}'
319
- # .format(video_acc=metric_logger.video_acc, video_balanced_acc=metric_logger.video_balanced_acc,
320
- # video_auc=metric_logger.video_auc, video_eer=metric_logger.video_eer))
321
 
322
- # return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
323
- return frame_preds_list, video_pred_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # -*- coding: utf-8 -*-
2
+ # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
6
  # --------------------------------------------------------
7
 
 
 
 
 
8
  import numpy as np
9
  import torch
10
+ import torch.nn.functional as F
 
 
11
 
12
  import util.misc as misc
 
13
  from util.metrics import *
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  @torch.no_grad()
17
+ def test_two_class(data_loader, model, device):
18
  criterion = torch.nn.CrossEntropyLoss()
19
 
 
 
 
20
  # switch to evaluation mode
21
  model.eval()
22
 
23
  frame_labels = np.array([]) # int label
24
  frame_preds = np.array([]) # pred logit
25
  frame_y_preds = np.array([]) # pred int
26
+ video_names_list = list()
27
 
 
28
  for batch in data_loader:
29
  images = batch[0] # torch.Size([BS, C, H, W])
30
  target = batch[1] # torch.Size([BS])
31
+ video_name = batch[-1] # list[BS]
32
  images = images.to(device, non_blocking=True)
33
  target = target.to(device, non_blocking=True)
34
 
35
+ output = model(images).to(device, non_blocking=True) # modified
36
+ loss = criterion(output, target)
 
 
 
37
 
38
  frame_pred = (F.softmax(output, dim=1)[:, 1].detach().cpu().numpy())
39
  frame_preds = np.append(frame_preds, frame_pred)
 
40
  frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
41
  frame_y_preds = np.append(frame_y_preds, frame_y_pred)
42
 
43
  frame_label = (target.detach().cpu().numpy())
44
  frame_labels = np.append(frame_labels, frame_label)
45
+ video_names_list.extend(list(video_name))
46
 
47
+ # video-level metrics:
48
+ frame_labels_list = frame_labels.tolist()
49
+ frame_preds_list = frame_preds.tolist()
50
+ video_label_list, video_pred_list, video_y_pred_list = get_video_level_label_pred(frame_labels_list, video_names_list, frame_preds_list)
 
 
 
 
 
 
 
 
 
 
51
 
52
+ return frame_preds_list, video_pred_list
53
 
54
 
55
  @torch.no_grad()
56
+ def test_multi_class(data_loader, model, device):
57
  criterion = torch.nn.CrossEntropyLoss()
58
 
 
 
 
59
  # switch to evaluation mode
60
  model.eval()
61
 
62
  frame_labels = np.array([]) # int label
63
+ frame_preds = np.empty((0, 4)) # pred logit, initialize as 2D array with 4 columns for 4 classes
64
  frame_y_preds = np.array([]) # pred int
65
  video_names_list = list()
66
 
 
67
  for batch in data_loader:
68
  images = batch[0] # torch.Size([BS, C, H, W])
69
  target = batch[1] # torch.Size([BS])
70
  video_name = batch[-1] # list[BS]
 
71
  images = images.to(device, non_blocking=True)
72
  target = target.to(device, non_blocking=True)
73
 
74
+ output = model(images).to(device, non_blocking=True)
 
 
 
75
  loss = criterion(output, target)
76
 
77
+ frame_pred = F.softmax(output, dim=1).detach().cpu().numpy()
78
+ frame_preds = np.append(frame_preds, frame_pred, axis=0)
 
79
  frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
80
  frame_y_preds = np.append(frame_y_preds, frame_y_pred)
81
 
82
+ frame_label = target.detach().cpu().numpy()
83
  frame_labels = np.append(frame_labels, frame_label)
 
84
  video_names_list.extend(list(video_name))
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  # video-level metrics:
87
  frame_labels_list = frame_labels.tolist()
88
  frame_preds_list = frame_preds.tolist()
89
+ video_label_list, video_pred_list, video_y_pred_list = get_video_level_label_pred_multi_class(frame_labels_list, video_names_list, frame_preds_list)
90
 
91
+ return frame_preds_list, video_pred_list
 
 
 
 
 
92
 
 
 
 
 
 
93
 
94
+ # @torch.no_grad()
95
+ # def test_multi_class(data_loader, model, device):
96
+ # criterion = torch.nn.CrossEntropyLoss()
97
+ #
98
+ # # switch to evaluation mode
99
+ # model.eval()
100
+ #
101
+ # frame_labels = np.array([]) # int label
102
+ # frame_preds = np.array([]) # pred logit
103
+ # frame_y_preds = np.array([]) # pred int
104
+ # video_names_list = list()
105
+ #
106
+ # for batch in data_loader:
107
+ # images = batch[0] # torch.Size([BS, C, H, W])
108
+ # target = batch[1] # torch.Size([BS])
109
+ # video_name = batch[-1] # list[BS]
110
+ # images = images.to(device, non_blocking=True)
111
+ # target = target.to(device, non_blocking=True)
112
+ #
113
+ # output = model(images).to(device, non_blocking=True)
114
+ # loss = criterion(output, target)
115
+ #
116
+ # frame_pred = F.softmax(output, dim=1).detach().cpu().numpy()
117
+ # frame_preds = np.append(frame_preds, frame_pred, axis=0)
118
+ # frame_y_pred = np.argmax(output.detach().cpu().numpy(), axis=1)
119
+ # frame_y_preds = np.append(frame_y_preds, frame_y_pred)
120
+ #
121
+ # frame_label = target.detach().cpu().numpy()
122
+ # frame_labels = np.append(frame_labels, frame_label)
123
+ # video_names_list.extend(list(video_name))
124
+ #
125
+ # # video-level metrics:
126
+ # frame_labels_list = frame_labels.tolist()
127
+ # frame_preds_list = frame_preds.tolist()
128
+ # video_label_list, video_pred_list, video_y_pred_list = get_video_level_label_pred_multi_class(frame_labels_list, video_names_list, frame_preds_list)
129
+ #
130
+ # return frame_preds_list, video_pred_list
models_vit.py CHANGED
@@ -1,5 +1,5 @@
1
  # -*- coding: utf-8 -*-
2
- # Author: Gaojian Wang@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
 
1
  # -*- coding: utf-8 -*-
2
+ # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
util/crop.py CHANGED
@@ -1,5 +1,5 @@
1
  # -*- coding: utf-8 -*-
2
- # Author: Gaojian Wang@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
 
1
  # -*- coding: utf-8 -*-
2
+ # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
util/datasets.py CHANGED
@@ -1,5 +1,5 @@
1
  # -*- coding: utf-8 -*-
2
- # Author: Gaojian Wang@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
 
1
  # -*- coding: utf-8 -*-
2
+ # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
util/lars.py CHANGED
@@ -1,5 +1,5 @@
1
  # -*- coding: utf-8 -*-
2
- # Author: Gaojian Wang@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
 
1
  # -*- coding: utf-8 -*-
2
+ # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
util/loss_contrastive.py CHANGED
@@ -1,5 +1,5 @@
1
  # -*- coding: utf-8 -*-
2
- # Author: Gaojian Wang@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
 
1
  # -*- coding: utf-8 -*-
2
+ # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
util/lr_decay.py CHANGED
@@ -1,5 +1,5 @@
1
  # -*- coding: utf-8 -*-
2
- # Author: Gaojian Wang@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
 
1
  # -*- coding: utf-8 -*-
2
+ # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
util/lr_sched.py CHANGED
@@ -1,5 +1,5 @@
1
  # -*- coding: utf-8 -*-
2
- # Author: Gaojian Wang@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
 
1
  # -*- coding: utf-8 -*-
2
+ # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
util/metrics.py CHANGED
@@ -1,5 +1,5 @@
1
  # -*- coding: utf-8 -*-
2
- # Author: Gaojian Wang@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
@@ -68,6 +68,34 @@ def get_video_level_label_pred(f_label_list, v_name_list, f_pred_list):
68
  return video_label_list, video_pred_list, video_y_pred_list
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def video_level_acc(video_label_list, video_y_pred_list):
72
  return accuracy_score(video_label_list, video_y_pred_list) * 100.
73
 
 
1
  # -*- coding: utf-8 -*-
2
+ # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
 
68
  return video_label_list, video_pred_list, video_y_pred_list
69
 
70
 
71
+ def get_video_level_label_pred_multi_class(f_label_list, v_name_list, f_pred_list):
72
+ import numpy as np
73
+ """
74
+ Adapted for multi-class predictions.
75
+ """
76
+ video_res_dict = dict()
77
+ video_pred_list = list()
78
+ video_y_pred_list = list()
79
+ video_label_list = list()
80
+
81
+ # Summarize all the results for each video
82
+ for label, video, score in zip(f_label_list, v_name_list, f_pred_list):
83
+ if video not in video_res_dict.keys():
84
+ video_res_dict[video] = {"scores": [score], "label": label}
85
+ else:
86
+ video_res_dict[video]["scores"].append(score)
87
+
88
+ # Get the score and label for each video
89
+ for video, res in video_res_dict.items():
90
+ avg_score = np.mean(res['scores'], axis=0)
91
+ label = res['label']
92
+ video_pred_list.append(avg_score)
93
+ video_label_list.append(label)
94
+ video_y_pred_list.append(np.argmax(avg_score))
95
+
96
+ return video_label_list, video_pred_list, video_y_pred_list
97
+
98
+
99
  def video_level_acc(video_label_list, video_y_pred_list):
100
  return accuracy_score(video_label_list, video_y_pred_list) * 100.
101
 
util/misc.py CHANGED
@@ -1,5 +1,5 @@
1
  # -*- coding: utf-8 -*-
2
- # Author: Gaojian Wang@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
 
1
  # -*- coding: utf-8 -*-
2
+ # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
util/pos_embed.py CHANGED
@@ -1,5 +1,5 @@
1
  # -*- coding: utf-8 -*-
2
- # Author: Gaojian Wang@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.
 
1
  # -*- coding: utf-8 -*-
2
+ # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
  # --------------------------------------------------------
4
  # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
  # You can find the license in the LICENSE file in the root directory of this source tree.