|
import os |
|
import re |
|
import shutil |
|
import sys |
|
import tempfile |
|
import torch |
|
import uuid |
|
import warnings |
|
import hashlib |
|
from urllib.parse import urlparse |
|
from urllib.error import HTTPError, URLError |
|
from urllib.request import urlopen, Request |
|
from typing import Optional |
|
|
|
class _Faketqdm: |
|
|
|
def __init__(self, total=None, disable=False, |
|
unit=None, *args, **kwargs): |
|
self.total = total |
|
self.disable = disable |
|
self.n = 0 |
|
|
|
|
|
def update(self, n): |
|
if self.disable: |
|
return |
|
|
|
self.n += n |
|
if self.total is None: |
|
sys.stderr.write(f"\r{self.n:.1f} bytes") |
|
else: |
|
sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%") |
|
sys.stderr.flush() |
|
|
|
|
|
def set_description(self, *args, **kwargs): |
|
pass |
|
|
|
def write(self, s): |
|
sys.stderr.write(f"{s}\n") |
|
|
|
def close(self): |
|
self.disable = True |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
if self.disable: |
|
return |
|
|
|
sys.stderr.write('\n') |
|
|
|
try: |
|
from tqdm import tqdm |
|
except ImportError: |
|
tqdm = _Faketqdm |
|
|
|
def load_file_from_url( |
|
url: str, |
|
*, |
|
model_dir: str, |
|
progress: bool = True, |
|
file_name: Optional[str] = None, |
|
) -> str: |
|
"""Download a file from `url` into `model_dir`, using the file present if possible. |
|
|
|
Returns the path to the downloaded file. |
|
""" |
|
print(f'prepare load {url} to {model_dir}') |
|
os.makedirs(model_dir, exist_ok=True) |
|
print(f'created model_dir : {model_dir}') |
|
if not file_name: |
|
parts = urlparse(url) |
|
print(f'URL parts : {parts}') |
|
file_name = os.path.basename(parts.path) |
|
print(f'file_name : {file_name}') |
|
cached_file = os.path.abspath(os.path.join(model_dir, file_name)) |
|
if not os.path.exists(cached_file): |
|
print(f'Downloading: "{url}" to {cached_file}\n') |
|
from torch.hub import download_url_to_file |
|
proxy_download_url_to_file(url, cached_file, progress=progress) |
|
print ('DOWNLOADED FILE: ', url) |
|
print(f'Using cached file: {cached_file}') |
|
return cached_file |
|
|
|
|
|
def proxy_download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None, |
|
progress: bool = True) -> None: |
|
r"""Download object at the given URL to a local path. |
|
|
|
Args: |
|
url (str): URL of the object to download |
|
dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file`` |
|
hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. |
|
Default: None |
|
progress (bool, optional): whether or not to display a progress bar to stderr |
|
Default: True |
|
|
|
Example: |
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) |
|
>>> # xdoctest: +REQUIRES(POSIX) |
|
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') |
|
|
|
""" |
|
print('PROXY DOWNLOAD') |
|
file_size = None |
|
req = Request(url, headers={"User-Agent": "torch.hub"}) |
|
u = urlopen(req) |
|
meta = u.info() |
|
if hasattr(meta, 'getheaders'): |
|
content_length = meta.getheaders("Content-Length") |
|
else: |
|
content_length = meta.get_all("Content-Length") |
|
if content_length is not None and len(content_length) > 0: |
|
file_size = int(content_length[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
dst = os.path.expanduser(dst) |
|
print(f'PROXY DOWNLOAD: {dst}') |
|
for seq in range(tempfile.TMP_MAX): |
|
tmp_dst = dst + '.' + uuid.uuid4().hex + '.partial' |
|
try: |
|
f = open(tmp_dst, 'w+b') |
|
except FileExistsError: |
|
continue |
|
break |
|
else: |
|
raise FileExistsError(errno.EEXIST, 'No usable temporary file name found') |
|
|
|
try: |
|
if hash_prefix is not None: |
|
sha256 = hashlib.sha256() |
|
with tqdm(total=file_size, disable=not progress, |
|
unit='B', unit_scale=True, unit_divisor=1024) as pbar: |
|
while True: |
|
buffer = u.read(8192) |
|
if len(buffer) == 0: |
|
break |
|
f.write(buffer) |
|
if hash_prefix is not None: |
|
sha256.update(buffer) |
|
pbar.update(len(buffer)) |
|
|
|
f.close() |
|
print(f'PROXY DOWNLOAD: closed file {f.name}') |
|
if hash_prefix is not None: |
|
digest = sha256.hexdigest() |
|
if digest[:len(hash_prefix)] != hash_prefix: |
|
raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")') |
|
shutil.move(f.name, dst) |
|
print(f'PROXY DOWNLOAD: moved file {f.name} to {dst}') |
|
finally: |
|
print(f'PROXY DOWNLOAD: finally closing {f.name}') |
|
f.close() |
|
if os.path.exists(f.name): |
|
os.remove(f.name) |
|
|