Spaces:
Running
on
Zero
Running
on
Zero
Upload 4 files
Browse files- llama_inference.py +204 -0
- single_video.py +70 -0
- utils_asr.py +95 -0
- vidchapters.py +107 -0
llama_inference.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from llama_cookbook.inference.model_utils import load_model as load_model_llamarecipes
|
5 |
+
from llama_cookbook.inference.model_utils import load_peft_model
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
|
8 |
+
from src.utils import RankedLogger
|
9 |
+
|
10 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
11 |
+
|
12 |
+
|
13 |
+
def load_model(
|
14 |
+
ckpt_path, quantization=None, use_fast_kernels=False, peft_model=False, **kwargs
|
15 |
+
):
|
16 |
+
model = load_model_llamarecipes(
|
17 |
+
model_name=ckpt_path,
|
18 |
+
quantization=quantization,
|
19 |
+
use_fast_kernels=use_fast_kernels,
|
20 |
+
device_map="auto",
|
21 |
+
**kwargs,
|
22 |
+
)
|
23 |
+
if peft_model:
|
24 |
+
model = load_peft_model(model, peft_model)
|
25 |
+
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
27 |
+
tokenizer.pad_token = tokenizer.eos_token
|
28 |
+
# special_tokens = {"additional_special_tokens": ["<image>"]}
|
29 |
+
# tokenizer.add_special_tokens(special_tokens)
|
30 |
+
|
31 |
+
return model, tokenizer
|
32 |
+
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def inference(
|
36 |
+
model,
|
37 |
+
tokenizer: AutoTokenizer,
|
38 |
+
prompt: str,
|
39 |
+
add_special_tokens: bool = True,
|
40 |
+
temperature: float = 1.0,
|
41 |
+
max_new_tokens=1024,
|
42 |
+
top_p: float = 1.0,
|
43 |
+
top_k: int = 50,
|
44 |
+
use_cache: bool = True,
|
45 |
+
max_padding_length: int = None,
|
46 |
+
do_sample: bool = False,
|
47 |
+
min_length: int = None,
|
48 |
+
repetition_penalty: float = 1.0,
|
49 |
+
length_penalty: int = 1,
|
50 |
+
max_prompt_tokens: int = 35_000,
|
51 |
+
**kwargs,
|
52 |
+
):
|
53 |
+
"""
|
54 |
+
temperature: float, optional (default=1.0) The value used to module the next token probabilities.
|
55 |
+
max_new_tokens: int, optional (default=1024) The maximum number of tokens to generate.
|
56 |
+
top_p: float, optional (default=1.0) If set to float < 1 only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
|
57 |
+
top_k: int, optional (default=50) The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
58 |
+
use_cache: bool, optional (default=True) Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
|
59 |
+
max_padding_length: int, optional (default=None) the max padding length to be used with tokenizer padding the prompts.
|
60 |
+
do_sample: bool, optional (default=True) Whether or not to use sampling ; use greedy decoding otherwise.
|
61 |
+
min_length: int, optional (default=None) The minimum length of the sequence to be generated input prompt + min_new_tokens
|
62 |
+
repetition_penalty: float, optional (default=1.0) The parameter for repetition penalty. 1.0 means no penalty.
|
63 |
+
length_penalty: int, optional (default=1) Exponential penalty to the length that is used with beam-based generation.
|
64 |
+
"""
|
65 |
+
if add_special_tokens:
|
66 |
+
prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
|
67 |
+
# prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
|
68 |
+
|
69 |
+
batch = tokenizer(
|
70 |
+
prompt,
|
71 |
+
truncation=True,
|
72 |
+
max_length=max_padding_length,
|
73 |
+
return_tensors="pt",
|
74 |
+
)
|
75 |
+
|
76 |
+
# if the input is too long, return the length of the input
|
77 |
+
n_tokens = len(batch["input_ids"][0])
|
78 |
+
if max_prompt_tokens is not None and n_tokens > max_prompt_tokens:
|
79 |
+
return n_tokens
|
80 |
+
|
81 |
+
batch = {k: v.to("cuda") for k, v in batch.items()}
|
82 |
+
|
83 |
+
terminators = [
|
84 |
+
tokenizer.eos_token_id,
|
85 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
|
86 |
+
]
|
87 |
+
|
88 |
+
try:
|
89 |
+
outputs = model.generate(
|
90 |
+
**batch,
|
91 |
+
max_new_tokens=max_new_tokens,
|
92 |
+
do_sample=do_sample,
|
93 |
+
top_p=top_p,
|
94 |
+
temperature=temperature,
|
95 |
+
min_length=min_length,
|
96 |
+
use_cache=use_cache,
|
97 |
+
top_k=top_k,
|
98 |
+
repetition_penalty=repetition_penalty,
|
99 |
+
length_penalty=length_penalty,
|
100 |
+
eos_token_id=terminators,
|
101 |
+
pad_token_id=tokenizer.eos_token_id,
|
102 |
+
**kwargs,
|
103 |
+
)
|
104 |
+
output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
105 |
+
|
106 |
+
output = output_text.split("<|start_header_id|>assistant<|end_header_id|>")[1]
|
107 |
+
output = output.strip()
|
108 |
+
output = output.removesuffix("<|eot_id|>")
|
109 |
+
|
110 |
+
except torch.cuda.OutOfMemoryError as e:
|
111 |
+
log.error(f"CUDA out of memory error: {e}")
|
112 |
+
torch.cuda.empty_cache()
|
113 |
+
return n_tokens
|
114 |
+
|
115 |
+
return output
|
116 |
+
|
117 |
+
|
118 |
+
class LlamaInference:
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
ckpt_path,
|
122 |
+
quantization=None,
|
123 |
+
use_fast_kernels=False,
|
124 |
+
peft_model=False,
|
125 |
+
add_special_tokens: bool = True,
|
126 |
+
temperature: float = 1.0,
|
127 |
+
max_new_tokens: int = 1024,
|
128 |
+
top_p: float = 1.0,
|
129 |
+
top_k: int = 50,
|
130 |
+
use_cache: bool = True,
|
131 |
+
max_padding_length: int = None,
|
132 |
+
do_sample: bool = False,
|
133 |
+
min_length: int = None,
|
134 |
+
repetition_penalty: float = 1.0,
|
135 |
+
length_penalty: int = 1,
|
136 |
+
max_prompt_tokens: int = 35_000,
|
137 |
+
**kwargs,
|
138 |
+
):
|
139 |
+
# Check if LLaMA model exists
|
140 |
+
# if not Path(ckpt_path).exists():
|
141 |
+
# log.warning(f"Model checkpoint does not exist at {ckpt_path}")
|
142 |
+
# return None
|
143 |
+
|
144 |
+
# If PEFT model is specified, check if it exists
|
145 |
+
if peft_model and not Path(peft_model).exists():
|
146 |
+
log.warning(f"PEFT model does not exist at {peft_model}")
|
147 |
+
return None
|
148 |
+
if peft_model:
|
149 |
+
log.info(f"PEFT model found at {peft_model}")
|
150 |
+
|
151 |
+
model = load_model_llamarecipes(
|
152 |
+
model_name=ckpt_path,
|
153 |
+
quantization=quantization,
|
154 |
+
use_fast_kernels=use_fast_kernels,
|
155 |
+
device_map="auto",
|
156 |
+
**kwargs,
|
157 |
+
)
|
158 |
+
if peft_model:
|
159 |
+
model = load_peft_model(model, peft_model)
|
160 |
+
|
161 |
+
model.eval()
|
162 |
+
|
163 |
+
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
164 |
+
tokenizer.pad_token = tokenizer.eos_token
|
165 |
+
|
166 |
+
self.model = model
|
167 |
+
self.tokenizer = tokenizer
|
168 |
+
self.add_special_tokens = add_special_tokens
|
169 |
+
self.temperature = temperature
|
170 |
+
self.max_new_tokens = max_new_tokens
|
171 |
+
self.top_p = top_p
|
172 |
+
self.top_k = top_k
|
173 |
+
self.use_cache = use_cache
|
174 |
+
self.max_padding_length = max_padding_length
|
175 |
+
self.do_sample = do_sample
|
176 |
+
self.min_length = min_length
|
177 |
+
self.repetition_penalty = repetition_penalty
|
178 |
+
self.length_penalty = length_penalty
|
179 |
+
self.max_prompt_tokens = max_prompt_tokens
|
180 |
+
|
181 |
+
def __call__(self, prompt: str, **kwargs):
|
182 |
+
# Create a dict of default parameters from instance attributes
|
183 |
+
params = {
|
184 |
+
"model": self.model,
|
185 |
+
"tokenizer": self.tokenizer,
|
186 |
+
"prompt": prompt,
|
187 |
+
"add_special_tokens": self.add_special_tokens,
|
188 |
+
"temperature": self.temperature,
|
189 |
+
"max_new_tokens": self.max_new_tokens,
|
190 |
+
"top_p": self.top_p,
|
191 |
+
"top_k": self.top_k,
|
192 |
+
"use_cache": self.use_cache,
|
193 |
+
"max_padding_length": self.max_padding_length,
|
194 |
+
"do_sample": self.do_sample,
|
195 |
+
"min_length": self.min_length,
|
196 |
+
"repetition_penalty": self.repetition_penalty,
|
197 |
+
"length_penalty": self.length_penalty,
|
198 |
+
"max_prompt_tokens": self.max_prompt_tokens,
|
199 |
+
}
|
200 |
+
|
201 |
+
# Update with any overrides passed in kwargs
|
202 |
+
params.update(kwargs)
|
203 |
+
|
204 |
+
return inference(**params)
|
single_video.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
from lutils import openf, writef
|
4 |
+
|
5 |
+
from src.data.chapters import sec_to_hms
|
6 |
+
from tools.extract.asr import ASRProcessor
|
7 |
+
|
8 |
+
|
9 |
+
class SingleVideo:
|
10 |
+
"""
|
11 |
+
A simplified implementation of the src.data.chapters.Chapters interface for single video inference.
|
12 |
+
|
13 |
+
This class mimics the behavior of the ChaptersASR class but is designed to work with
|
14 |
+
a single video file rather than a dataset. It provides the necessary methods
|
15 |
+
required by the PromptASR class for generating chapter timestamps and titles.
|
16 |
+
|
17 |
+
Note: This class is intended for inference only and should not be used for
|
18 |
+
training or evaluation purposes.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, video_path: Path):
|
22 |
+
self.video_path = video_path
|
23 |
+
self.video_ids = [video_path.stem]
|
24 |
+
assert video_path.exists(), f"Video file {video_path} not found"
|
25 |
+
self.asr, self.duration = get_asr(video_path, overwrite=True)
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.video_ids)
|
29 |
+
|
30 |
+
def __iter__(self):
|
31 |
+
return iter(self.video_ids)
|
32 |
+
|
33 |
+
def __contains__(self, vid_id):
|
34 |
+
return vid_id in self.video_ids
|
35 |
+
|
36 |
+
def get_duration(self, vid_id, hms=False):
|
37 |
+
assert vid_id == self.video_ids[0], f"Invalid video ID: {vid_id}"
|
38 |
+
if hms:
|
39 |
+
return sec_to_hms(self.duration)
|
40 |
+
return self.duration
|
41 |
+
|
42 |
+
def get_asr(self, vid_id):
|
43 |
+
assert vid_id == self.video_ids[0], f"Invalid video ID: {vid_id}"
|
44 |
+
return self.asr
|
45 |
+
|
46 |
+
|
47 |
+
def get_asr(video_path: Path, overwrite=False):
|
48 |
+
output_dir = Path(f"outputs/inference/{video_path.stem}")
|
49 |
+
asr_output = output_dir / "asr.txt"
|
50 |
+
duration_output = output_dir / "duration.txt"
|
51 |
+
if asr_output.exists() and duration_output.exists() and not overwrite:
|
52 |
+
asr = openf(asr_output)
|
53 |
+
asr = "\n".join(asr) + "\n"
|
54 |
+
|
55 |
+
duration = openf(duration_output)
|
56 |
+
assert isinstance(duration, list) and len(duration) == 1, (
|
57 |
+
f"Duration is not a list of length 1: {duration}"
|
58 |
+
)
|
59 |
+
duration = float(duration[0])
|
60 |
+
assert duration > 0, f"Duration is not positive: {duration}"
|
61 |
+
return asr, duration
|
62 |
+
|
63 |
+
print(f"\n=== 🎙️ Processing ASR for {video_path} ===")
|
64 |
+
asr_processor = ASRProcessor()
|
65 |
+
asr, duration = asr_processor.get_asr(video_path)
|
66 |
+
print(f"=== ✅ ASR processing complete for {video_path} ===\n")
|
67 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
68 |
+
writef(asr_output, asr)
|
69 |
+
writef(duration_output, str(duration))
|
70 |
+
return asr, duration
|
utils_asr.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lutils import openf, writef
|
2 |
+
|
3 |
+
from src.data.chapters import Chapters, sec_to_hms
|
4 |
+
from src.data.prompt import Prompt
|
5 |
+
from src.utils import RankedLogger
|
6 |
+
|
7 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
8 |
+
|
9 |
+
|
10 |
+
class ChaptersASR(Chapters):
|
11 |
+
def __init__(self, vidc_dir: str = "dataset/", subset=""):
|
12 |
+
super().__init__(vidc_dir=vidc_dir, subset=subset)
|
13 |
+
|
14 |
+
self._asrs = None
|
15 |
+
|
16 |
+
@property
|
17 |
+
def asrs(self):
|
18 |
+
if self._asrs is None:
|
19 |
+
self.load_asr_data()
|
20 |
+
return self._asrs
|
21 |
+
|
22 |
+
def load_asr_data(self):
|
23 |
+
if self._asrs is not None:
|
24 |
+
return
|
25 |
+
|
26 |
+
if self.subset:
|
27 |
+
asr_pth = self.vidc_dir / f"docs/subset_data/asrs/asrs_{self.subset}.json"
|
28 |
+
if asr_pth.exists():
|
29 |
+
self._asrs = openf(asr_pth)
|
30 |
+
else:
|
31 |
+
log.info(f"ASR data not found for subset {self.subset}.")
|
32 |
+
asr_val_pth = self.vidc_dir / "docs/subset_data/asrs/asrs_val.json"
|
33 |
+
asr_train_pth = self.vidc_dir / "docs/subset_data/asrs/asrs_train.json"
|
34 |
+
if "val" in self.subset and asr_val_pth.exists():
|
35 |
+
log.info("Loading from ASR validation file.")
|
36 |
+
asrs = openf(asr_val_pth)
|
37 |
+
elif "train" in self.subset and asr_train_pth.exists():
|
38 |
+
log.info("Loading from ASR training file.")
|
39 |
+
asrs = openf(asr_train_pth)
|
40 |
+
else:
|
41 |
+
log.info("Loading from ASR file.")
|
42 |
+
asrs = openf(self.vidc_dir / "docs/asrs.json")
|
43 |
+
video_ids = set(self.video_ids) & set(asrs.keys())
|
44 |
+
self._asrs = {vid_id: asrs[vid_id] for vid_id in video_ids}
|
45 |
+
asr_pth.parent.mkdir(exist_ok=True)
|
46 |
+
writef(asr_pth, self._asrs)
|
47 |
+
else:
|
48 |
+
self._asrs = openf(self.vidc_dir / "docs/asrs.json")
|
49 |
+
|
50 |
+
def get_asr(self, video_id, add_end=False):
|
51 |
+
if video_id not in self.asrs:
|
52 |
+
return None
|
53 |
+
|
54 |
+
asr = self.asrs[video_id]
|
55 |
+
asr_clean = []
|
56 |
+
for t, s, e in zip(asr["text"], asr["start"], asr["end"]):
|
57 |
+
t = t.strip()
|
58 |
+
s = sec_to_hms(s)
|
59 |
+
e = sec_to_hms(e)
|
60 |
+
if add_end:
|
61 |
+
asr_clean.append(f"{s} - {e}: {t}")
|
62 |
+
else:
|
63 |
+
asr_clean.append(f"{s}: {t}")
|
64 |
+
|
65 |
+
return "\n".join(asr_clean) + "\n"
|
66 |
+
|
67 |
+
def __contains__(self, vid_id):
|
68 |
+
return vid_id in self.asrs
|
69 |
+
|
70 |
+
|
71 |
+
class PromptASR(Prompt):
|
72 |
+
def __init__(self, chapters: ChaptersASR, add_end=False):
|
73 |
+
super().__init__(chapters=chapters)
|
74 |
+
self.add_end = add_end
|
75 |
+
|
76 |
+
def get_task_prompt(self):
|
77 |
+
return "segment the text into distinct chapters based on thematic shifts or changes in topics.\n"
|
78 |
+
|
79 |
+
def get_transcript(self, vid_id):
|
80 |
+
vid_asr = self.chapters.get_asr(vid_id, add_end=self.add_end)
|
81 |
+
assert vid_asr is not None, f"ASR not found for video ID: {vid_id}"
|
82 |
+
return vid_asr
|
83 |
+
|
84 |
+
def __contains__(self, vid_id):
|
85 |
+
return vid_id in self.chapters
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
chapters = ChaptersASR(subset="s10k_train")
|
90 |
+
vid_id = chapters.sample()
|
91 |
+
|
92 |
+
prompt = PromptASR(chapters=chapters)
|
93 |
+
print(prompt.get_prompt_train(vid_id))
|
94 |
+
print(prompt.get_transcript(vid_id))
|
95 |
+
print(prompt.get_output(vid_id))
|
vidchapters.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
from lutils import writef
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from src.test.utils_chapters import extract_chapters, filter_chapters
|
7 |
+
from src.utils import RankedLogger
|
8 |
+
|
9 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
10 |
+
|
11 |
+
|
12 |
+
def get_chapters(
|
13 |
+
inference,
|
14 |
+
prompt,
|
15 |
+
max_new_tokens,
|
16 |
+
do_sample=False,
|
17 |
+
vid_duration=None,
|
18 |
+
use_cache=True,
|
19 |
+
vid_id="",
|
20 |
+
):
|
21 |
+
output_text = inference(
|
22 |
+
prompt=prompt,
|
23 |
+
max_new_tokens=max_new_tokens,
|
24 |
+
add_special_tokens=True,
|
25 |
+
do_sample=do_sample,
|
26 |
+
use_cache=use_cache,
|
27 |
+
)
|
28 |
+
|
29 |
+
if isinstance(output_text, int):
|
30 |
+
# the input is too long, return the length of the input
|
31 |
+
return output_text, None
|
32 |
+
|
33 |
+
chapters = extract_chapters(output_text)
|
34 |
+
chapters = filter_chapters(chapters, vid_duration=vid_duration)
|
35 |
+
|
36 |
+
if not chapters and not do_sample:
|
37 |
+
log.info(f"No chapters found for {vid_id}, trying again with sampling")
|
38 |
+
return get_chapters(
|
39 |
+
inference,
|
40 |
+
prompt,
|
41 |
+
max_new_tokens,
|
42 |
+
do_sample=True,
|
43 |
+
vid_duration=vid_duration,
|
44 |
+
)
|
45 |
+
|
46 |
+
return output_text, chapters
|
47 |
+
|
48 |
+
|
49 |
+
class VidChaptersTester:
|
50 |
+
def __init__(self, save_dir: str, do_sample=False, **kwargs):
|
51 |
+
self.save_dir = Path(save_dir)
|
52 |
+
self.save_dir.mkdir(exist_ok=True)
|
53 |
+
self.do_sample = do_sample
|
54 |
+
|
55 |
+
def __call__(
|
56 |
+
self,
|
57 |
+
inference,
|
58 |
+
test_dataloader,
|
59 |
+
max_new_tokens=1024,
|
60 |
+
):
|
61 |
+
pbar = tqdm(
|
62 |
+
total=len(test_dataloader),
|
63 |
+
desc="Evaluating chapters",
|
64 |
+
)
|
65 |
+
|
66 |
+
for batch in test_dataloader:
|
67 |
+
vid_id = batch["vid_id"][0]
|
68 |
+
prompt = batch["prompt"][0]
|
69 |
+
transcript = batch["transcript"][0]
|
70 |
+
vid_duration = batch["vid_duration"][0]
|
71 |
+
prompt += transcript
|
72 |
+
|
73 |
+
chapters_pth = self.save_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
|
74 |
+
chapters_pth.parent.mkdir(exist_ok=True)
|
75 |
+
|
76 |
+
if chapters_pth.exists():
|
77 |
+
pbar.update(1)
|
78 |
+
continue
|
79 |
+
|
80 |
+
pbar.set_description(f"vid_id: {vid_id}")
|
81 |
+
|
82 |
+
output_text, chapters = get_chapters(
|
83 |
+
inference,
|
84 |
+
prompt,
|
85 |
+
max_new_tokens,
|
86 |
+
do_sample=self.do_sample,
|
87 |
+
vid_duration=vid_duration,
|
88 |
+
vid_id=vid_id,
|
89 |
+
)
|
90 |
+
|
91 |
+
if chapters is None:
|
92 |
+
log.info(f"Input too long for {vid_id}, {output_text} tokens")
|
93 |
+
error_pth = chapters_pth.with_suffix(".txt")
|
94 |
+
writef(error_pth, [output_text])
|
95 |
+
pbar.update(1)
|
96 |
+
continue
|
97 |
+
|
98 |
+
if chapters:
|
99 |
+
vid_data = {
|
100 |
+
"chapters": chapters,
|
101 |
+
"output": output_text,
|
102 |
+
}
|
103 |
+
writef(chapters_pth, vid_data)
|
104 |
+
|
105 |
+
pbar.update(1)
|
106 |
+
|
107 |
+
pbar.close()
|