tharms commited on
Commit
98e0e3d
·
1 Parent(s): 23924fc

added log output for custom launch

Browse files
Files changed (1) hide show
  1. modules/model_loader.py +83 -1
modules/model_loader.py CHANGED
@@ -1,5 +1,15 @@
1
  import os
 
 
 
 
 
 
 
 
2
  from urllib.parse import urlparse
 
 
3
  from typing import Optional
4
 
5
 
@@ -26,7 +36,79 @@ def load_file_from_url(
26
  if not os.path.exists(cached_file):
27
  print(f'Downloading: "{url}" to {cached_file}\n')
28
  from torch.hub import download_url_to_file
29
- download_url_to_file(url, cached_file, progress=progress)
30
  print ('DOWNLOADED FILE: ', url)
31
  print(f'Using cached file: {cached_file}')
32
  return cached_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import re
3
+ import shutil
4
+ import sys
5
+ import tempfile
6
+ import torch
7
+ import uuid
8
+ import warnings
9
+ import hashlib
10
  from urllib.parse import urlparse
11
+ from urllib.error import HTTPError, URLError
12
+ from urllib.request import urlopen, Request
13
  from typing import Optional
14
 
15
 
 
36
  if not os.path.exists(cached_file):
37
  print(f'Downloading: "{url}" to {cached_file}\n')
38
  from torch.hub import download_url_to_file
39
+ proxy_download_url_to_file(url, cached_file, progress=progress)
40
  print ('DOWNLOADED FILE: ', url)
41
  print(f'Using cached file: {cached_file}')
42
  return cached_file
43
+
44
+
45
+ def proxy_download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
46
+ progress: bool = True) -> None:
47
+ r"""Download object at the given URL to a local path.
48
+
49
+ Args:
50
+ url (str): URL of the object to download
51
+ dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file``
52
+ hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``.
53
+ Default: None
54
+ progress (bool, optional): whether or not to display a progress bar to stderr
55
+ Default: True
56
+
57
+ Example:
58
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
59
+ >>> # xdoctest: +REQUIRES(POSIX)
60
+ >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
61
+
62
+ """
63
+ print('PROXY DOWNLOAD')
64
+ file_size = None
65
+ req = Request(url, headers={"User-Agent": "torch.hub"})
66
+ u = urlopen(req)
67
+ meta = u.info()
68
+ if hasattr(meta, 'getheaders'):
69
+ content_length = meta.getheaders("Content-Length")
70
+ else:
71
+ content_length = meta.get_all("Content-Length")
72
+ if content_length is not None and len(content_length) > 0:
73
+ file_size = int(content_length[0])
74
+
75
+ # We deliberately save it in a temp file and move it after
76
+ # download is complete. This prevents a local working checkpoint
77
+ # being overridden by a broken download.
78
+ # We deliberately do not use NamedTemporaryFile to avoid restrictive
79
+ # file permissions being applied to the downloaded file.
80
+ dst = os.path.expanduser(dst)
81
+ for seq in range(tempfile.TMP_MAX):
82
+ tmp_dst = dst + '.' + uuid.uuid4().hex + '.partial'
83
+ try:
84
+ f = open(tmp_dst, 'w+b')
85
+ except FileExistsError:
86
+ continue
87
+ break
88
+ else:
89
+ raise FileExistsError(errno.EEXIST, 'No usable temporary file name found')
90
+
91
+ try:
92
+ if hash_prefix is not None:
93
+ sha256 = hashlib.sha256()
94
+ with tqdm(total=file_size, disable=not progress,
95
+ unit='B', unit_scale=True, unit_divisor=1024) as pbar:
96
+ while True:
97
+ buffer = u.read(8192)
98
+ if len(buffer) == 0:
99
+ break
100
+ f.write(buffer)
101
+ if hash_prefix is not None:
102
+ sha256.update(buffer)
103
+ pbar.update(len(buffer))
104
+
105
+ f.close()
106
+ if hash_prefix is not None:
107
+ digest = sha256.hexdigest()
108
+ if digest[:len(hash_prefix)] != hash_prefix:
109
+ raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")')
110
+ shutil.move(f.name, dst)
111
+ finally:
112
+ f.close()
113
+ if os.path.exists(f.name):
114
+ os.remove(f.name)