ClothQuill / download.py
Bismay
Initial commit
475e066
# In this file, we define download_model
# It runs during container build time to get model weights built into the container
import os
import wget
import json
import tarfile
import tempfile
def download_models(config):
# Download parser checkpoint
# wget.download(config['schp']['download_url'],
# os.path.join(os.path.dirname(__file__), config['schp']['path']))
wget.download(config['u2net']['download_url'],
os.path.join(os.path.dirname(__file__), config['u2net']['path']))
# Download Super resolution model
wget.download(config['realesrgan']['download_url'],
os.path.join(os.path.dirname(__file__), config['realesrgan']['path']))
# Download diffuser model checkpoint
_, local_file_name = tempfile.mkstemp()
local_file_name += '.tar'
wget.download(config['diffuser']['download_url'], local_file_name)
tar_file = tarfile.open(local_file_name)
tar_file.extractall(os.path.join(os.path.dirname(__file__),'checkpoints/'))
if __name__ == "__main__":
config_file = "configs/configs.json"
config_file = os.path.join(os.path.dirname(__file__), config_file)
with open(config_file) as fin:
config = json.load(fin)
download_models(config['models'])