File size: 4,912 Bytes
1c72248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import shutil
from collections import OrderedDict
import gc
from typing import List

import torch
from tqdm import tqdm

from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR
from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \
    get_img_paths
from jobs.process import BaseExtensionProcess


def flush():
    torch.cuda.empty_cache()
    gc.collect()


class SyncFromCollection(BaseExtensionProcess):

    def __init__(self, process_id: int, job, config: OrderedDict):
        super().__init__(process_id, job, config)

        self.min_width = config.get('min_width', 1024)
        self.min_height = config.get('min_height', 1024)

        # add our min_width and min_height to each dataset config if they don't exist
        for dataset_config in config.get('dataset_sync', []):
            if 'min_width' not in dataset_config:
                dataset_config['min_width'] = self.min_width
            if 'min_height' not in dataset_config:
                dataset_config['min_height'] = self.min_height

        self.dataset_configs: List[DatasetSyncCollectionConfig] = [
            DatasetSyncCollectionConfig(**dataset_config)
            for dataset_config in config.get('dataset_sync', [])
        ]
        print(f"Found {len(self.dataset_configs)} dataset configs")

    def move_new_images(self, root_dir: str):
        raw_dir = os.path.join(root_dir, RAW_DIR)
        new_dir = os.path.join(root_dir, NEW_DIR)
        new_images = get_img_paths(new_dir)

        for img_path in new_images:
            # move to raw
            new_path = os.path.join(raw_dir, os.path.basename(img_path))
            shutil.move(img_path, new_path)

        # remove new dir
        shutil.rmtree(new_dir)

    def sync_dataset(self, config: DatasetSyncCollectionConfig):
        if config.host == 'unsplash':
            get_images = get_unsplash_images
        elif config.host == 'pexels':
            get_images = get_pexels_images
        else:
            raise ValueError(f"Unknown host: {config.host}")

        results = {
            'num_downloaded': 0,
            'num_skipped': 0,
            'bad': 0,
            'total': 0,
        }

        photos = get_images(config)
        raw_dir = os.path.join(config.directory, RAW_DIR)
        new_dir = os.path.join(config.directory, NEW_DIR)
        raw_images = get_local_image_file_names(raw_dir)
        new_images = get_local_image_file_names(new_dir)

        for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"):
            try:
                if photo.filename not in raw_images and photo.filename not in new_images:
                    download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height)
                    results['num_downloaded'] += 1
                else:
                    results['num_skipped'] += 1
            except Exception as e:
                print(f" - BAD({photo.id}): {e}")
                results['bad'] += 1
                continue
            results['total'] += 1

        return results

    def print_results(self, results):
        print(
            f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}")

    def run(self):
        super().run()
        print(f"Syncing {len(self.dataset_configs)} datasets")
        all_results = None
        failed_datasets = []
        for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True):
            try:
                results = self.sync_dataset(dataset_config)
                if all_results is None:
                    all_results = {**results}
                else:
                    for key, value in results.items():
                        all_results[key] += value

                self.print_results(results)
            except Exception as e:
                print(f" - FAILED: {e}")
                if 'response' in e.__dict__:
                    error = f"{e.response.status_code}: {e.response.text}"
                    print(f"   - {error}")
                    failed_datasets.append({'dataset': dataset_config, 'error': error})
                else:
                    failed_datasets.append({'dataset': dataset_config, 'error': str(e)})
                continue

        print("Moving new images to raw")
        for dataset_config in self.dataset_configs:
            self.move_new_images(dataset_config.directory)

        print("Done syncing datasets")
        self.print_results(all_results)

        if len(failed_datasets) > 0:
            print(f"Failed to sync {len(failed_datasets)} datasets")
            for failed in failed_datasets:
                print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}")
                print(f"   - ERR: {failed['error']}")