Spaces:
Paused
Paused
File size: 4,912 Bytes
1c72248 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import shutil
from collections import OrderedDict
import gc
from typing import List
import torch
from tqdm import tqdm
from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR
from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \
get_img_paths
from jobs.process import BaseExtensionProcess
def flush():
torch.cuda.empty_cache()
gc.collect()
class SyncFromCollection(BaseExtensionProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.min_width = config.get('min_width', 1024)
self.min_height = config.get('min_height', 1024)
# add our min_width and min_height to each dataset config if they don't exist
for dataset_config in config.get('dataset_sync', []):
if 'min_width' not in dataset_config:
dataset_config['min_width'] = self.min_width
if 'min_height' not in dataset_config:
dataset_config['min_height'] = self.min_height
self.dataset_configs: List[DatasetSyncCollectionConfig] = [
DatasetSyncCollectionConfig(**dataset_config)
for dataset_config in config.get('dataset_sync', [])
]
print(f"Found {len(self.dataset_configs)} dataset configs")
def move_new_images(self, root_dir: str):
raw_dir = os.path.join(root_dir, RAW_DIR)
new_dir = os.path.join(root_dir, NEW_DIR)
new_images = get_img_paths(new_dir)
for img_path in new_images:
# move to raw
new_path = os.path.join(raw_dir, os.path.basename(img_path))
shutil.move(img_path, new_path)
# remove new dir
shutil.rmtree(new_dir)
def sync_dataset(self, config: DatasetSyncCollectionConfig):
if config.host == 'unsplash':
get_images = get_unsplash_images
elif config.host == 'pexels':
get_images = get_pexels_images
else:
raise ValueError(f"Unknown host: {config.host}")
results = {
'num_downloaded': 0,
'num_skipped': 0,
'bad': 0,
'total': 0,
}
photos = get_images(config)
raw_dir = os.path.join(config.directory, RAW_DIR)
new_dir = os.path.join(config.directory, NEW_DIR)
raw_images = get_local_image_file_names(raw_dir)
new_images = get_local_image_file_names(new_dir)
for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"):
try:
if photo.filename not in raw_images and photo.filename not in new_images:
download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height)
results['num_downloaded'] += 1
else:
results['num_skipped'] += 1
except Exception as e:
print(f" - BAD({photo.id}): {e}")
results['bad'] += 1
continue
results['total'] += 1
return results
def print_results(self, results):
print(
f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}")
def run(self):
super().run()
print(f"Syncing {len(self.dataset_configs)} datasets")
all_results = None
failed_datasets = []
for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True):
try:
results = self.sync_dataset(dataset_config)
if all_results is None:
all_results = {**results}
else:
for key, value in results.items():
all_results[key] += value
self.print_results(results)
except Exception as e:
print(f" - FAILED: {e}")
if 'response' in e.__dict__:
error = f"{e.response.status_code}: {e.response.text}"
print(f" - {error}")
failed_datasets.append({'dataset': dataset_config, 'error': error})
else:
failed_datasets.append({'dataset': dataset_config, 'error': str(e)})
continue
print("Moving new images to raw")
for dataset_config in self.dataset_configs:
self.move_new_images(dataset_config.directory)
print("Done syncing datasets")
self.print_results(all_results)
if len(failed_datasets) > 0:
print(f"Failed to sync {len(failed_datasets)} datasets")
for failed in failed_datasets:
print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}")
print(f" - ERR: {failed['error']}")
|