mac9087 commited on
Commit
5eb7b3a
·
verified ·
1 Parent(s): e16b8f7

Update download_weights.py

Browse files
Files changed (1) hide show
  1. download_weights.py +44 -21
download_weights.py CHANGED
@@ -1,8 +1,8 @@
1
- # download_weights.py - Improved with proper OpenAI Point-E model URL
2
  import os
3
  import requests
4
  import torch
5
  from tqdm import tqdm
 
6
 
7
  def download_file(url, destination):
8
  """
@@ -13,35 +13,57 @@ def download_file(url, destination):
13
  return
14
 
15
  print(f"Downloading {url} to {destination}")
16
- response = requests.get(url, stream=True)
17
- total_size_in_bytes = int(response.headers.get('content-length', 0))
18
- block_size = 1024 # 1 Kibibyte
19
- progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
20
-
21
- with open(destination, 'wb') as file:
22
- for data in response.iter_content(block_size):
23
- progress_bar.update(len(data))
24
- file.write(data)
25
-
26
- progress_bar.close()
27
- if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
28
- print("ERROR, something went wrong")
29
- else:
30
- print(f"Download complete: {destination}")
 
 
 
 
 
 
 
 
31
 
32
  def main():
33
  # Create directory for model storage
34
- os.makedirs("/tmp/point_e_models", exist_ok=True)
 
 
 
 
 
 
 
35
 
36
- # URL to download the model weights from Hugging Face Hub
37
- # Point-E model weights are available on Hugging Face
38
  model_url = "https://huggingface.co/openai/point-e/resolve/main/base40M-textvec.pt"
39
 
40
  # Destination path
41
  model_path = "/tmp/point_e_models/base40M-textvec.pt"
42
 
43
- # Download the model weights
44
- download_file(model_url, model_path)
 
 
 
 
 
 
 
 
45
 
46
  # Verify the download
47
  try:
@@ -49,6 +71,7 @@ def main():
49
  print(f"Successfully loaded model state dict with {len(state_dict)} keys")
50
  except Exception as e:
51
  print(f"Error loading model: {e}")
 
52
 
53
  if __name__ == "__main__":
54
  main()
 
 
1
  import os
2
  import requests
3
  import torch
4
  from tqdm import tqdm
5
+ import sys
6
 
7
  def download_file(url, destination):
8
  """
 
13
  return
14
 
15
  print(f"Downloading {url} to {destination}")
16
+ try:
17
+ response = requests.get(url, stream=True)
18
+ response.raise_for_status() # Raise an exception for bad status codes
19
+
20
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
21
+ block_size = 1024 # 1 Kibibyte
22
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
23
+
24
+ with open(destination, 'wb') as file:
25
+ for data in response.iter_content(block_size):
26
+ progress_bar.update(len(data))
27
+ file.write(data)
28
+
29
+ progress_bar.close()
30
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
31
+ print("ERROR, something went wrong with the download")
32
+ return False
33
+ else:
34
+ print(f"Download complete: {destination}")
35
+ return True
36
+ except requests.exceptions.RequestException as e:
37
+ print(f"Error downloading file: {e}")
38
+ return False
39
 
40
  def main():
41
  # Create directory for model storage
42
+ model_dir = "/tmp/point_e_models"
43
+ os.makedirs(model_dir, exist_ok=True)
44
+
45
+ # Ensure directory is writable
46
+ try:
47
+ os.chmod(model_dir, 0o777)
48
+ except Exception as e:
49
+ print(f"Warning: Could not set permissions on {model_dir}: {e}")
50
 
51
+ # URL to download the model weights - use Hugging Face Hub URL
 
52
  model_url = "https://huggingface.co/openai/point-e/resolve/main/base40M-textvec.pt"
53
 
54
  # Destination path
55
  model_path = "/tmp/point_e_models/base40M-textvec.pt"
56
 
57
+ # Attempt download
58
+ success = download_file(model_url, model_path)
59
+ if not success:
60
+ print("Failed to download model weights. Trying alternative source...")
61
+ # Try an alternative source if the first one fails
62
+ alt_model_url = "https://github.com/openai/point-e/releases/download/v0.1.0/base40M-textvec.pt"
63
+ success = download_file(alt_model_url, model_path)
64
+ if not success:
65
+ print("Failed to download model weights from alternative source.")
66
+ sys.exit(1)
67
 
68
  # Verify the download
69
  try:
 
71
  print(f"Successfully loaded model state dict with {len(state_dict)} keys")
72
  except Exception as e:
73
  print(f"Error loading model: {e}")
74
+ sys.exit(1)
75
 
76
  if __name__ == "__main__":
77
  main()