lucas-ventura commited on
Commit
7a4927a
·
verified ·
1 Parent(s): 2e23f3d

Upload 3 files

Browse files
Files changed (3) hide show
  1. chapters.py +341 -0
  2. prompt.py +93 -0
  3. single_video.py +70 -0
chapters.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from pathlib import Path
3
+
4
+ from lutils import openf, writef
5
+
6
+
7
+ class Chapters:
8
+ def __init__(self, vidc_dir: str = "dataset/", subset="", videos_dir="videos"):
9
+ self.vidc_dir = Path(vidc_dir)
10
+ assert self.vidc_dir.exists(), f"Directory {vidc_dir} does not exist."
11
+ self.subset = subset
12
+
13
+ self.data = self.load_subset_data(subset=subset)
14
+ self.video_ids = list(self.data.keys())
15
+ assert len(self.video_ids) == len(self.data), (
16
+ f"len(data)= {len(self.data)} != len(ids)= {len(self.video_ids)}."
17
+ )
18
+
19
+ self.videos_dir = videos_dir
20
+
21
+ def get_subset_ids(self, subset: str):
22
+ return openf(self.vidc_dir / f"docs/subset_data/{subset}.json")
23
+
24
+ def load_subset_data(self, subset=""):
25
+ if subset == "":
26
+ data_path = self.vidc_dir / "docs/chapters.json"
27
+ assert data_path.exists(), f"Data file {data_path} does not exist."
28
+ data = openf(data_path)
29
+ return data
30
+
31
+ data_path = self.vidc_dir / f"docs/subset_data/chapters/chapters_{subset}.json"
32
+ if not data_path.exists():
33
+ video_ids = openf(self.vidc_dir / f"docs/subset_data/{subset}.json")
34
+ data = openf(self.vidc_dir / "docs/chapters.json")
35
+ data = {video_id: data[video_id] for video_id in video_ids}
36
+ data_path.parent.mkdir(exist_ok=True)
37
+ writef(data, data_path)
38
+ else:
39
+ data = openf(data_path)
40
+ return data
41
+
42
+ def __len__(self):
43
+ return len(self.video_ids)
44
+
45
+ def __iter__(self):
46
+ return iter(self.video_ids)
47
+
48
+ def __contains__(self, vid_id):
49
+ return vid_id in self.data
50
+
51
+ def __getitem__(self, idx):
52
+ if isinstance(idx, int):
53
+ video_info = self.get_video_info(self.video_ids[idx])
54
+ video_info["video_id"] = self.video_ids[idx]
55
+ return video_info
56
+ elif isinstance(idx, str):
57
+ return self.get_video_info(idx)
58
+ else:
59
+ raise ValueError(f"Invalid index type {type(idx)}.")
60
+
61
+ def get_video_info(self, video_id):
62
+ assert video_id in self.data, f"Video ID {video_id} not found in data."
63
+ return self.data[video_id]
64
+
65
+ def get_chapters(self, video_id, hms=False, segments=False):
66
+ """Retrieve chapters for a specific video ID."""
67
+ video_info = self.get_video_info(video_id)
68
+
69
+ vid_chapters = video_info.get("chapters", {})
70
+ chapter_timestamps = {}
71
+ for time, label in vid_chapters.items():
72
+ time = sec_to_hms(time) if hms else hms_to_sec(time)
73
+ chapter_timestamps[time] = label
74
+ if not segments:
75
+ return chapter_timestamps
76
+
77
+ # If segments is True, we return the timestamps as segments
78
+ assert not hms, "hms must be False if segments is True."
79
+ timestamps = list(chapter_timestamps.keys())
80
+ start_times = timestamps
81
+ end_times = timestamps[1:] + [self.get_duration(video_id)]
82
+ segmented_chapters = {}
83
+
84
+ for start_time, end_time in zip(start_times, end_times):
85
+ segment = (start_time, end_time)
86
+ segmented_chapters[segment] = chapter_timestamps[start_time]
87
+
88
+ return segmented_chapters
89
+
90
+ def get_labels(self, video_id):
91
+ """Retrieve a list of chapter labels for a specific video ID."""
92
+ chapters = self.get_chapters(video_id)
93
+ return list(chapters.values())
94
+
95
+ def get_timestamps(
96
+ self, video_id, zero_handling="default", duration_handling="default"
97
+ ):
98
+ """Retrieve a list of chapter timestamps for a specific video ID."""
99
+ assert zero_handling in [
100
+ "default",
101
+ "add",
102
+ "remove",
103
+ ], f"Invalid zero handling {zero_handling}."
104
+
105
+ assert duration_handling in [
106
+ "default",
107
+ "add",
108
+ "remove",
109
+ ], f"Invalid duration handling {duration_handling}."
110
+
111
+ chapters = self.get_chapters(video_id)
112
+ timestamps = [int(time) for time in chapters]
113
+
114
+ # Handle zero timestamps based on the flag
115
+ if zero_handling == "add":
116
+ timestamps = (
117
+ [0] + timestamps if timestamps and timestamps[0] != 0 else timestamps
118
+ )
119
+ elif zero_handling == "remove":
120
+ timestamps = [time for time in timestamps if time != 0]
121
+
122
+ if duration_handling == "add":
123
+ duration = self.get_duration(video_id)
124
+ timestamps = (
125
+ timestamps + [duration] if timestamps[-1] != duration else timestamps
126
+ )
127
+ elif duration_handling == "remove":
128
+ duration = self.get_duration(video_id)
129
+ timestamps = timestamps[:-1] if timestamps[-1] == duration else timestamps
130
+
131
+ return timestamps
132
+
133
+ def get_n_timestamps(self, video_id, zero_handling="default"):
134
+ """Retrieve the number of chapter timestamps for a specific video ID."""
135
+ return len(self.get_timestamps(video_id, zero_handling=zero_handling))
136
+
137
+ def get_n_chapters(self, video_id):
138
+ return len(self.get_gt_segments(video_id))
139
+
140
+ def get_n_labels(self, video_id):
141
+ return len(self.get_labels(video_id))
142
+
143
+ def get_duration(self, video_id, hms=False):
144
+ """Retrieve the duration of a specific video ID."""
145
+ video_info = self.get_video_info(video_id)
146
+ duration = video_info.get("duration")
147
+ if hms:
148
+ return sec_to_hms(duration)
149
+ return duration
150
+
151
+ def get_hms_duration(self, video_id, string=True):
152
+ """Retrieve the duration of a specific video ID in hours, minutes, and seconds."""
153
+ h, m, s = self.get_duration(video_id)
154
+ if string:
155
+ return f"{h:02d}:{m:02d}:{s:02d}"
156
+ else:
157
+ return h, m, s
158
+
159
+ def get_title(self, video_id):
160
+ """Retrieve the title of a specific video ID."""
161
+ video_info = self.get_video_info(video_id)
162
+ return video_info.get("title")
163
+
164
+ def get_description(self, video_id):
165
+ """Retrieve the description of a specific video ID."""
166
+ video_info = self.get_video_info(video_id)
167
+ return video_info.get("description")
168
+
169
+ def get_channel_id(self, video_id):
170
+ """Retrieve the channel ID of a specific video ID."""
171
+ video_info = self.get_video_info(video_id)
172
+ return video_info.get("channel_id")
173
+
174
+ def get_view_count(self, video_id):
175
+ """Retrieve the view count of a specific video ID."""
176
+ video_info = self.get_video_info(video_id)
177
+ return video_info.get("view_count")
178
+
179
+ def get_video_path(self, video_id):
180
+ """Retrieve the path to the video file for a specific video ID."""
181
+ video_pth = (
182
+ self.vidc_dir / self.videos_dir / f"{video_id[:2]}" / f"{video_id}.mp4"
183
+ )
184
+ assert video_pth.exists(), f"Video file {video_pth} does not exist."
185
+ return str(video_pth)
186
+
187
+ def sample(self, n=1):
188
+ """Sample n video IDs."""
189
+ sample = random.sample(self.video_ids, n)
190
+
191
+ if n == 1:
192
+ return sample[0]
193
+ else:
194
+ return sample
195
+
196
+ def get_gt_segments(self, video_id, zero_handling="add"):
197
+ """Generate ground truth segments based on video ID with options to adjust zero timestamps."""
198
+ timestamps = self.get_timestamps(video_id, zero_handling=zero_handling)
199
+ segments = boundary2seg(
200
+ timestamps, self.get_duration(video_id), zero_handling=zero_handling
201
+ )
202
+ return segments
203
+
204
+ def get_segments(self, video_id, zero_handling="add"):
205
+ return self.get_gt_segments(
206
+ video_id,
207
+ zero_handling=zero_handling,
208
+ )
209
+
210
+ def get_all_gt_segments(self, zero_handling="add"):
211
+ """Generate ground truth segments for all video IDs."""
212
+ return {
213
+ video_id: self.get_gt_segments(video_id, zero_handling=zero_handling)
214
+ for video_id in self.video_ids
215
+ }
216
+
217
+ def get_pred_segments(self, vid_id, vid_preds, zero_handling="add"):
218
+ duration = self.get_duration(vid_id)
219
+ if isinstance(vid_preds, list):
220
+ # vid_preds are the timestamps
221
+ vid_preds = (
222
+ [hms_to_sec(hms) for hms in vid_preds]
223
+ if isinstance(vid_preds[0], str)
224
+ else vid_preds
225
+ )
226
+ return boundary2seg(vid_preds, duration, zero_handling=zero_handling)
227
+ elif isinstance(vid_preds, dict):
228
+ # vid_preds are the chapters with key timestamps
229
+ vid_preds_new = {}
230
+ start_times = list(vid_preds.keys())
231
+ end_times = start_times[1:] + [duration]
232
+ for start_time, end_time in zip(start_times, end_times):
233
+ segment = (hms_to_sec(start_time), hms_to_sec(end_time))
234
+ vid_preds_new[segment] = vid_preds[start_time]
235
+ return vid_preds_new
236
+
237
+ def convert_predictions_to_segments(self, preds):
238
+ segments = {}
239
+ for video_id, vid_preds in preds.items():
240
+ segments[video_id] = self.get_pred_segments(video_id, vid_preds)
241
+
242
+ return segments
243
+
244
+ def get_link(self, video_id):
245
+ return f"https://www.youtube.com/watch?v={video_id}"
246
+
247
+ def get_url(self, video_id):
248
+ return f"https://www.youtube.com/watch?v={video_id}"
249
+
250
+ @staticmethod
251
+ def sec_to_hms(seconds, string=True, short=False):
252
+ return sec_to_hms(seconds, string=True, short=False)
253
+
254
+ @staticmethod
255
+ def hms_to_sec(time_str, enable_single_part=False):
256
+ return hms_to_sec(time_str, enable_single_part=enable_single_part)
257
+
258
+ @staticmethod
259
+ def clean_segment(segment, zero_handling="add"):
260
+ return clean_segment(segment, zero_handling=zero_handling)
261
+
262
+ @staticmethod
263
+ def clean_timestamps(timestamps, zero_handling="remove"):
264
+ return clean_tiemstamps(timestamps, zero_handling=zero_handling)
265
+
266
+
267
+ def boundary2seg(boundaries, duration, zero_handling="add"):
268
+ if zero_handling == "add" and boundaries[0] != 0:
269
+ boundaries = [0] + boundaries
270
+
271
+ gt = []
272
+ for i in range(len(boundaries)):
273
+ if i < len(boundaries) - 1:
274
+ gt.append((float(boundaries[i]), float(boundaries[i + 1])))
275
+ else:
276
+ # Check if the last boundary equals the duration
277
+ if boundaries[i] != duration:
278
+ gt.append((float(boundaries[i]), float(duration)))
279
+ return gt
280
+
281
+
282
+ def sec_to_hms(seconds, string=True, short=False):
283
+ """Convert seconds to hours, minutes, and seconds."""
284
+ if isinstance(seconds, str) and ":" in seconds:
285
+ return sec_to_hms(hms_to_sec(seconds), string=string, short=short)
286
+ if isinstance(seconds, str) and seconds.isdigit() or isinstance(seconds, float):
287
+ seconds = int(seconds)
288
+ m, s = divmod(seconds, 60)
289
+ h, m = divmod(m, 60)
290
+ if string:
291
+ if h == 0 and short:
292
+ return f"{m:02d}:{s:02d}"
293
+ return f"{h:02d}:{m:02d}:{s:02d}"
294
+ return h, m, s
295
+
296
+
297
+ def hms_to_sec(time_str, enable_single_part=False):
298
+ """Convert hours, minutes, and seconds to total seconds."""
299
+ if isinstance(time_str, (int, float)):
300
+ return time_str
301
+ if isinstance(time_str, str) and time_str.isdigit():
302
+ return int(time_str)
303
+
304
+ parts = time_str.split(":")
305
+ if len(parts) == 3:
306
+ hours, minutes, seconds = parts
307
+ seconds = float(seconds) if "." in seconds else int(seconds)
308
+ minutes = int(minutes)
309
+ if minutes >= 60 or seconds >= 60:
310
+ return False
311
+ total_seconds = int(hours) * 3600 + minutes * 60 + seconds
312
+ elif len(parts) == 2:
313
+ minutes, seconds = parts
314
+ seconds = float(seconds) if "." in seconds else int(seconds)
315
+ minutes = int(minutes)
316
+ if seconds >= 60:
317
+ return False
318
+ total_seconds = int(minutes) * 60 + seconds
319
+ elif len(parts) == 1 and enable_single_part:
320
+ seconds = float(parts[0]) if "." in parts[0] else int(parts[0])
321
+ total_seconds = seconds
322
+ else:
323
+ raise ValueError("Invalid time format")
324
+ return total_seconds
325
+
326
+
327
+ def clean_segment(segment, zero_handling="add"):
328
+ if zero_handling == "add" and segment[0][0] != 0.0:
329
+ segment.insert(0, [0.0, segment[0][0]])
330
+ elif zero_handling == "remove" and segment[0][0] == 0.0:
331
+ segment.pop(0)
332
+ return segment
333
+
334
+
335
+ def clean_tiemstamps(timestamps, zero_handling="remove"):
336
+ if zero_handling == "remove":
337
+ return [time for time in timestamps if time != 0]
338
+ elif zero_handling == "add":
339
+ return [0] + timestamps if timestamps[0] != 0 else timestamps
340
+ else:
341
+ return timestamps
prompt.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.data.chapters import Chapters, sec_to_hms
2
+
3
+
4
+ class Prompt:
5
+ def __init__(
6
+ self,
7
+ chapters: Chapters,
8
+ ):
9
+ self.chapters = chapters
10
+
11
+ def __contains__(self, vid_id):
12
+ raise NotImplementedError(
13
+ "Subclasses must implement the '__contains__' method."
14
+ )
15
+
16
+ def get_duration_prompt(self, vid_id: str) -> str:
17
+ duration = self.chapters.get_duration(vid_id, hms=True)
18
+ return f"Given the complete transcript of a video of duration {duration}, "
19
+
20
+ def get_task_prompt(self) -> str:
21
+ raise NotImplementedError(
22
+ "Subclasses must implement the 'get_task_prompt' method."
23
+ )
24
+
25
+ def get_format_instruction(self):
26
+ return "Identify the approximate start time of each chapter in the format 'hh:mm:ss - Title'. "
27
+
28
+ def get_new_line_instruction(self):
29
+ return "Ensure each chapter entry is on a new line. "
30
+
31
+ def get_focus_instruction(self):
32
+ return "Focus on significant topic changes that would merit a new chapter in a video, "
33
+
34
+ def get_no_summaries_instruction(self):
35
+ return "but do not provide summaries of the chapters.\n"
36
+
37
+ def get_transcript_introduction(self):
38
+ return "Here is the transcript to analyze:\n"
39
+
40
+ def get_transcript(self, vid_id: str) -> str:
41
+ # By default, the transcript is the same for train and test
42
+ raise NotImplementedError(
43
+ "Subclasses must implement the 'get_transcript' method."
44
+ )
45
+
46
+ def get_transcript_train(self, vid_id: str) -> str:
47
+ return self.get_transcript(vid_id)
48
+
49
+ def get_transcript_test(self, vid_id: str) -> str:
50
+ return self.get_transcript(vid_id)
51
+
52
+ def get_base_prompt(self, vid_id: str) -> str:
53
+ prompt_parts = [
54
+ self.get_duration_prompt(vid_id),
55
+ self.get_task_prompt(),
56
+ self.get_format_instruction(),
57
+ self.get_new_line_instruction(),
58
+ self.get_focus_instruction(),
59
+ self.get_no_summaries_instruction(),
60
+ self.get_transcript_introduction(),
61
+ ]
62
+ return "".join(prompt_parts)
63
+
64
+ def get_prompt_train(self, vid_id: str) -> str:
65
+ return self.get_base_prompt(vid_id)
66
+
67
+ def get_prompt_test(self, vid_id: str) -> str:
68
+ return self.get_base_prompt(vid_id)
69
+
70
+ def get_output(self, vid_id: str) -> str:
71
+ vid_chapters = self.chapters.get_chapters(vid_id)
72
+ answers = []
73
+ for chp_time, chp_title in vid_chapters.items():
74
+ chp_time = sec_to_hms(chp_time)
75
+ answers.append(f"{chp_time} - {chp_title}")
76
+
77
+ return "\n".join(answers)
78
+
79
+ def get_dialog(self, vid_id: str) -> str:
80
+ prompt = self.get_prompt_train(vid_id)
81
+ transcript = self.get_transcript_train(vid_id)
82
+ output = self.get_output(vid_id)
83
+ dialog = [
84
+ {
85
+ "role": "user",
86
+ "content": prompt + transcript,
87
+ },
88
+ {
89
+ "role": "assistant",
90
+ "content": output,
91
+ },
92
+ ]
93
+ return dialog
single_video.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from lutils import openf, writef
4
+
5
+ from src.data.chapters import sec_to_hms
6
+ from tools.extract.asr import ASRProcessor
7
+
8
+
9
+ class SingleVideo:
10
+ """
11
+ A simplified implementation of the src.data.chapters.Chapters interface for single video inference.
12
+
13
+ This class mimics the behavior of the ChaptersASR class but is designed to work with
14
+ a single video file rather than a dataset. It provides the necessary methods
15
+ required by the PromptASR class for generating chapter timestamps and titles.
16
+
17
+ Note: This class is intended for inference only and should not be used for
18
+ training or evaluation purposes.
19
+ """
20
+
21
+ def __init__(self, video_path: Path):
22
+ self.video_path = video_path
23
+ self.video_ids = [video_path.stem]
24
+ assert video_path.exists(), f"Video file {video_path} not found"
25
+ self.asr, self.duration = get_asr(video_path, overwrite=True)
26
+
27
+ def __len__(self):
28
+ return len(self.video_ids)
29
+
30
+ def __iter__(self):
31
+ return iter(self.video_ids)
32
+
33
+ def __contains__(self, vid_id):
34
+ return vid_id in self.video_ids
35
+
36
+ def get_duration(self, vid_id, hms=False):
37
+ assert vid_id == self.video_ids[0], f"Invalid video ID: {vid_id}"
38
+ if hms:
39
+ return sec_to_hms(self.duration)
40
+ return self.duration
41
+
42
+ def get_asr(self, vid_id):
43
+ assert vid_id == self.video_ids[0], f"Invalid video ID: {vid_id}"
44
+ return self.asr
45
+
46
+
47
+ def get_asr(video_path: Path, overwrite=False):
48
+ output_dir = Path(f"outputs/inference/{video_path.stem}")
49
+ asr_output = output_dir / "asr.txt"
50
+ duration_output = output_dir / "duration.txt"
51
+ if asr_output.exists() and duration_output.exists() and not overwrite:
52
+ asr = openf(asr_output)
53
+ asr = "\n".join(asr) + "\n"
54
+
55
+ duration = openf(duration_output)
56
+ assert isinstance(duration, list) and len(duration) == 1, (
57
+ f"Duration is not a list of length 1: {duration}"
58
+ )
59
+ duration = float(duration[0])
60
+ assert duration > 0, f"Duration is not positive: {duration}"
61
+ return asr, duration
62
+
63
+ print(f"\n=== 🎙️ Processing ASR for {video_path} ===")
64
+ asr_processor = ASRProcessor()
65
+ asr, duration = asr_processor.get_asr(video_path)
66
+ print(f"=== ✅ ASR processing complete for {video_path} ===\n")
67
+ output_dir.mkdir(parents=True, exist_ok=True)
68
+ writef(asr_output, asr)
69
+ writef(duration_output, str(duration))
70
+ return asr, duration