Spaces:
Running
Running
# In this file, we define download_model | |
# It runs during container build time to get model weights built into the container | |
import os | |
import wget | |
import json | |
import tarfile | |
import tempfile | |
def download_models(config): | |
# Download parser checkpoint | |
# wget.download(config['schp']['download_url'], | |
# os.path.join(os.path.dirname(__file__), config['schp']['path'])) | |
wget.download(config['u2net']['download_url'], | |
os.path.join(os.path.dirname(__file__), config['u2net']['path'])) | |
# Download Super resolution model | |
wget.download(config['realesrgan']['download_url'], | |
os.path.join(os.path.dirname(__file__), config['realesrgan']['path'])) | |
# Download diffuser model checkpoint | |
_, local_file_name = tempfile.mkstemp() | |
local_file_name += '.tar' | |
wget.download(config['diffuser']['download_url'], local_file_name) | |
tar_file = tarfile.open(local_file_name) | |
tar_file.extractall(os.path.join(os.path.dirname(__file__),'checkpoints/')) | |
if __name__ == "__main__": | |
config_file = "configs/configs.json" | |
config_file = os.path.join(os.path.dirname(__file__), config_file) | |
with open(config_file) as fin: | |
config = json.load(fin) | |
download_models(config['models']) | |