Update app.py
Browse files
app.py
CHANGED
@@ -54,14 +54,15 @@ def load_model():
|
|
54 |
app.logger.info(f"Device set to use {device}")
|
55 |
|
56 |
model = AutoModelForCausalLM.from_pretrained(
|
57 |
-
"gpt2
|
58 |
use_safetensors=True,
|
59 |
device_map="auto",
|
60 |
torch_dtype=dtype,
|
61 |
-
low_cpu_mem_usage=True
|
|
|
62 |
)
|
63 |
|
64 |
-
tokenizer = AutoTokenizer.from_pretrained("gpt2
|
65 |
|
66 |
# Initialize pipeline without explicit device assignment
|
67 |
generator = pipeline(
|
@@ -146,17 +147,11 @@ IEEE_TEMPLATE = """
|
|
146 |
# --------------------------------------------------
|
147 |
@app.route('/health', methods=['GET'])
|
148 |
def health_check():
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
"status": "error",
|
155 |
-
"message": f"Model failed to load: {load_error}"
|
156 |
-
}), 500
|
157 |
-
|
158 |
-
status_code = 200 if model_loaded else 503
|
159 |
-
device_info = "cuda" if torch.cuda.is_available() else "cpu"
|
160 |
|
161 |
app.logger.info(f"Health check returning status: {'ready' if model_loaded else 'loading'}, device: {device_info}")
|
162 |
return jsonify({
|
|
|
54 |
app.logger.info(f"Device set to use {device}")
|
55 |
|
56 |
model = AutoModelForCausalLM.from_pretrained(
|
57 |
+
"gpt2",
|
58 |
use_safetensors=True,
|
59 |
device_map="auto",
|
60 |
torch_dtype=dtype,
|
61 |
+
low_cpu_mem_usage=True,
|
62 |
+
offload_folder="offload"
|
63 |
)
|
64 |
|
65 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
66 |
|
67 |
# Initialize pipeline without explicit device assignment
|
68 |
generator = pipeline(
|
|
|
147 |
# --------------------------------------------------
|
148 |
@app.route('/health', methods=['GET'])
|
149 |
def health_check():
|
150 |
+
return jsonify({
|
151 |
+
"status": "ok",
|
152 |
+
"model_loaded": model_loaded,
|
153 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu"
|
154 |
+
}), 200
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
app.logger.info(f"Health check returning status: {'ready' if model_loaded else 'loading'}, device: {device_info}")
|
157 |
return jsonify({
|