|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from pathlib import Path |
|
from tqdm import tqdm |
|
import pandas as pd |
|
import numpy as np |
|
from datasets import load_dataset |
|
from typing import Dict |
|
|
|
def empty_solution(sample): |
|
'''Return a minimal valid solution, i.e. 2 vertices and 1 edge.''' |
|
return np.zeros((2,3)), [(0, 1)] |
|
|
|
class Sample(Dict): |
|
def pick_repr_data(self, x): |
|
if hasattr(x, 'shape'): |
|
return x.shape |
|
if isinstance(x, (str, float, int)): |
|
return x |
|
if isinstance(x, list): |
|
return [type(x[0])] if len(x) > 0 else [] |
|
return type(x) |
|
|
|
def __repr__(self): |
|
|
|
return str({k: self.pick_repr_data(v) for k,v in self.items()}) |
|
|
|
import json |
|
if __name__ == "__main__": |
|
print ("------------ Loading dataset------------ ") |
|
param_path = Path('params.json') |
|
print(param_path) |
|
with param_path.open() as f: |
|
params = json.load(f) |
|
print(params) |
|
import os |
|
|
|
print('pwd:') |
|
os.system('pwd') |
|
print(os.system('ls -lahtr')) |
|
print('/tmp/data/') |
|
print(os.system('ls -lahtr /tmp/data/')) |
|
print('/tmp/data/data') |
|
print(os.system('ls -lahtrR /tmp/data/data')) |
|
|
|
|
|
data_path_test_server = Path('/tmp/data') |
|
data_path_local = Path().home() / '.cache/huggingface/datasets/usm3d___hoho25k_test_x/' |
|
|
|
if data_path_test_server.exists(): |
|
|
|
TEST_ENV = True |
|
else: |
|
|
|
TEST_ENV = False |
|
from huggingface_hub import snapshot_download |
|
_ = snapshot_download( |
|
repo_id=params['dataset'], |
|
local_dir="/tmp/data", |
|
repo_type="dataset", |
|
) |
|
data_path = data_path_test_server |
|
|
|
|
|
print(data_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
data_files = { |
|
"validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')], |
|
"test": [str(p) for p in data_path.rglob('*private*/**/*.tar')], |
|
} |
|
print(data_files) |
|
dataset = load_dataset( |
|
str(data_path / 'hoho25k_test_x.py'), |
|
data_files=data_files, |
|
trust_remote_code=True, |
|
writer_batch_size=100 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('load with webdataset') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(dataset, flush=True) |
|
|
|
|
|
print('------------ Now you can do your solution ---------------') |
|
solution = [] |
|
for subset_name in dataset: |
|
for i, sample in enumerate(tqdm(dataset[subset_name])): |
|
|
|
print(Sample(sample), flush=True) |
|
print('------') |
|
pred_vertices, pred_edges = empty_solution(sample) |
|
solution.append({ |
|
'order_id': sample['order_id'], |
|
'wf_vertices': pred_vertices.tolist(), |
|
'wf_edges': pred_edges |
|
}) |
|
|
|
print('------------ Saving results ---------------') |
|
sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"]) |
|
sub.to_parquet("submission.parquet") |
|
print("------------ Done ------------ ") |