lucas-ventura commited on
Commit
90559ad
·
verified ·
1 Parent(s): 36a1678

Upload 4 files

Browse files
Files changed (4) hide show
  1. llama_inference.py +204 -0
  2. single_video.py +70 -0
  3. utils_asr.py +95 -0
  4. vidchapters.py +107 -0
llama_inference.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ from llama_cookbook.inference.model_utils import load_model as load_model_llamarecipes
5
+ from llama_cookbook.inference.model_utils import load_peft_model
6
+ from transformers import AutoTokenizer
7
+
8
+ from src.utils import RankedLogger
9
+
10
+ log = RankedLogger(__name__, rank_zero_only=True)
11
+
12
+
13
+ def load_model(
14
+ ckpt_path, quantization=None, use_fast_kernels=False, peft_model=False, **kwargs
15
+ ):
16
+ model = load_model_llamarecipes(
17
+ model_name=ckpt_path,
18
+ quantization=quantization,
19
+ use_fast_kernels=use_fast_kernels,
20
+ device_map="auto",
21
+ **kwargs,
22
+ )
23
+ if peft_model:
24
+ model = load_peft_model(model, peft_model)
25
+
26
+ tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
27
+ tokenizer.pad_token = tokenizer.eos_token
28
+ # special_tokens = {"additional_special_tokens": ["<image>"]}
29
+ # tokenizer.add_special_tokens(special_tokens)
30
+
31
+ return model, tokenizer
32
+
33
+
34
+ @torch.no_grad()
35
+ def inference(
36
+ model,
37
+ tokenizer: AutoTokenizer,
38
+ prompt: str,
39
+ add_special_tokens: bool = True,
40
+ temperature: float = 1.0,
41
+ max_new_tokens=1024,
42
+ top_p: float = 1.0,
43
+ top_k: int = 50,
44
+ use_cache: bool = True,
45
+ max_padding_length: int = None,
46
+ do_sample: bool = False,
47
+ min_length: int = None,
48
+ repetition_penalty: float = 1.0,
49
+ length_penalty: int = 1,
50
+ max_prompt_tokens: int = 35_000,
51
+ **kwargs,
52
+ ):
53
+ """
54
+ temperature: float, optional (default=1.0) The value used to module the next token probabilities.
55
+ max_new_tokens: int, optional (default=1024) The maximum number of tokens to generate.
56
+ top_p: float, optional (default=1.0) If set to float < 1 only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
57
+ top_k: int, optional (default=50) The number of highest probability vocabulary tokens to keep for top-k-filtering.
58
+ use_cache: bool, optional (default=True) Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
59
+ max_padding_length: int, optional (default=None) the max padding length to be used with tokenizer padding the prompts.
60
+ do_sample: bool, optional (default=True) Whether or not to use sampling ; use greedy decoding otherwise.
61
+ min_length: int, optional (default=None) The minimum length of the sequence to be generated input prompt + min_new_tokens
62
+ repetition_penalty: float, optional (default=1.0) The parameter for repetition penalty. 1.0 means no penalty.
63
+ length_penalty: int, optional (default=1) Exponential penalty to the length that is used with beam-based generation.
64
+ """
65
+ if add_special_tokens:
66
+ prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
67
+ # prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
68
+
69
+ batch = tokenizer(
70
+ prompt,
71
+ truncation=True,
72
+ max_length=max_padding_length,
73
+ return_tensors="pt",
74
+ )
75
+
76
+ # if the input is too long, return the length of the input
77
+ n_tokens = len(batch["input_ids"][0])
78
+ if max_prompt_tokens is not None and n_tokens > max_prompt_tokens:
79
+ return n_tokens
80
+
81
+ batch = {k: v.to("cuda") for k, v in batch.items()}
82
+
83
+ terminators = [
84
+ tokenizer.eos_token_id,
85
+ tokenizer.convert_tokens_to_ids("<|eot_id|>"),
86
+ ]
87
+
88
+ try:
89
+ outputs = model.generate(
90
+ **batch,
91
+ max_new_tokens=max_new_tokens,
92
+ do_sample=do_sample,
93
+ top_p=top_p,
94
+ temperature=temperature,
95
+ min_length=min_length,
96
+ use_cache=use_cache,
97
+ top_k=top_k,
98
+ repetition_penalty=repetition_penalty,
99
+ length_penalty=length_penalty,
100
+ eos_token_id=terminators,
101
+ pad_token_id=tokenizer.eos_token_id,
102
+ **kwargs,
103
+ )
104
+ output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
105
+
106
+ output = output_text.split("<|start_header_id|>assistant<|end_header_id|>")[1]
107
+ output = output.strip()
108
+ output = output.removesuffix("<|eot_id|>")
109
+
110
+ except torch.cuda.OutOfMemoryError as e:
111
+ log.error(f"CUDA out of memory error: {e}")
112
+ torch.cuda.empty_cache()
113
+ return n_tokens
114
+
115
+ return output
116
+
117
+
118
+ class LlamaInference:
119
+ def __init__(
120
+ self,
121
+ ckpt_path,
122
+ quantization=None,
123
+ use_fast_kernels=False,
124
+ peft_model=False,
125
+ add_special_tokens: bool = True,
126
+ temperature: float = 1.0,
127
+ max_new_tokens: int = 1024,
128
+ top_p: float = 1.0,
129
+ top_k: int = 50,
130
+ use_cache: bool = True,
131
+ max_padding_length: int = None,
132
+ do_sample: bool = False,
133
+ min_length: int = None,
134
+ repetition_penalty: float = 1.0,
135
+ length_penalty: int = 1,
136
+ max_prompt_tokens: int = 35_000,
137
+ **kwargs,
138
+ ):
139
+ # Check if LLaMA model exists
140
+ # if not Path(ckpt_path).exists():
141
+ # log.warning(f"Model checkpoint does not exist at {ckpt_path}")
142
+ # return None
143
+
144
+ # If PEFT model is specified, check if it exists
145
+ if peft_model and not Path(peft_model).exists():
146
+ log.warning(f"PEFT model does not exist at {peft_model}")
147
+ return None
148
+ if peft_model:
149
+ log.info(f"PEFT model found at {peft_model}")
150
+
151
+ model = load_model_llamarecipes(
152
+ model_name=ckpt_path,
153
+ quantization=quantization,
154
+ use_fast_kernels=use_fast_kernels,
155
+ device_map="auto",
156
+ **kwargs,
157
+ )
158
+ if peft_model:
159
+ model = load_peft_model(model, peft_model)
160
+
161
+ model.eval()
162
+
163
+ tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
164
+ tokenizer.pad_token = tokenizer.eos_token
165
+
166
+ self.model = model
167
+ self.tokenizer = tokenizer
168
+ self.add_special_tokens = add_special_tokens
169
+ self.temperature = temperature
170
+ self.max_new_tokens = max_new_tokens
171
+ self.top_p = top_p
172
+ self.top_k = top_k
173
+ self.use_cache = use_cache
174
+ self.max_padding_length = max_padding_length
175
+ self.do_sample = do_sample
176
+ self.min_length = min_length
177
+ self.repetition_penalty = repetition_penalty
178
+ self.length_penalty = length_penalty
179
+ self.max_prompt_tokens = max_prompt_tokens
180
+
181
+ def __call__(self, prompt: str, **kwargs):
182
+ # Create a dict of default parameters from instance attributes
183
+ params = {
184
+ "model": self.model,
185
+ "tokenizer": self.tokenizer,
186
+ "prompt": prompt,
187
+ "add_special_tokens": self.add_special_tokens,
188
+ "temperature": self.temperature,
189
+ "max_new_tokens": self.max_new_tokens,
190
+ "top_p": self.top_p,
191
+ "top_k": self.top_k,
192
+ "use_cache": self.use_cache,
193
+ "max_padding_length": self.max_padding_length,
194
+ "do_sample": self.do_sample,
195
+ "min_length": self.min_length,
196
+ "repetition_penalty": self.repetition_penalty,
197
+ "length_penalty": self.length_penalty,
198
+ "max_prompt_tokens": self.max_prompt_tokens,
199
+ }
200
+
201
+ # Update with any overrides passed in kwargs
202
+ params.update(kwargs)
203
+
204
+ return inference(**params)
single_video.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from lutils import openf, writef
4
+
5
+ from src.data.chapters import sec_to_hms
6
+ from tools.extract.asr import ASRProcessor
7
+
8
+
9
+ class SingleVideo:
10
+ """
11
+ A simplified implementation of the src.data.chapters.Chapters interface for single video inference.
12
+
13
+ This class mimics the behavior of the ChaptersASR class but is designed to work with
14
+ a single video file rather than a dataset. It provides the necessary methods
15
+ required by the PromptASR class for generating chapter timestamps and titles.
16
+
17
+ Note: This class is intended for inference only and should not be used for
18
+ training or evaluation purposes.
19
+ """
20
+
21
+ def __init__(self, video_path: Path):
22
+ self.video_path = video_path
23
+ self.video_ids = [video_path.stem]
24
+ assert video_path.exists(), f"Video file {video_path} not found"
25
+ self.asr, self.duration = get_asr(video_path, overwrite=True)
26
+
27
+ def __len__(self):
28
+ return len(self.video_ids)
29
+
30
+ def __iter__(self):
31
+ return iter(self.video_ids)
32
+
33
+ def __contains__(self, vid_id):
34
+ return vid_id in self.video_ids
35
+
36
+ def get_duration(self, vid_id, hms=False):
37
+ assert vid_id == self.video_ids[0], f"Invalid video ID: {vid_id}"
38
+ if hms:
39
+ return sec_to_hms(self.duration)
40
+ return self.duration
41
+
42
+ def get_asr(self, vid_id):
43
+ assert vid_id == self.video_ids[0], f"Invalid video ID: {vid_id}"
44
+ return self.asr
45
+
46
+
47
+ def get_asr(video_path: Path, overwrite=False):
48
+ output_dir = Path(f"outputs/inference/{video_path.stem}")
49
+ asr_output = output_dir / "asr.txt"
50
+ duration_output = output_dir / "duration.txt"
51
+ if asr_output.exists() and duration_output.exists() and not overwrite:
52
+ asr = openf(asr_output)
53
+ asr = "\n".join(asr) + "\n"
54
+
55
+ duration = openf(duration_output)
56
+ assert isinstance(duration, list) and len(duration) == 1, (
57
+ f"Duration is not a list of length 1: {duration}"
58
+ )
59
+ duration = float(duration[0])
60
+ assert duration > 0, f"Duration is not positive: {duration}"
61
+ return asr, duration
62
+
63
+ print(f"\n=== 🎙️ Processing ASR for {video_path} ===")
64
+ asr_processor = ASRProcessor()
65
+ asr, duration = asr_processor.get_asr(video_path)
66
+ print(f"=== ✅ ASR processing complete for {video_path} ===\n")
67
+ output_dir.mkdir(parents=True, exist_ok=True)
68
+ writef(asr_output, asr)
69
+ writef(duration_output, str(duration))
70
+ return asr, duration
utils_asr.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lutils import openf, writef
2
+
3
+ from src.data.chapters import Chapters, sec_to_hms
4
+ from src.data.prompt import Prompt
5
+ from src.utils import RankedLogger
6
+
7
+ log = RankedLogger(__name__, rank_zero_only=True)
8
+
9
+
10
+ class ChaptersASR(Chapters):
11
+ def __init__(self, vidc_dir: str = "dataset/", subset=""):
12
+ super().__init__(vidc_dir=vidc_dir, subset=subset)
13
+
14
+ self._asrs = None
15
+
16
+ @property
17
+ def asrs(self):
18
+ if self._asrs is None:
19
+ self.load_asr_data()
20
+ return self._asrs
21
+
22
+ def load_asr_data(self):
23
+ if self._asrs is not None:
24
+ return
25
+
26
+ if self.subset:
27
+ asr_pth = self.vidc_dir / f"docs/subset_data/asrs/asrs_{self.subset}.json"
28
+ if asr_pth.exists():
29
+ self._asrs = openf(asr_pth)
30
+ else:
31
+ log.info(f"ASR data not found for subset {self.subset}.")
32
+ asr_val_pth = self.vidc_dir / "docs/subset_data/asrs/asrs_val.json"
33
+ asr_train_pth = self.vidc_dir / "docs/subset_data/asrs/asrs_train.json"
34
+ if "val" in self.subset and asr_val_pth.exists():
35
+ log.info("Loading from ASR validation file.")
36
+ asrs = openf(asr_val_pth)
37
+ elif "train" in self.subset and asr_train_pth.exists():
38
+ log.info("Loading from ASR training file.")
39
+ asrs = openf(asr_train_pth)
40
+ else:
41
+ log.info("Loading from ASR file.")
42
+ asrs = openf(self.vidc_dir / "docs/asrs.json")
43
+ video_ids = set(self.video_ids) & set(asrs.keys())
44
+ self._asrs = {vid_id: asrs[vid_id] for vid_id in video_ids}
45
+ asr_pth.parent.mkdir(exist_ok=True)
46
+ writef(asr_pth, self._asrs)
47
+ else:
48
+ self._asrs = openf(self.vidc_dir / "docs/asrs.json")
49
+
50
+ def get_asr(self, video_id, add_end=False):
51
+ if video_id not in self.asrs:
52
+ return None
53
+
54
+ asr = self.asrs[video_id]
55
+ asr_clean = []
56
+ for t, s, e in zip(asr["text"], asr["start"], asr["end"]):
57
+ t = t.strip()
58
+ s = sec_to_hms(s)
59
+ e = sec_to_hms(e)
60
+ if add_end:
61
+ asr_clean.append(f"{s} - {e}: {t}")
62
+ else:
63
+ asr_clean.append(f"{s}: {t}")
64
+
65
+ return "\n".join(asr_clean) + "\n"
66
+
67
+ def __contains__(self, vid_id):
68
+ return vid_id in self.asrs
69
+
70
+
71
+ class PromptASR(Prompt):
72
+ def __init__(self, chapters: ChaptersASR, add_end=False):
73
+ super().__init__(chapters=chapters)
74
+ self.add_end = add_end
75
+
76
+ def get_task_prompt(self):
77
+ return "segment the text into distinct chapters based on thematic shifts or changes in topics.\n"
78
+
79
+ def get_transcript(self, vid_id):
80
+ vid_asr = self.chapters.get_asr(vid_id, add_end=self.add_end)
81
+ assert vid_asr is not None, f"ASR not found for video ID: {vid_id}"
82
+ return vid_asr
83
+
84
+ def __contains__(self, vid_id):
85
+ return vid_id in self.chapters
86
+
87
+
88
+ if __name__ == "__main__":
89
+ chapters = ChaptersASR(subset="s10k_train")
90
+ vid_id = chapters.sample()
91
+
92
+ prompt = PromptASR(chapters=chapters)
93
+ print(prompt.get_prompt_train(vid_id))
94
+ print(prompt.get_transcript(vid_id))
95
+ print(prompt.get_output(vid_id))
vidchapters.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from lutils import writef
4
+ from tqdm import tqdm
5
+
6
+ from src.test.utils_chapters import extract_chapters, filter_chapters
7
+ from src.utils import RankedLogger
8
+
9
+ log = RankedLogger(__name__, rank_zero_only=True)
10
+
11
+
12
+ def get_chapters(
13
+ inference,
14
+ prompt,
15
+ max_new_tokens,
16
+ do_sample=False,
17
+ vid_duration=None,
18
+ use_cache=True,
19
+ vid_id="",
20
+ ):
21
+ output_text = inference(
22
+ prompt=prompt,
23
+ max_new_tokens=max_new_tokens,
24
+ add_special_tokens=True,
25
+ do_sample=do_sample,
26
+ use_cache=use_cache,
27
+ )
28
+
29
+ if isinstance(output_text, int):
30
+ # the input is too long, return the length of the input
31
+ return output_text, None
32
+
33
+ chapters = extract_chapters(output_text)
34
+ chapters = filter_chapters(chapters, vid_duration=vid_duration)
35
+
36
+ if not chapters and not do_sample:
37
+ log.info(f"No chapters found for {vid_id}, trying again with sampling")
38
+ return get_chapters(
39
+ inference,
40
+ prompt,
41
+ max_new_tokens,
42
+ do_sample=True,
43
+ vid_duration=vid_duration,
44
+ )
45
+
46
+ return output_text, chapters
47
+
48
+
49
+ class VidChaptersTester:
50
+ def __init__(self, save_dir: str, do_sample=False, **kwargs):
51
+ self.save_dir = Path(save_dir)
52
+ self.save_dir.mkdir(exist_ok=True)
53
+ self.do_sample = do_sample
54
+
55
+ def __call__(
56
+ self,
57
+ inference,
58
+ test_dataloader,
59
+ max_new_tokens=1024,
60
+ ):
61
+ pbar = tqdm(
62
+ total=len(test_dataloader),
63
+ desc="Evaluating chapters",
64
+ )
65
+
66
+ for batch in test_dataloader:
67
+ vid_id = batch["vid_id"][0]
68
+ prompt = batch["prompt"][0]
69
+ transcript = batch["transcript"][0]
70
+ vid_duration = batch["vid_duration"][0]
71
+ prompt += transcript
72
+
73
+ chapters_pth = self.save_dir / f"{vid_id[:2]}" / f"{vid_id}.json"
74
+ chapters_pth.parent.mkdir(exist_ok=True)
75
+
76
+ if chapters_pth.exists():
77
+ pbar.update(1)
78
+ continue
79
+
80
+ pbar.set_description(f"vid_id: {vid_id}")
81
+
82
+ output_text, chapters = get_chapters(
83
+ inference,
84
+ prompt,
85
+ max_new_tokens,
86
+ do_sample=self.do_sample,
87
+ vid_duration=vid_duration,
88
+ vid_id=vid_id,
89
+ )
90
+
91
+ if chapters is None:
92
+ log.info(f"Input too long for {vid_id}, {output_text} tokens")
93
+ error_pth = chapters_pth.with_suffix(".txt")
94
+ writef(error_pth, [output_text])
95
+ pbar.update(1)
96
+ continue
97
+
98
+ if chapters:
99
+ vid_data = {
100
+ "chapters": chapters,
101
+ "output": output_text,
102
+ }
103
+ writef(chapters_pth, vid_data)
104
+
105
+ pbar.update(1)
106
+
107
+ pbar.close()