File size: 2,121 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import argparse
import os
import sys
import tarfile
sys.path.append('.')
from imaginaire.utils.io import download_file_from_google_drive  # noqa: E402

URLS = {
    'pix2pixhd': '1Xg9m184zkuG8H0LHdBtSzt2VbMi3SWwR',
    'spade': '1ESm-gHWu_aMHnKF42qkGc8qf1SBECsgf',
    'funit': '1a-EE_6RsYPUoKxEl5oXrpRmKYUltqaD-',
    'coco_funit': '1JYVYB0Q1VStDLOb0SBJbN1vkaf6KrGDh',
    'unit': '17BbwnCG7qF7FI-t9VkORv2XCKqlrY1CO',
    'munit': '1VPgHGuQfmm1N1Vh56wr34wtAwaXzjXtH',
    'vid2vid': '1SHvGPMq-55GDUQ0Ac2Ng0eyG5xCPeKhc',
    'fs_vid2vid': '1fTj0HHjzcitgsSeG5O_aWMF8yvCQUQkN',
    'wc_vid2vid/cityscapes': '1KKzrTHfbpBY9xtLqK8e3QvX8psSdrFcD',
    'wc_vid2vid/mannequin': '1mafZf9KJrwUGGI1kBTvwgehHSqP5iaA0',
    'gancraft': '1m6q7ZtYJjxFL0SQ_WzMbvoLZxXmI5_vJ',
}


def parse_args():
    parser = argparse.ArgumentParser(description='Download test data.')
    parser.add_argument('--model_name', required=True,
                        help='Name of the model.')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    test_data_dir = 'projects/' + args.model_name + '/test_data'
    print(test_data_dir)
    assert args.model_name in URLS, 'No sample test data available'
    url = URLS[args.model_name]

    if os.path.exists(test_data_dir):
        print('Test data exists at', test_data_dir)
        compressed_path = test_data_dir + '.tar.gz'
        # Extract the dataset.
        print('Extracting test data to', test_data_dir)
        with tarfile.open(compressed_path) as tar:
            tar.extractall(path=test_data_dir)
    else:
        os.makedirs(test_data_dir, exist_ok=True)
        # Download the compressed dataset.
        compressed_path = test_data_dir + '.tar.gz'
        if not os.path.exists(compressed_path):
            print('Downloading test data to', compressed_path)
            download_file_from_google_drive(url, compressed_path)

        # Extract the dataset.
        print('Extracting test data to', test_data_dir)
        with tarfile.open(compressed_path) as tar:
            tar.extractall(path=test_data_dir)


if __name__ == "__main__":
    main()