|
import argparse |
|
import json |
|
import os |
|
|
|
import numpy as np |
|
from PIL import Image |
|
|
|
try: |
|
from sklearn.cluster import KMeans |
|
except ImportError: |
|
print("Please install sklearn to use this script.") |
|
exit(1) |
|
|
|
|
|
subfolder = "/path/to/your/dataset" |
|
output_file = os.path.join(subfolder, "transforms.json") |
|
|
|
|
|
frames = [] |
|
|
|
|
|
for file in sorted(os.listdir(subfolder)): |
|
if file.endswith(".json"): |
|
|
|
json_path = os.path.join(subfolder, file) |
|
with open(json_path, "r") as f: |
|
data = json.load(f) |
|
|
|
|
|
image_file = file.replace(".json", ".png") |
|
image_path = os.path.join(subfolder, image_file) |
|
if not os.path.exists(image_path): |
|
print(f"Image file not found for {file}, skipping...") |
|
continue |
|
with Image.open(image_path) as img: |
|
w, h = img.size |
|
|
|
|
|
K = data["K"] |
|
fx = K[0][0] * w |
|
fy = K[1][1] * h |
|
cx = K[0][2] * w |
|
cy = K[1][2] * h |
|
|
|
|
|
transform_matrix = np.array(data["c2w"]) |
|
|
|
transform_matrix[..., [1, 2]] *= -1 |
|
|
|
|
|
frames.append( |
|
{ |
|
"fl_x": fx, |
|
"fl_y": fy, |
|
"cx": cx, |
|
"cy": cy, |
|
"w": w, |
|
"h": h, |
|
"file_path": f"./{os.path.relpath(image_path, subfolder)}", |
|
"transform_matrix": transform_matrix.tolist(), |
|
} |
|
) |
|
|
|
|
|
transforms_data = {"orientation_override": "none", "frames": frames} |
|
|
|
|
|
with open(output_file, "w") as f: |
|
json.dump(transforms_data, f, indent=4) |
|
|
|
print(f"transforms.json generated at {output_file}") |
|
|
|
|
|
|
|
def create_train_test_split(frames, n, output_path, stride): |
|
|
|
positions = [] |
|
for frame in frames: |
|
transform_matrix = np.array(frame["transform_matrix"]) |
|
position = transform_matrix[:3, 3] |
|
direction = transform_matrix[:3, 2] / np.linalg.norm( |
|
transform_matrix[:3, 2] |
|
) |
|
positions.append(np.concatenate([position, direction])) |
|
|
|
positions = np.array(positions) |
|
|
|
|
|
kmeans = KMeans(n_clusters=n, random_state=42) |
|
kmeans.fit(positions) |
|
centers = kmeans.cluster_centers_ |
|
|
|
|
|
train_ids = [] |
|
for center in centers: |
|
distances = np.linalg.norm(positions - center, axis=1) |
|
train_ids.append(int(np.argmin(distances))) |
|
|
|
|
|
all_indices = set(range(len(frames))) |
|
remaining_indices = sorted(all_indices - set(train_ids)) |
|
test_ids = [ |
|
int(idx) for idx in remaining_indices[::stride] |
|
] |
|
|
|
|
|
split_data = {"train_ids": sorted(train_ids), "test_ids": test_ids} |
|
|
|
with open(output_path, "w") as f: |
|
json.dump(split_data, f, indent=4) |
|
|
|
print(f"Train-test split file generated at {output_path}") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description="Generate train-test split JSON file using K-means clustering." |
|
) |
|
parser.add_argument( |
|
"--n", |
|
type=int, |
|
required=True, |
|
help="Number of frames to include in the training set.", |
|
) |
|
parser.add_argument( |
|
"--stride", |
|
type=int, |
|
default=1, |
|
help="Stride for selecting test frames (not used with K-means).", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
train_test_split_path = os.path.join(subfolder, f"train_test_split_{args.n}.json") |
|
create_train_test_split(frames, args.n, train_test_split_path, args.stride) |
|
|