import os import re import shutil import sys import tempfile import torch import uuid import warnings import hashlib from urllib.parse import urlparse from urllib.error import HTTPError, URLError from urllib.request import urlopen, Request from typing import Optional class _Faketqdm: # type: ignore[no-redef] def __init__(self, total=None, disable=False, unit=None, *args, **kwargs): self.total = total self.disable = disable self.n = 0 # Ignore all extra *args and **kwargs lest you want to reinvent tqdm def update(self, n): if self.disable: return self.n += n if self.total is None: sys.stderr.write(f"\r{self.n:.1f} bytes") else: sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%") sys.stderr.flush() # Don't bother implementing; use real tqdm if you want def set_description(self, *args, **kwargs): pass def write(self, s): sys.stderr.write(f"{s}\n") def close(self): self.disable = True def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if self.disable: return sys.stderr.write('\n') try: from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper except ImportError: tqdm = _Faketqdm def load_file_from_url( url: str, *, model_dir: str, progress: bool = True, file_name: Optional[str] = None, ) -> str: """Download a file from `url` into `model_dir`, using the file present if possible. Returns the path to the downloaded file. """ print(f'prepare load {url} to {model_dir}') os.makedirs(model_dir, exist_ok=True) print(f'created model_dir : {model_dir}') if not file_name: parts = urlparse(url) print(f'URL parts : {parts}') file_name = os.path.basename(parts.path) print(f'file_name : {file_name}') cached_file = os.path.abspath(os.path.join(model_dir, file_name)) if not os.path.exists(cached_file): print(f'Downloading: "{url}" to {cached_file}\n') from torch.hub import download_url_to_file proxy_download_url_to_file(url, cached_file, progress=progress) print ('DOWNLOADED FILE: ', url) print(f'Using cached file: {cached_file}') return cached_file def proxy_download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None, progress: bool = True) -> None: r"""Download object at the given URL to a local path. Args: url (str): URL of the object to download dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file`` hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. Default: None progress (bool, optional): whether or not to display a progress bar to stderr Default: True Example: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) >>> # xdoctest: +REQUIRES(POSIX) >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') """ print('PROXY DOWNLOAD') file_size = None req = Request(url, headers={"User-Agent": "torch.hub"}) u = urlopen(req) meta = u.info() if hasattr(meta, 'getheaders'): content_length = meta.getheaders("Content-Length") else: content_length = meta.get_all("Content-Length") if content_length is not None and len(content_length) > 0: file_size = int(content_length[0]) # We deliberately save it in a temp file and move it after # download is complete. This prevents a local working checkpoint # being overridden by a broken download. # We deliberately do not use NamedTemporaryFile to avoid restrictive # file permissions being applied to the downloaded file. dst = os.path.expanduser(dst) print(f'PROXY DOWNLOAD: {dst}') for seq in range(tempfile.TMP_MAX): tmp_dst = dst + '.' + uuid.uuid4().hex + '.partial' try: f = open(tmp_dst, 'w+b') except FileExistsError: continue break else: raise FileExistsError(errno.EEXIST, 'No usable temporary file name found') try: if hash_prefix is not None: sha256 = hashlib.sha256() with tqdm(total=file_size, disable=not progress, unit='B', unit_scale=True, unit_divisor=1024) as pbar: while True: buffer = u.read(8192) if len(buffer) == 0: break f.write(buffer) if hash_prefix is not None: sha256.update(buffer) pbar.update(len(buffer)) f.close() print(f'PROXY DOWNLOAD: closed file {f.name}') if hash_prefix is not None: digest = sha256.hexdigest() if digest[:len(hash_prefix)] != hash_prefix: raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")') shutil.move(f.name, dst) print(f'PROXY DOWNLOAD: moved file {f.name} to {dst}') finally: print(f'PROXY DOWNLOAD: finally closing {f.name}') f.close() if os.path.exists(f.name): os.remove(f.name)