Spaces:
Running
on
Zero
Running
on
Zero
Upload 3 files
Browse files- chapters.py +341 -0
- prompt.py +93 -0
- single_video.py +70 -0
chapters.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from lutils import openf, writef
|
5 |
+
|
6 |
+
|
7 |
+
class Chapters:
|
8 |
+
def __init__(self, vidc_dir: str = "dataset/", subset="", videos_dir="videos"):
|
9 |
+
self.vidc_dir = Path(vidc_dir)
|
10 |
+
assert self.vidc_dir.exists(), f"Directory {vidc_dir} does not exist."
|
11 |
+
self.subset = subset
|
12 |
+
|
13 |
+
self.data = self.load_subset_data(subset=subset)
|
14 |
+
self.video_ids = list(self.data.keys())
|
15 |
+
assert len(self.video_ids) == len(self.data), (
|
16 |
+
f"len(data)= {len(self.data)} != len(ids)= {len(self.video_ids)}."
|
17 |
+
)
|
18 |
+
|
19 |
+
self.videos_dir = videos_dir
|
20 |
+
|
21 |
+
def get_subset_ids(self, subset: str):
|
22 |
+
return openf(self.vidc_dir / f"docs/subset_data/{subset}.json")
|
23 |
+
|
24 |
+
def load_subset_data(self, subset=""):
|
25 |
+
if subset == "":
|
26 |
+
data_path = self.vidc_dir / "docs/chapters.json"
|
27 |
+
assert data_path.exists(), f"Data file {data_path} does not exist."
|
28 |
+
data = openf(data_path)
|
29 |
+
return data
|
30 |
+
|
31 |
+
data_path = self.vidc_dir / f"docs/subset_data/chapters/chapters_{subset}.json"
|
32 |
+
if not data_path.exists():
|
33 |
+
video_ids = openf(self.vidc_dir / f"docs/subset_data/{subset}.json")
|
34 |
+
data = openf(self.vidc_dir / "docs/chapters.json")
|
35 |
+
data = {video_id: data[video_id] for video_id in video_ids}
|
36 |
+
data_path.parent.mkdir(exist_ok=True)
|
37 |
+
writef(data, data_path)
|
38 |
+
else:
|
39 |
+
data = openf(data_path)
|
40 |
+
return data
|
41 |
+
|
42 |
+
def __len__(self):
|
43 |
+
return len(self.video_ids)
|
44 |
+
|
45 |
+
def __iter__(self):
|
46 |
+
return iter(self.video_ids)
|
47 |
+
|
48 |
+
def __contains__(self, vid_id):
|
49 |
+
return vid_id in self.data
|
50 |
+
|
51 |
+
def __getitem__(self, idx):
|
52 |
+
if isinstance(idx, int):
|
53 |
+
video_info = self.get_video_info(self.video_ids[idx])
|
54 |
+
video_info["video_id"] = self.video_ids[idx]
|
55 |
+
return video_info
|
56 |
+
elif isinstance(idx, str):
|
57 |
+
return self.get_video_info(idx)
|
58 |
+
else:
|
59 |
+
raise ValueError(f"Invalid index type {type(idx)}.")
|
60 |
+
|
61 |
+
def get_video_info(self, video_id):
|
62 |
+
assert video_id in self.data, f"Video ID {video_id} not found in data."
|
63 |
+
return self.data[video_id]
|
64 |
+
|
65 |
+
def get_chapters(self, video_id, hms=False, segments=False):
|
66 |
+
"""Retrieve chapters for a specific video ID."""
|
67 |
+
video_info = self.get_video_info(video_id)
|
68 |
+
|
69 |
+
vid_chapters = video_info.get("chapters", {})
|
70 |
+
chapter_timestamps = {}
|
71 |
+
for time, label in vid_chapters.items():
|
72 |
+
time = sec_to_hms(time) if hms else hms_to_sec(time)
|
73 |
+
chapter_timestamps[time] = label
|
74 |
+
if not segments:
|
75 |
+
return chapter_timestamps
|
76 |
+
|
77 |
+
# If segments is True, we return the timestamps as segments
|
78 |
+
assert not hms, "hms must be False if segments is True."
|
79 |
+
timestamps = list(chapter_timestamps.keys())
|
80 |
+
start_times = timestamps
|
81 |
+
end_times = timestamps[1:] + [self.get_duration(video_id)]
|
82 |
+
segmented_chapters = {}
|
83 |
+
|
84 |
+
for start_time, end_time in zip(start_times, end_times):
|
85 |
+
segment = (start_time, end_time)
|
86 |
+
segmented_chapters[segment] = chapter_timestamps[start_time]
|
87 |
+
|
88 |
+
return segmented_chapters
|
89 |
+
|
90 |
+
def get_labels(self, video_id):
|
91 |
+
"""Retrieve a list of chapter labels for a specific video ID."""
|
92 |
+
chapters = self.get_chapters(video_id)
|
93 |
+
return list(chapters.values())
|
94 |
+
|
95 |
+
def get_timestamps(
|
96 |
+
self, video_id, zero_handling="default", duration_handling="default"
|
97 |
+
):
|
98 |
+
"""Retrieve a list of chapter timestamps for a specific video ID."""
|
99 |
+
assert zero_handling in [
|
100 |
+
"default",
|
101 |
+
"add",
|
102 |
+
"remove",
|
103 |
+
], f"Invalid zero handling {zero_handling}."
|
104 |
+
|
105 |
+
assert duration_handling in [
|
106 |
+
"default",
|
107 |
+
"add",
|
108 |
+
"remove",
|
109 |
+
], f"Invalid duration handling {duration_handling}."
|
110 |
+
|
111 |
+
chapters = self.get_chapters(video_id)
|
112 |
+
timestamps = [int(time) for time in chapters]
|
113 |
+
|
114 |
+
# Handle zero timestamps based on the flag
|
115 |
+
if zero_handling == "add":
|
116 |
+
timestamps = (
|
117 |
+
[0] + timestamps if timestamps and timestamps[0] != 0 else timestamps
|
118 |
+
)
|
119 |
+
elif zero_handling == "remove":
|
120 |
+
timestamps = [time for time in timestamps if time != 0]
|
121 |
+
|
122 |
+
if duration_handling == "add":
|
123 |
+
duration = self.get_duration(video_id)
|
124 |
+
timestamps = (
|
125 |
+
timestamps + [duration] if timestamps[-1] != duration else timestamps
|
126 |
+
)
|
127 |
+
elif duration_handling == "remove":
|
128 |
+
duration = self.get_duration(video_id)
|
129 |
+
timestamps = timestamps[:-1] if timestamps[-1] == duration else timestamps
|
130 |
+
|
131 |
+
return timestamps
|
132 |
+
|
133 |
+
def get_n_timestamps(self, video_id, zero_handling="default"):
|
134 |
+
"""Retrieve the number of chapter timestamps for a specific video ID."""
|
135 |
+
return len(self.get_timestamps(video_id, zero_handling=zero_handling))
|
136 |
+
|
137 |
+
def get_n_chapters(self, video_id):
|
138 |
+
return len(self.get_gt_segments(video_id))
|
139 |
+
|
140 |
+
def get_n_labels(self, video_id):
|
141 |
+
return len(self.get_labels(video_id))
|
142 |
+
|
143 |
+
def get_duration(self, video_id, hms=False):
|
144 |
+
"""Retrieve the duration of a specific video ID."""
|
145 |
+
video_info = self.get_video_info(video_id)
|
146 |
+
duration = video_info.get("duration")
|
147 |
+
if hms:
|
148 |
+
return sec_to_hms(duration)
|
149 |
+
return duration
|
150 |
+
|
151 |
+
def get_hms_duration(self, video_id, string=True):
|
152 |
+
"""Retrieve the duration of a specific video ID in hours, minutes, and seconds."""
|
153 |
+
h, m, s = self.get_duration(video_id)
|
154 |
+
if string:
|
155 |
+
return f"{h:02d}:{m:02d}:{s:02d}"
|
156 |
+
else:
|
157 |
+
return h, m, s
|
158 |
+
|
159 |
+
def get_title(self, video_id):
|
160 |
+
"""Retrieve the title of a specific video ID."""
|
161 |
+
video_info = self.get_video_info(video_id)
|
162 |
+
return video_info.get("title")
|
163 |
+
|
164 |
+
def get_description(self, video_id):
|
165 |
+
"""Retrieve the description of a specific video ID."""
|
166 |
+
video_info = self.get_video_info(video_id)
|
167 |
+
return video_info.get("description")
|
168 |
+
|
169 |
+
def get_channel_id(self, video_id):
|
170 |
+
"""Retrieve the channel ID of a specific video ID."""
|
171 |
+
video_info = self.get_video_info(video_id)
|
172 |
+
return video_info.get("channel_id")
|
173 |
+
|
174 |
+
def get_view_count(self, video_id):
|
175 |
+
"""Retrieve the view count of a specific video ID."""
|
176 |
+
video_info = self.get_video_info(video_id)
|
177 |
+
return video_info.get("view_count")
|
178 |
+
|
179 |
+
def get_video_path(self, video_id):
|
180 |
+
"""Retrieve the path to the video file for a specific video ID."""
|
181 |
+
video_pth = (
|
182 |
+
self.vidc_dir / self.videos_dir / f"{video_id[:2]}" / f"{video_id}.mp4"
|
183 |
+
)
|
184 |
+
assert video_pth.exists(), f"Video file {video_pth} does not exist."
|
185 |
+
return str(video_pth)
|
186 |
+
|
187 |
+
def sample(self, n=1):
|
188 |
+
"""Sample n video IDs."""
|
189 |
+
sample = random.sample(self.video_ids, n)
|
190 |
+
|
191 |
+
if n == 1:
|
192 |
+
return sample[0]
|
193 |
+
else:
|
194 |
+
return sample
|
195 |
+
|
196 |
+
def get_gt_segments(self, video_id, zero_handling="add"):
|
197 |
+
"""Generate ground truth segments based on video ID with options to adjust zero timestamps."""
|
198 |
+
timestamps = self.get_timestamps(video_id, zero_handling=zero_handling)
|
199 |
+
segments = boundary2seg(
|
200 |
+
timestamps, self.get_duration(video_id), zero_handling=zero_handling
|
201 |
+
)
|
202 |
+
return segments
|
203 |
+
|
204 |
+
def get_segments(self, video_id, zero_handling="add"):
|
205 |
+
return self.get_gt_segments(
|
206 |
+
video_id,
|
207 |
+
zero_handling=zero_handling,
|
208 |
+
)
|
209 |
+
|
210 |
+
def get_all_gt_segments(self, zero_handling="add"):
|
211 |
+
"""Generate ground truth segments for all video IDs."""
|
212 |
+
return {
|
213 |
+
video_id: self.get_gt_segments(video_id, zero_handling=zero_handling)
|
214 |
+
for video_id in self.video_ids
|
215 |
+
}
|
216 |
+
|
217 |
+
def get_pred_segments(self, vid_id, vid_preds, zero_handling="add"):
|
218 |
+
duration = self.get_duration(vid_id)
|
219 |
+
if isinstance(vid_preds, list):
|
220 |
+
# vid_preds are the timestamps
|
221 |
+
vid_preds = (
|
222 |
+
[hms_to_sec(hms) for hms in vid_preds]
|
223 |
+
if isinstance(vid_preds[0], str)
|
224 |
+
else vid_preds
|
225 |
+
)
|
226 |
+
return boundary2seg(vid_preds, duration, zero_handling=zero_handling)
|
227 |
+
elif isinstance(vid_preds, dict):
|
228 |
+
# vid_preds are the chapters with key timestamps
|
229 |
+
vid_preds_new = {}
|
230 |
+
start_times = list(vid_preds.keys())
|
231 |
+
end_times = start_times[1:] + [duration]
|
232 |
+
for start_time, end_time in zip(start_times, end_times):
|
233 |
+
segment = (hms_to_sec(start_time), hms_to_sec(end_time))
|
234 |
+
vid_preds_new[segment] = vid_preds[start_time]
|
235 |
+
return vid_preds_new
|
236 |
+
|
237 |
+
def convert_predictions_to_segments(self, preds):
|
238 |
+
segments = {}
|
239 |
+
for video_id, vid_preds in preds.items():
|
240 |
+
segments[video_id] = self.get_pred_segments(video_id, vid_preds)
|
241 |
+
|
242 |
+
return segments
|
243 |
+
|
244 |
+
def get_link(self, video_id):
|
245 |
+
return f"https://www.youtube.com/watch?v={video_id}"
|
246 |
+
|
247 |
+
def get_url(self, video_id):
|
248 |
+
return f"https://www.youtube.com/watch?v={video_id}"
|
249 |
+
|
250 |
+
@staticmethod
|
251 |
+
def sec_to_hms(seconds, string=True, short=False):
|
252 |
+
return sec_to_hms(seconds, string=True, short=False)
|
253 |
+
|
254 |
+
@staticmethod
|
255 |
+
def hms_to_sec(time_str, enable_single_part=False):
|
256 |
+
return hms_to_sec(time_str, enable_single_part=enable_single_part)
|
257 |
+
|
258 |
+
@staticmethod
|
259 |
+
def clean_segment(segment, zero_handling="add"):
|
260 |
+
return clean_segment(segment, zero_handling=zero_handling)
|
261 |
+
|
262 |
+
@staticmethod
|
263 |
+
def clean_timestamps(timestamps, zero_handling="remove"):
|
264 |
+
return clean_tiemstamps(timestamps, zero_handling=zero_handling)
|
265 |
+
|
266 |
+
|
267 |
+
def boundary2seg(boundaries, duration, zero_handling="add"):
|
268 |
+
if zero_handling == "add" and boundaries[0] != 0:
|
269 |
+
boundaries = [0] + boundaries
|
270 |
+
|
271 |
+
gt = []
|
272 |
+
for i in range(len(boundaries)):
|
273 |
+
if i < len(boundaries) - 1:
|
274 |
+
gt.append((float(boundaries[i]), float(boundaries[i + 1])))
|
275 |
+
else:
|
276 |
+
# Check if the last boundary equals the duration
|
277 |
+
if boundaries[i] != duration:
|
278 |
+
gt.append((float(boundaries[i]), float(duration)))
|
279 |
+
return gt
|
280 |
+
|
281 |
+
|
282 |
+
def sec_to_hms(seconds, string=True, short=False):
|
283 |
+
"""Convert seconds to hours, minutes, and seconds."""
|
284 |
+
if isinstance(seconds, str) and ":" in seconds:
|
285 |
+
return sec_to_hms(hms_to_sec(seconds), string=string, short=short)
|
286 |
+
if isinstance(seconds, str) and seconds.isdigit() or isinstance(seconds, float):
|
287 |
+
seconds = int(seconds)
|
288 |
+
m, s = divmod(seconds, 60)
|
289 |
+
h, m = divmod(m, 60)
|
290 |
+
if string:
|
291 |
+
if h == 0 and short:
|
292 |
+
return f"{m:02d}:{s:02d}"
|
293 |
+
return f"{h:02d}:{m:02d}:{s:02d}"
|
294 |
+
return h, m, s
|
295 |
+
|
296 |
+
|
297 |
+
def hms_to_sec(time_str, enable_single_part=False):
|
298 |
+
"""Convert hours, minutes, and seconds to total seconds."""
|
299 |
+
if isinstance(time_str, (int, float)):
|
300 |
+
return time_str
|
301 |
+
if isinstance(time_str, str) and time_str.isdigit():
|
302 |
+
return int(time_str)
|
303 |
+
|
304 |
+
parts = time_str.split(":")
|
305 |
+
if len(parts) == 3:
|
306 |
+
hours, minutes, seconds = parts
|
307 |
+
seconds = float(seconds) if "." in seconds else int(seconds)
|
308 |
+
minutes = int(minutes)
|
309 |
+
if minutes >= 60 or seconds >= 60:
|
310 |
+
return False
|
311 |
+
total_seconds = int(hours) * 3600 + minutes * 60 + seconds
|
312 |
+
elif len(parts) == 2:
|
313 |
+
minutes, seconds = parts
|
314 |
+
seconds = float(seconds) if "." in seconds else int(seconds)
|
315 |
+
minutes = int(minutes)
|
316 |
+
if seconds >= 60:
|
317 |
+
return False
|
318 |
+
total_seconds = int(minutes) * 60 + seconds
|
319 |
+
elif len(parts) == 1 and enable_single_part:
|
320 |
+
seconds = float(parts[0]) if "." in parts[0] else int(parts[0])
|
321 |
+
total_seconds = seconds
|
322 |
+
else:
|
323 |
+
raise ValueError("Invalid time format")
|
324 |
+
return total_seconds
|
325 |
+
|
326 |
+
|
327 |
+
def clean_segment(segment, zero_handling="add"):
|
328 |
+
if zero_handling == "add" and segment[0][0] != 0.0:
|
329 |
+
segment.insert(0, [0.0, segment[0][0]])
|
330 |
+
elif zero_handling == "remove" and segment[0][0] == 0.0:
|
331 |
+
segment.pop(0)
|
332 |
+
return segment
|
333 |
+
|
334 |
+
|
335 |
+
def clean_tiemstamps(timestamps, zero_handling="remove"):
|
336 |
+
if zero_handling == "remove":
|
337 |
+
return [time for time in timestamps if time != 0]
|
338 |
+
elif zero_handling == "add":
|
339 |
+
return [0] + timestamps if timestamps[0] != 0 else timestamps
|
340 |
+
else:
|
341 |
+
return timestamps
|
prompt.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.data.chapters import Chapters, sec_to_hms
|
2 |
+
|
3 |
+
|
4 |
+
class Prompt:
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
chapters: Chapters,
|
8 |
+
):
|
9 |
+
self.chapters = chapters
|
10 |
+
|
11 |
+
def __contains__(self, vid_id):
|
12 |
+
raise NotImplementedError(
|
13 |
+
"Subclasses must implement the '__contains__' method."
|
14 |
+
)
|
15 |
+
|
16 |
+
def get_duration_prompt(self, vid_id: str) -> str:
|
17 |
+
duration = self.chapters.get_duration(vid_id, hms=True)
|
18 |
+
return f"Given the complete transcript of a video of duration {duration}, "
|
19 |
+
|
20 |
+
def get_task_prompt(self) -> str:
|
21 |
+
raise NotImplementedError(
|
22 |
+
"Subclasses must implement the 'get_task_prompt' method."
|
23 |
+
)
|
24 |
+
|
25 |
+
def get_format_instruction(self):
|
26 |
+
return "Identify the approximate start time of each chapter in the format 'hh:mm:ss - Title'. "
|
27 |
+
|
28 |
+
def get_new_line_instruction(self):
|
29 |
+
return "Ensure each chapter entry is on a new line. "
|
30 |
+
|
31 |
+
def get_focus_instruction(self):
|
32 |
+
return "Focus on significant topic changes that would merit a new chapter in a video, "
|
33 |
+
|
34 |
+
def get_no_summaries_instruction(self):
|
35 |
+
return "but do not provide summaries of the chapters.\n"
|
36 |
+
|
37 |
+
def get_transcript_introduction(self):
|
38 |
+
return "Here is the transcript to analyze:\n"
|
39 |
+
|
40 |
+
def get_transcript(self, vid_id: str) -> str:
|
41 |
+
# By default, the transcript is the same for train and test
|
42 |
+
raise NotImplementedError(
|
43 |
+
"Subclasses must implement the 'get_transcript' method."
|
44 |
+
)
|
45 |
+
|
46 |
+
def get_transcript_train(self, vid_id: str) -> str:
|
47 |
+
return self.get_transcript(vid_id)
|
48 |
+
|
49 |
+
def get_transcript_test(self, vid_id: str) -> str:
|
50 |
+
return self.get_transcript(vid_id)
|
51 |
+
|
52 |
+
def get_base_prompt(self, vid_id: str) -> str:
|
53 |
+
prompt_parts = [
|
54 |
+
self.get_duration_prompt(vid_id),
|
55 |
+
self.get_task_prompt(),
|
56 |
+
self.get_format_instruction(),
|
57 |
+
self.get_new_line_instruction(),
|
58 |
+
self.get_focus_instruction(),
|
59 |
+
self.get_no_summaries_instruction(),
|
60 |
+
self.get_transcript_introduction(),
|
61 |
+
]
|
62 |
+
return "".join(prompt_parts)
|
63 |
+
|
64 |
+
def get_prompt_train(self, vid_id: str) -> str:
|
65 |
+
return self.get_base_prompt(vid_id)
|
66 |
+
|
67 |
+
def get_prompt_test(self, vid_id: str) -> str:
|
68 |
+
return self.get_base_prompt(vid_id)
|
69 |
+
|
70 |
+
def get_output(self, vid_id: str) -> str:
|
71 |
+
vid_chapters = self.chapters.get_chapters(vid_id)
|
72 |
+
answers = []
|
73 |
+
for chp_time, chp_title in vid_chapters.items():
|
74 |
+
chp_time = sec_to_hms(chp_time)
|
75 |
+
answers.append(f"{chp_time} - {chp_title}")
|
76 |
+
|
77 |
+
return "\n".join(answers)
|
78 |
+
|
79 |
+
def get_dialog(self, vid_id: str) -> str:
|
80 |
+
prompt = self.get_prompt_train(vid_id)
|
81 |
+
transcript = self.get_transcript_train(vid_id)
|
82 |
+
output = self.get_output(vid_id)
|
83 |
+
dialog = [
|
84 |
+
{
|
85 |
+
"role": "user",
|
86 |
+
"content": prompt + transcript,
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"role": "assistant",
|
90 |
+
"content": output,
|
91 |
+
},
|
92 |
+
]
|
93 |
+
return dialog
|
single_video.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
from lutils import openf, writef
|
4 |
+
|
5 |
+
from src.data.chapters import sec_to_hms
|
6 |
+
from tools.extract.asr import ASRProcessor
|
7 |
+
|
8 |
+
|
9 |
+
class SingleVideo:
|
10 |
+
"""
|
11 |
+
A simplified implementation of the src.data.chapters.Chapters interface for single video inference.
|
12 |
+
|
13 |
+
This class mimics the behavior of the ChaptersASR class but is designed to work with
|
14 |
+
a single video file rather than a dataset. It provides the necessary methods
|
15 |
+
required by the PromptASR class for generating chapter timestamps and titles.
|
16 |
+
|
17 |
+
Note: This class is intended for inference only and should not be used for
|
18 |
+
training or evaluation purposes.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, video_path: Path):
|
22 |
+
self.video_path = video_path
|
23 |
+
self.video_ids = [video_path.stem]
|
24 |
+
assert video_path.exists(), f"Video file {video_path} not found"
|
25 |
+
self.asr, self.duration = get_asr(video_path, overwrite=True)
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.video_ids)
|
29 |
+
|
30 |
+
def __iter__(self):
|
31 |
+
return iter(self.video_ids)
|
32 |
+
|
33 |
+
def __contains__(self, vid_id):
|
34 |
+
return vid_id in self.video_ids
|
35 |
+
|
36 |
+
def get_duration(self, vid_id, hms=False):
|
37 |
+
assert vid_id == self.video_ids[0], f"Invalid video ID: {vid_id}"
|
38 |
+
if hms:
|
39 |
+
return sec_to_hms(self.duration)
|
40 |
+
return self.duration
|
41 |
+
|
42 |
+
def get_asr(self, vid_id):
|
43 |
+
assert vid_id == self.video_ids[0], f"Invalid video ID: {vid_id}"
|
44 |
+
return self.asr
|
45 |
+
|
46 |
+
|
47 |
+
def get_asr(video_path: Path, overwrite=False):
|
48 |
+
output_dir = Path(f"outputs/inference/{video_path.stem}")
|
49 |
+
asr_output = output_dir / "asr.txt"
|
50 |
+
duration_output = output_dir / "duration.txt"
|
51 |
+
if asr_output.exists() and duration_output.exists() and not overwrite:
|
52 |
+
asr = openf(asr_output)
|
53 |
+
asr = "\n".join(asr) + "\n"
|
54 |
+
|
55 |
+
duration = openf(duration_output)
|
56 |
+
assert isinstance(duration, list) and len(duration) == 1, (
|
57 |
+
f"Duration is not a list of length 1: {duration}"
|
58 |
+
)
|
59 |
+
duration = float(duration[0])
|
60 |
+
assert duration > 0, f"Duration is not positive: {duration}"
|
61 |
+
return asr, duration
|
62 |
+
|
63 |
+
print(f"\n=== 🎙️ Processing ASR for {video_path} ===")
|
64 |
+
asr_processor = ASRProcessor()
|
65 |
+
asr, duration = asr_processor.get_asr(video_path)
|
66 |
+
print(f"=== ✅ ASR processing complete for {video_path} ===\n")
|
67 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
68 |
+
writef(asr_output, asr)
|
69 |
+
writef(duration_output, str(duration))
|
70 |
+
return asr, duration
|