yeliudev's picture
Upload folder using huggingface_hub
6073e55 verified
raw
history blame contribute delete
5.52 kB
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License.
import copy
import nncore
import torch
from nncore.ops import temporal_iou
from torch.utils.data import Dataset
from videomind.constants import VERIFIER_PROMPT
from videomind.dataset.hybrid import DATASETS
from videomind.utils.parser import parse_span
class VerifyingDataset(Dataset):
def __init__(self, processor, model_args, data_args, training_args):
super(VerifyingDataset, self).__init__()
raw_annos = self.load_annos()
annos = []
for anno in raw_annos:
num_words = len(anno['query'].split(' '))
if data_args.min_num_words >= 0 and num_words < data_args.min_num_words:
continue
if data_args.max_num_words >= 0 and num_words > data_args.max_num_words:
continue
if data_args.min_video_len >= 0 and anno.get('duration', float('inf')) < data_args.min_video_len:
continue
if data_args.max_video_len >= 0 and anno.get('duration', 0) > data_args.max_video_len:
continue
annos.append(anno)
self.annos = annos
self.raw_length = len(raw_annos)
self.processor = processor
self.model_args = model_args
self.data_args = data_args
self.training_args = training_args
def __len__(self):
return len(self.annos)
@classmethod
def load_annos(self, split='train'):
assert split == 'train'
if nncore.is_dir(self.ANNO_PATH):
raw_paths = nncore.ls(self.ANNO_PATH, ext='json', join_path=True, sort=True)
raw_annos = nncore.flatten([nncore.load(p) for p in raw_paths])
else:
raw_annos = nncore.load(self.ANNO_PATH)
annos = []
for raw_anno in raw_annos:
# using top-5 predictions
for pred in raw_anno['pred'][:5]:
iou = temporal_iou(torch.Tensor([pred]), torch.Tensor(raw_anno['span']))
iou = torch.where(iou.isfinite(), iou, 0)
iou = iou.max().item()
positive = iou >= 0.5
anno = dict(
source=self.SOURCE,
data_type='multimodal',
video_path=raw_anno['video_path'],
duration=raw_anno['duration'],
query=raw_anno['query'],
span=raw_anno['span'],
pred=pred,
positive=positive,
task=raw_anno.get('task', 'unknown'))
annos.append(anno)
pos_inds = [i for i, a in enumerate(annos) if a['positive']]
neg_inds = [i for i, a in enumerate(annos) if not a['positive']]
num_pos = len(pos_inds)
num_neg = len(neg_inds)
print(f'[{self.SOURCE}] pos: {num_pos} neg: {num_neg} n/p ratio: {num_neg / num_pos}')
# filter negative samples
# if num_neg > num_pos * 3:
# neg_inds = random.sample(neg_inds, int(num_pos * 3))
# inds = pos_inds + neg_inds
# random.shuffle(inds)
# inds = comm.broadcast(inds)
# annos = [annos[i] for i in inds]
return annos
def __getitem__(self, idx):
anno = copy.deepcopy(self.annos[idx])
video_path, duration, query, positive = anno['video_path'], anno['duration'], anno['query'], anno['positive']
s0, e0 = parse_span(anno['pred'], duration, 2)
offset = (e0 - s0) / 2
s1, e1 = parse_span([s0 - offset, e0 + offset], duration)
# percentage of s0, e0 within s1, e1
s = (s0 - s1) / (e1 - s1)
e = (e0 - s1) / (e1 - s1)
messages = [{
'role':
'user',
'content': [{
'type': 'video',
'video': video_path,
'video_start': s1,
'video_end': e1,
'min_pixels': 36 * 28 * 28,
'max_pixels': 64 * 28 * 28,
'max_frames': 64,
'fps': 2.0
}, {
'type': 'text',
'text': VERIFIER_PROMPT.format(query)
}]
}]
messages = messages + [{'role': 'assistant', 'content': 'Yes.' if positive else 'No.'}]
meta = dict(messages=messages, ss=s, se=e)
return meta
@DATASETS.register(name='qvhighlights_verify_2b')
class QVHighlightsVerify2BDataset(VerifyingDataset):
ANNO_PATH = 'data/verifying/verifying_qvhighlights_2b.json'
SOURCE = 'qvhighlights_verify_2b'
@DATASETS.register(name='didemo_verify_2b')
class DiDeMoVerify2BDataset(VerifyingDataset):
ANNO_PATH = 'data/verifying/verifying_didemo_2b.json'
SOURCE = 'didemo_verify_2b'
@DATASETS.register(name='tacos_verify_2b')
class TACoSVerify2BDataset(VerifyingDataset):
ANNO_PATH = 'data/verifying/verifying_tacos_2b.json'
SOURCE = 'tacos_verify_2b'
@DATASETS.register(name='qvhighlights_verify_7b')
class QVHighlightsVerify7BDataset(VerifyingDataset):
ANNO_PATH = 'data/verifying/verifying_qvhighlights_7b.json'
SOURCE = 'qvhighlights_verify_7b'
@DATASETS.register(name='didemo_verify_7b')
class DiDeMoVerify7BDataset(VerifyingDataset):
ANNO_PATH = 'data/verifying/verifying_didemo_7b.json'
SOURCE = 'didemo_verify_7b'
@DATASETS.register(name='tacos_verify_7b')
class TACoSVerify7BDataset(VerifyingDataset):
ANNO_PATH = 'data/verifying/verifying_tacos_7b.json'
SOURCE = 'tacos_verify_7b'