import os import shutil from collections import OrderedDict import gc from typing import List import torch from tqdm import tqdm from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \ get_img_paths from jobs.process import BaseExtensionProcess def flush(): torch.cuda.empty_cache() gc.collect() class SyncFromCollection(BaseExtensionProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) self.min_width = config.get('min_width', 1024) self.min_height = config.get('min_height', 1024) # add our min_width and min_height to each dataset config if they don't exist for dataset_config in config.get('dataset_sync', []): if 'min_width' not in dataset_config: dataset_config['min_width'] = self.min_width if 'min_height' not in dataset_config: dataset_config['min_height'] = self.min_height self.dataset_configs: List[DatasetSyncCollectionConfig] = [ DatasetSyncCollectionConfig(**dataset_config) for dataset_config in config.get('dataset_sync', []) ] print(f"Found {len(self.dataset_configs)} dataset configs") def move_new_images(self, root_dir: str): raw_dir = os.path.join(root_dir, RAW_DIR) new_dir = os.path.join(root_dir, NEW_DIR) new_images = get_img_paths(new_dir) for img_path in new_images: # move to raw new_path = os.path.join(raw_dir, os.path.basename(img_path)) shutil.move(img_path, new_path) # remove new dir shutil.rmtree(new_dir) def sync_dataset(self, config: DatasetSyncCollectionConfig): if config.host == 'unsplash': get_images = get_unsplash_images elif config.host == 'pexels': get_images = get_pexels_images else: raise ValueError(f"Unknown host: {config.host}") results = { 'num_downloaded': 0, 'num_skipped': 0, 'bad': 0, 'total': 0, } photos = get_images(config) raw_dir = os.path.join(config.directory, RAW_DIR) new_dir = os.path.join(config.directory, NEW_DIR) raw_images = get_local_image_file_names(raw_dir) new_images = get_local_image_file_names(new_dir) for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"): try: if photo.filename not in raw_images and photo.filename not in new_images: download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height) results['num_downloaded'] += 1 else: results['num_skipped'] += 1 except Exception as e: print(f" - BAD({photo.id}): {e}") results['bad'] += 1 continue results['total'] += 1 return results def print_results(self, results): print( f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}") def run(self): super().run() print(f"Syncing {len(self.dataset_configs)} datasets") all_results = None failed_datasets = [] for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True): try: results = self.sync_dataset(dataset_config) if all_results is None: all_results = {**results} else: for key, value in results.items(): all_results[key] += value self.print_results(results) except Exception as e: print(f" - FAILED: {e}") if 'response' in e.__dict__: error = f"{e.response.status_code}: {e.response.text}" print(f" - {error}") failed_datasets.append({'dataset': dataset_config, 'error': error}) else: failed_datasets.append({'dataset': dataset_config, 'error': str(e)}) continue print("Moving new images to raw") for dataset_config in self.dataset_configs: self.move_new_images(dataset_config.directory) print("Done syncing datasets") self.print_results(all_results) if len(failed_datasets) > 0: print(f"Failed to sync {len(failed_datasets)} datasets") for failed in failed_datasets: print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}") print(f" - ERR: {failed['error']}")