Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. | |
import copy | |
import nncore | |
import torch | |
from nncore.ops import temporal_iou | |
from torch.utils.data import Dataset | |
from videomind.constants import VERIFIER_PROMPT | |
from videomind.dataset.hybrid import DATASETS | |
from videomind.utils.parser import parse_span | |
class VerifyingDataset(Dataset): | |
def __init__(self, processor, model_args, data_args, training_args): | |
super(VerifyingDataset, self).__init__() | |
raw_annos = self.load_annos() | |
annos = [] | |
for anno in raw_annos: | |
num_words = len(anno['query'].split(' ')) | |
if data_args.min_num_words >= 0 and num_words < data_args.min_num_words: | |
continue | |
if data_args.max_num_words >= 0 and num_words > data_args.max_num_words: | |
continue | |
if data_args.min_video_len >= 0 and anno.get('duration', float('inf')) < data_args.min_video_len: | |
continue | |
if data_args.max_video_len >= 0 and anno.get('duration', 0) > data_args.max_video_len: | |
continue | |
annos.append(anno) | |
self.annos = annos | |
self.raw_length = len(raw_annos) | |
self.processor = processor | |
self.model_args = model_args | |
self.data_args = data_args | |
self.training_args = training_args | |
def __len__(self): | |
return len(self.annos) | |
def load_annos(self, split='train'): | |
assert split == 'train' | |
if nncore.is_dir(self.ANNO_PATH): | |
raw_paths = nncore.ls(self.ANNO_PATH, ext='json', join_path=True, sort=True) | |
raw_annos = nncore.flatten([nncore.load(p) for p in raw_paths]) | |
else: | |
raw_annos = nncore.load(self.ANNO_PATH) | |
annos = [] | |
for raw_anno in raw_annos: | |
# using top-5 predictions | |
for pred in raw_anno['pred'][:5]: | |
iou = temporal_iou(torch.Tensor([pred]), torch.Tensor(raw_anno['span'])) | |
iou = torch.where(iou.isfinite(), iou, 0) | |
iou = iou.max().item() | |
positive = iou >= 0.5 | |
anno = dict( | |
source=self.SOURCE, | |
data_type='multimodal', | |
video_path=raw_anno['video_path'], | |
duration=raw_anno['duration'], | |
query=raw_anno['query'], | |
span=raw_anno['span'], | |
pred=pred, | |
positive=positive, | |
task=raw_anno.get('task', 'unknown')) | |
annos.append(anno) | |
pos_inds = [i for i, a in enumerate(annos) if a['positive']] | |
neg_inds = [i for i, a in enumerate(annos) if not a['positive']] | |
num_pos = len(pos_inds) | |
num_neg = len(neg_inds) | |
print(f'[{self.SOURCE}] pos: {num_pos} neg: {num_neg} n/p ratio: {num_neg / num_pos}') | |
# filter negative samples | |
# if num_neg > num_pos * 3: | |
# neg_inds = random.sample(neg_inds, int(num_pos * 3)) | |
# inds = pos_inds + neg_inds | |
# random.shuffle(inds) | |
# inds = comm.broadcast(inds) | |
# annos = [annos[i] for i in inds] | |
return annos | |
def __getitem__(self, idx): | |
anno = copy.deepcopy(self.annos[idx]) | |
video_path, duration, query, positive = anno['video_path'], anno['duration'], anno['query'], anno['positive'] | |
s0, e0 = parse_span(anno['pred'], duration, 2) | |
offset = (e0 - s0) / 2 | |
s1, e1 = parse_span([s0 - offset, e0 + offset], duration) | |
# percentage of s0, e0 within s1, e1 | |
s = (s0 - s1) / (e1 - s1) | |
e = (e0 - s1) / (e1 - s1) | |
messages = [{ | |
'role': | |
'user', | |
'content': [{ | |
'type': 'video', | |
'video': video_path, | |
'video_start': s1, | |
'video_end': e1, | |
'min_pixels': 36 * 28 * 28, | |
'max_pixels': 64 * 28 * 28, | |
'max_frames': 64, | |
'fps': 2.0 | |
}, { | |
'type': 'text', | |
'text': VERIFIER_PROMPT.format(query) | |
}] | |
}] | |
messages = messages + [{'role': 'assistant', 'content': 'Yes.' if positive else 'No.'}] | |
meta = dict(messages=messages, ss=s, se=e) | |
return meta | |
class QVHighlightsVerify2BDataset(VerifyingDataset): | |
ANNO_PATH = 'data/verifying/verifying_qvhighlights_2b.json' | |
SOURCE = 'qvhighlights_verify_2b' | |
class DiDeMoVerify2BDataset(VerifyingDataset): | |
ANNO_PATH = 'data/verifying/verifying_didemo_2b.json' | |
SOURCE = 'didemo_verify_2b' | |
class TACoSVerify2BDataset(VerifyingDataset): | |
ANNO_PATH = 'data/verifying/verifying_tacos_2b.json' | |
SOURCE = 'tacos_verify_2b' | |
class QVHighlightsVerify7BDataset(VerifyingDataset): | |
ANNO_PATH = 'data/verifying/verifying_qvhighlights_7b.json' | |
SOURCE = 'qvhighlights_verify_7b' | |
class DiDeMoVerify7BDataset(VerifyingDataset): | |
ANNO_PATH = 'data/verifying/verifying_didemo_7b.json' | |
SOURCE = 'didemo_verify_7b' | |
class TACoSVerify7BDataset(VerifyingDataset): | |
ANNO_PATH = 'data/verifying/verifying_tacos_7b.json' | |
SOURCE = 'tacos_verify_7b' | |