ymcmy commited on
Commit
ce88d40
·
verified ·
1 Parent(s): 48f6c67

Create preprocess.py

Browse files
Files changed (1) hide show
  1. preprocess.py +140 -0
preprocess.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # preprocess by converting images into fingerprints and save to disk
2
+ import numpy as np
3
+ import cv2
4
+ import PIL.Image
5
+ from scipy.interpolate import griddata
6
+ import h5py
7
+ from utils import azi_diff
8
+ from tqdm import tqdm
9
+ import os
10
+ import random
11
+ import pickle
12
+ import logging
13
+ import joblib
14
+
15
+ def get_image_files(directory):
16
+ image_extensions = {'.jpg', '.jpeg', '.png'}
17
+ image_files = []
18
+ for root, _, files in os.walk(directory):
19
+ for file in files:
20
+ if os.path.splitext(file)[1].lower() in image_extensions:
21
+ image_files.append(os.path.join(root, file))
22
+ return image_files
23
+
24
+ def load_image_files(class1_dirs, class2_dirs):
25
+ class1_files = []
26
+ for directory in tqdm(class1_dirs):
27
+ class1_files.extend(get_image_files(directory))
28
+
29
+ class2_files = []
30
+ for directory in tqdm(class2_dirs):
31
+ class2_files.extend(get_image_files(directory))
32
+
33
+ # Ensure equal representation
34
+ min_length = min(len(class1_files), len(class2_files))
35
+
36
+ random.shuffle(class1_files)
37
+ random.shuffle(class2_files)
38
+
39
+ class1_files = class1_files[:min_length]
40
+ class2_files = class2_files[:min_length]
41
+
42
+ print(f"Number of files: Real = {len(class1_files)}, Fake = {len(class2_files)}")
43
+
44
+ return class1_files, class2_files
45
+
46
+
47
+ def process_and_save_h5(file_label_pairs, patch_num, N, save_interval, joblib_batch_size, output_dir, start_by=0):
48
+ def process_file(file_label):
49
+ path, label = file_label
50
+ try:
51
+ result = azi_diff(path, patch_num, N)
52
+ return result, label
53
+ except Exception as e:
54
+ logging.error(f"Error processing file {path}: {str(e)}")
55
+ return None, None
56
+
57
+ num_files = len(file_label_pairs)
58
+ num_saves = (num_files - start_by + save_interval - 1) // save_interval
59
+
60
+ if not os.path.exists(output_dir):
61
+ os.makedirs(output_dir)
62
+
63
+ with tqdm(total=num_files - start_by, desc="Processing files", unit="image") as pbar:
64
+ for save_index in range(num_saves):
65
+ save_start = start_by + save_index * save_interval
66
+ save_end = min(save_start + save_interval, num_files)
67
+ batch_pairs = file_label_pairs[save_start:save_end]
68
+
69
+ all_rich = []
70
+ all_poor = []
71
+ all_labels = []
72
+ for batch_start in range(0, len(batch_pairs), joblib_batch_size):
73
+ batch_end = min(batch_start + joblib_batch_size, len(batch_pairs))
74
+ small_batch_pairs = batch_pairs[batch_start:batch_end]
75
+
76
+ processed_data = joblib.Parallel(n_jobs=-1)(
77
+ joblib.delayed(process_file)(file_label) for file_label in small_batch_pairs
78
+ )
79
+ for data, label in processed_data:
80
+ if data is not None:
81
+ all_rich.append(data['total_emb'][0])
82
+ all_poor.append(data['total_emb'][1])
83
+ all_labels.append(label)
84
+
85
+ pbar.update(len(small_batch_pairs))
86
+
87
+ next_save_start = save_end
88
+ output_filename = f"{output_dir}/processed_data_{next_save_start}.h5"
89
+ logging.info(f"Saving {output_filename}")
90
+
91
+ with h5py.File(output_filename, 'w') as h5file:
92
+ h5file.create_dataset('rich', data=np.array(all_rich))
93
+ h5file.create_dataset('rich', data=np.array(all_poor))
94
+ h5file.create_dataset('labels', data=np.array(all_labels))
95
+
96
+ logging.info(f"Successfully saved {output_filename}")
97
+
98
+ del all_rich
99
+ del all_poor
100
+ del all_labels
101
+
102
+
103
+
104
+
105
+ load=False
106
+ class1_dirs = [
107
+ "/home/archive/real/",
108
+ "/home/13k_real/",
109
+ "/home/AI_detection_dataset/Real_AI_SD_LD_Dataset/train/real/",
110
+ "/home/AI_detection_dataset/Real_AI_SD_LD_Dataset/test/real/"
111
+ ] #real 0
112
+
113
+ class2_dirs = [
114
+ "/home/archive/fakeV2/fake-v2/",
115
+ "/home/dalle3/",
116
+ "/home/AI_detection_dataset/Real_AI_SD_LD_Dataset/train/fake/",
117
+ "/home/AI_detection_dataset/Real_AI_SD_LD_Dataset/test/fake/"
118
+ ] #fake 1
119
+ output_dir = "/content/drive/MyDrive/joblibsaves"
120
+ file_paths_pickle_save_dir='/content/drive/MyDrive/aigc_file_paths.pkl'
121
+ patch_num = 128
122
+ N = 256
123
+ save_interval = 2000
124
+ joblib_batch_size = 400
125
+ start_by = 0
126
+
127
+ if load==True:
128
+ with open(file_paths_pickle_save_dir, 'rb') as file:
129
+ file_label_pairs=pickle.load(file)
130
+ print(len(file_label_pairs))
131
+ else:
132
+ class1_files, class2_files = load_image_files(class1_dirs, class2_dirs)
133
+ file_label_pairs = list(zip(class1_files, [0] * len(class1_files))) + list(zip(class2_files, [1] * len(class2_files)))
134
+ random.shuffle(file_label_pairs)
135
+ with open(file_paths_pickle_save_dir, 'wb') as file:
136
+ pickle.dump(file_label_pairs, file)
137
+ print(len(file_label_pairs))
138
+
139
+ logging.basicConfig(level=logging.INFO)
140
+ process_and_save_h5(file_label_pairs, patch_num, N, save_interval, joblib_batch_size, output_dir, start_by)