roshcheeku commited on
Commit
268c5f7
·
verified ·
1 Parent(s): 3e44fe1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -16
app.py CHANGED
@@ -12,37 +12,35 @@ from flask_cors import CORS
12
  # Suppress sklearn warnings
13
  warnings.filterwarnings('ignore', category=UserWarning, module='sklearn')
14
 
15
- # Configure logging
16
  logging.basicConfig(level=logging.INFO)
17
 
18
- # Get model URLs from environment variables
19
  DIABETES_MODEL_URL = os.getenv("DIABETES_MODEL_URL")
20
  SCALER_URL = os.getenv("SCALER_URL")
21
  MULTI_MODEL_URL = os.getenv("MULTI_MODEL_URL")
22
 
23
- # Local paths for downloaded models
24
  MODEL_PATHS = {
25
  "DIABETES_MODEL": "finaliseddiabetes_model.zip",
26
  "SCALER": "finalisedscaler.zip",
27
  "MULTI_MODEL": "nodiabetes.zip",
28
  }
29
 
30
- # Extracted model names
31
  EXTRACTED_MODELS = {
32
  "DIABETES_MODEL": "finaliseddiabetes_model.joblib",
33
  "SCALER": "finalisedscaler.joblib",
34
  "MULTI_MODEL": "nodiabetes.joblib",
35
  }
36
 
37
- BASE_DIR = os.getcwd()
 
38
 
39
- # Flask app initialization
40
  app = Flask(__name__)
41
-
42
- # Enable CORS for all origins
43
  CORS(app, supports_credentials=True)
44
 
45
- # Root route
46
  @app.route('/')
47
  def index():
48
  return """
@@ -52,7 +50,7 @@ def index():
52
  """
53
 
54
  def download_model(url, zip_filename):
55
- zip_path = os.path.join(BASE_DIR, zip_filename)
56
  if not url:
57
  logging.error(f"URL for {zip_filename} is missing!")
58
  return False
@@ -71,8 +69,8 @@ def download_model(url, zip_filename):
71
  return False
72
 
73
  def extract_if_needed(zip_filename, extracted_filename):
74
- zip_path = os.path.join(BASE_DIR, zip_filename)
75
- extracted_path = os.path.join(BASE_DIR, extracted_filename)
76
  if os.path.exists(extracted_path):
77
  logging.info(f"{extracted_filename} already exists. Skipping extraction.")
78
  return True
@@ -81,15 +79,15 @@ def extract_if_needed(zip_filename, extracted_filename):
81
  return False
82
  try:
83
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
84
- zip_ref.extractall(BASE_DIR)
85
- logging.info(f"Extracted {zip_filename}")
86
  return True
87
  except Exception as e:
88
  logging.error(f"Error extracting {zip_filename}: {e}")
89
  return False
90
 
91
  def load_model(model_filename):
92
- model_path = os.path.join(BASE_DIR, model_filename)
93
  if not os.path.exists(model_path):
94
  logging.error(f"Model file not found: {model_path}")
95
  return None
@@ -105,7 +103,8 @@ def initialize_models():
105
  models = {}
106
  for model_key, zip_filename in MODEL_PATHS.items():
107
  extracted_filename = EXTRACTED_MODELS[model_key]
108
- if not os.path.exists(os.path.join(BASE_DIR, zip_filename)):
 
109
  download_model(globals()[f"{model_key}_URL"], zip_filename)
110
  extract_if_needed(zip_filename, extracted_filename)
111
  models[model_key] = load_model(extracted_filename)
 
12
  # Suppress sklearn warnings
13
  warnings.filterwarnings('ignore', category=UserWarning, module='sklearn')
14
 
15
+ # Logging setup
16
  logging.basicConfig(level=logging.INFO)
17
 
18
+ # Model URLs from env
19
  DIABETES_MODEL_URL = os.getenv("DIABETES_MODEL_URL")
20
  SCALER_URL = os.getenv("SCALER_URL")
21
  MULTI_MODEL_URL = os.getenv("MULTI_MODEL_URL")
22
 
23
+ # Model ZIP names
24
  MODEL_PATHS = {
25
  "DIABETES_MODEL": "finaliseddiabetes_model.zip",
26
  "SCALER": "finalisedscaler.zip",
27
  "MULTI_MODEL": "nodiabetes.zip",
28
  }
29
 
30
+ # Extracted joblib names
31
  EXTRACTED_MODELS = {
32
  "DIABETES_MODEL": "finaliseddiabetes_model.joblib",
33
  "SCALER": "finalisedscaler.joblib",
34
  "MULTI_MODEL": "nodiabetes.joblib",
35
  }
36
 
37
+ # Use writeable temp dir
38
+ TMP_DIR = "/tmp"
39
 
40
+ # Flask app init
41
  app = Flask(__name__)
 
 
42
  CORS(app, supports_credentials=True)
43
 
 
44
  @app.route('/')
45
  def index():
46
  return """
 
50
  """
51
 
52
  def download_model(url, zip_filename):
53
+ zip_path = os.path.join(TMP_DIR, zip_filename)
54
  if not url:
55
  logging.error(f"URL for {zip_filename} is missing!")
56
  return False
 
69
  return False
70
 
71
  def extract_if_needed(zip_filename, extracted_filename):
72
+ zip_path = os.path.join(TMP_DIR, zip_filename)
73
+ extracted_path = os.path.join(TMP_DIR, extracted_filename)
74
  if os.path.exists(extracted_path):
75
  logging.info(f"{extracted_filename} already exists. Skipping extraction.")
76
  return True
 
79
  return False
80
  try:
81
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
82
+ zip_ref.extractall(TMP_DIR)
83
+ logging.info(f"Extracted {zip_filename} to {TMP_DIR}")
84
  return True
85
  except Exception as e:
86
  logging.error(f"Error extracting {zip_filename}: {e}")
87
  return False
88
 
89
  def load_model(model_filename):
90
+ model_path = os.path.join(TMP_DIR, model_filename)
91
  if not os.path.exists(model_path):
92
  logging.error(f"Model file not found: {model_path}")
93
  return None
 
103
  models = {}
104
  for model_key, zip_filename in MODEL_PATHS.items():
105
  extracted_filename = EXTRACTED_MODELS[model_key]
106
+ zip_path = os.path.join(TMP_DIR, zip_filename)
107
+ if not os.path.exists(zip_path):
108
  download_model(globals()[f"{model_key}_URL"], zip_filename)
109
  extract_if_needed(zip_filename, extracted_filename)
110
  models[model_key] = load_model(extracted_filename)