# In this file, we define download_model # It runs during container build time to get model weights built into the container import os import wget import json import tarfile import tempfile def download_models(config): # Download parser checkpoint # wget.download(config['schp']['download_url'], # os.path.join(os.path.dirname(__file__), config['schp']['path'])) wget.download(config['u2net']['download_url'], os.path.join(os.path.dirname(__file__), config['u2net']['path'])) # Download Super resolution model wget.download(config['realesrgan']['download_url'], os.path.join(os.path.dirname(__file__), config['realesrgan']['path'])) # Download diffuser model checkpoint _, local_file_name = tempfile.mkstemp() local_file_name += '.tar' wget.download(config['diffuser']['download_url'], local_file_name) tar_file = tarfile.open(local_file_name) tar_file.extractall(os.path.join(os.path.dirname(__file__),'checkpoints/')) if __name__ == "__main__": config_file = "configs/configs.json" config_file = os.path.join(os.path.dirname(__file__), config_file) with open(config_file) as fin: config = json.load(fin) download_models(config['models'])