mac9087 commited on
Commit
ab52342
·
verified ·
1 Parent(s): 5eb7b3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -17,6 +17,17 @@ logger = logging.getLogger(__name__)
17
  os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
18
  os.environ['HOME'] = '/tmp'
19
 
 
 
 
 
 
 
 
 
 
 
 
20
  app = Flask(__name__)
21
  CORS(app)
22
 
@@ -25,9 +36,6 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
  logger.info(f"Using device: {device}")
26
 
27
  # Set up model directory for downloads
28
- os.makedirs("/tmp/point_e_models", exist_ok=True)
29
-
30
- # Download model weights or use cached version
31
  model_path = "/tmp/point_e_models/base40M-textvec.pt"
32
  if not os.path.exists(model_path):
33
  logger.info("Model weights not found. Downloading...")
@@ -45,7 +53,11 @@ if not os.path.exists(model_path):
45
  # Load the model
46
  logger.info("Loading base model...")
47
  try:
48
- base_model = model_from_config(MODEL_CONFIGS["base40M-textvec"], device=device)
 
 
 
 
49
  base_model.load_state_dict(torch.load(model_path, map_location=device))
50
  base_model.eval()
51
  logger.info("Base model loaded successfully")
 
17
  os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
18
  os.environ['HOME'] = '/tmp'
19
 
20
+ # Set cache directories explicitly
21
+ cache_dir = os.environ.get('POINT_E_CACHE_DIR', '/tmp/point_e_model_cache')
22
+ clip_cache_dir = os.environ.get('CLIP_MODEL_DIR', '/tmp/clip_models')
23
+
24
+ # Ensure cache directories exist and are writable
25
+ for directory in [cache_dir, clip_cache_dir, '/tmp/point_e_models']:
26
+ if not os.path.exists(directory):
27
+ os.makedirs(directory, exist_ok=True)
28
+ # Make sure permissions are correct
29
+ os.chmod(directory, 0o777)
30
+
31
  app = Flask(__name__)
32
  CORS(app)
33
 
 
36
  logger.info(f"Using device: {device}")
37
 
38
  # Set up model directory for downloads
 
 
 
39
  model_path = "/tmp/point_e_models/base40M-textvec.pt"
40
  if not os.path.exists(model_path):
41
  logger.info("Model weights not found. Downloading...")
 
53
  # Load the model
54
  logger.info("Loading base model...")
55
  try:
56
+ base_model = model_from_config(
57
+ MODEL_CONFIGS["base40M-textvec"],
58
+ device=device,
59
+ cache_dir=cache_dir # Pass cache directory explicitly
60
+ )
61
  base_model.load_state_dict(torch.load(model_path, map_location=device))
62
  base_model.eval()
63
  logger.info("Base model loaded successfully")