svjack commited on
Commit
059ce64
·
verified ·
1 Parent(s): 6d2843d

Upload run_video_ccip.py

Browse files
Files changed (1) hide show
  1. run_video_ccip.py +130 -0
run_video_ccip.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ python run_video_ccip.py Beyond_the_Boundary_Videos_sm Beyond_the_Boundary_Videos_sm_named --image_dir named_image_dir
3
+
4
+ import pandas as pd
5
+ import pathlib
6
+ import json
7
+ def read_j(x):
8
+ with open(x, "r") as f:
9
+ return json.load(f)
10
+
11
+ path_s = pd.Series(list(pathlib.Path("Beyond_the_Boundary_Videos_sm_named/").rglob("*.json"))).map(str)
12
+ df = pd.DataFrame(path_s.head(int(1e10)).map(
13
+ lambda x: (x, read_j(x))
14
+ ).values.tolist()
15
+ ).explode(1).applymap(
16
+ lambda x: x["results"] if type(x) == type({}) else x
17
+ ).explode(1)
18
+ df
19
+ right_df = pd.json_normalize(df[1])
20
+ df = pd.concat([df.reset_index().iloc[:, 1:], right_df.reset_index().iloc[:,1:]], axis = 1)
21
+ df = df[
22
+ df["prediction"] == "Same"
23
+ ]
24
+ ###df[0].sort_values().drop_duplicates()
25
+ df
26
+ '''
27
+
28
+ import os
29
+ import json
30
+ from tqdm import tqdm
31
+ from PIL import Image
32
+ from ccip import _VALID_MODEL_NAMES, _DEFAULT_MODEL_NAMES, ccip_difference, ccip_default_threshold
33
+ import pathlib
34
+ import argparse
35
+ from moviepy.editor import VideoFileClip
36
+
37
+ def load_images_from_directory(image_dir):
38
+ """
39
+ 从指定目录加载图片,构建字典。
40
+ 键为图片的文件名(不含扩展名),值为图片的 PIL.Image 对象。
41
+ """
42
+ name_image_dict = {}
43
+ image_paths = list(pathlib.Path(image_dir).rglob("*.png")) + list(pathlib.Path(image_dir).rglob("*.jpg")) + list(pathlib.Path(image_dir).rglob("*.jpeg"))
44
+
45
+ for image_path in tqdm(image_paths, desc="Loading images"):
46
+ image = Image.open(image_path)
47
+ name = os.path.splitext(os.path.basename(image_path))[0] # 去掉扩展名
48
+ name_image_dict[name] = image
49
+
50
+ return name_image_dict
51
+
52
+ def _compare_with_dataset(imagex, model_name, name_image_dict):
53
+ threshold = ccip_default_threshold(model_name)
54
+ results = []
55
+
56
+ for name, imagey in name_image_dict.items():
57
+ diff = ccip_difference(imagex, imagey)
58
+ result = {
59
+ "difference": diff,
60
+ "prediction": 'Same' if diff <= threshold else 'Not Same',
61
+ "name": name
62
+ }
63
+ results.append(result)
64
+
65
+ # 按照 diff 值进行排序
66
+ results.sort(key=lambda x: x["difference"])
67
+
68
+ return results
69
+
70
+ def process_video(video_path, model_name, output_dir, max_frames, name_image_dict):
71
+ # 打开视频文件
72
+ clip = VideoFileClip(video_path)
73
+ duration = clip.duration
74
+ fps = clip.fps
75
+ total_frames = int(duration * fps)
76
+
77
+ # 计算帧间隔
78
+ frame_interval = max(1, total_frames // max_frames)
79
+
80
+ # 生成输出文件名
81
+ video_name = os.path.splitext(os.path.basename(video_path))[0]
82
+ output_file = os.path.join(output_dir, f"{video_name}.json")
83
+
84
+ results = []
85
+
86
+ # 采样帧并处理
87
+ for i in tqdm(range(0, total_frames, frame_interval), desc="Processing frames"):
88
+ frame = clip.get_frame(i / fps)
89
+ image = Image.fromarray(frame)
90
+ frame_results = _compare_with_dataset(image, model_name, name_image_dict)
91
+ results.append({
92
+ "frame_time": i / fps,
93
+ "results": frame_results
94
+ })
95
+
96
+ # 保存结果到 JSON 文件
97
+ with open(output_file, 'w') as f:
98
+ json.dump(results, f, indent=4)
99
+
100
+ def main():
101
+ parser = argparse.ArgumentParser(description="Compare videos with a dataset and save results as JSON.")
102
+ parser.add_argument("input_path", type=str, help="Path to the input video or directory containing videos.")
103
+ parser.add_argument("output_dir", type=str, help="Directory to save the output JSON files.")
104
+ parser.add_argument("--image_dir", type=str, required=True, help="Directory containing images to compare with.")
105
+ parser.add_argument("--model", type=str, default=_DEFAULT_MODEL_NAMES, choices=_VALID_MODEL_NAMES, help="Model to use for comparison.")
106
+ parser.add_argument("--max_frames", type=int, default=3, help="Maximum number of frames to process per video.")
107
+
108
+ args = parser.parse_args()
109
+
110
+ # 确保输出目录存在
111
+ os.makedirs(args.output_dir, exist_ok=True)
112
+
113
+ # 加载图片数据集
114
+ name_image_dict = load_images_from_directory(args.image_dir)
115
+
116
+ # 判断输入路径是文件还是目录
117
+ if os.path.isfile(args.input_path):
118
+ video_paths = [args.input_path]
119
+ elif os.path.isdir(args.input_path):
120
+ video_paths = list(pathlib.Path(args.input_path).rglob("*.mp4")) + list(pathlib.Path(args.input_path).rglob("*.avi"))
121
+ else:
122
+ raise ValueError("Input path must be a valid file or directory.")
123
+ video_paths = list(map(str, video_paths))
124
+
125
+ # 处理每个视频
126
+ for video_path in tqdm(video_paths, desc="Processing videos"):
127
+ process_video(video_path, args.model, args.output_dir, args.max_frames, name_image_dict)
128
+
129
+ if __name__ == '__main__':
130
+ main()