Update app.py
Browse files
app.py
CHANGED
@@ -8,12 +8,13 @@ import time
|
|
8 |
app = FastAPI()
|
9 |
|
10 |
# Create cache directory
|
11 |
-
|
|
|
12 |
|
13 |
# Track app status
|
14 |
app_status = {
|
15 |
"status": "initializing",
|
16 |
-
"model_name": "
|
17 |
"model_loaded": False,
|
18 |
"tokenizer_loaded": False,
|
19 |
"startup_time": time.time(),
|
@@ -21,22 +22,22 @@ app_status = {
|
|
21 |
}
|
22 |
|
23 |
# Load model and tokenizer once at startup
|
24 |
-
model_name = "distilgpt2"
|
25 |
try:
|
26 |
# Try to load tokenizer
|
27 |
app_status["status"] = "loading_tokenizer"
|
28 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=
|
29 |
app_status["tokenizer_loaded"] = True
|
30 |
|
31 |
-
# Try to load model
|
32 |
app_status["status"] = "loading_model"
|
33 |
-
model = AutoModelForCausalLM.from_pretrained(model_name,
|
34 |
app_status["model_loaded"] = True
|
35 |
|
36 |
app_status["status"] = "ready"
|
37 |
except Exception as e:
|
38 |
error_msg = f"Error loading model or tokenizer: {str(e)}"
|
39 |
-
app_status["status"] = "
|
40 |
app_status["errors"].append(error_msg)
|
41 |
print(error_msg)
|
42 |
|
@@ -48,7 +49,7 @@ class PromptRequest(BaseModel):
|
|
48 |
async def generate_text(req: PromptRequest, response: Response):
|
49 |
if app_status["status"] != "ready":
|
50 |
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
51 |
-
return {"error": "Model not ready", "status": app_status["status"]}
|
52 |
|
53 |
try:
|
54 |
inputs = tokenizer(req.prompt, return_tensors="pt")
|
@@ -67,7 +68,7 @@ async def generate_text(req: PromptRequest, response: Response):
|
|
67 |
|
68 |
@app.get("/")
|
69 |
async def root():
|
70 |
-
return {"message": "API is
|
71 |
|
72 |
@app.get("/status")
|
73 |
async def get_status():
|
|
|
8 |
app = FastAPI()
|
9 |
|
10 |
# Create cache directory
|
11 |
+
cache_dir = "./model_cache"
|
12 |
+
os.makedirs(cache_dir, exist_ok=True)
|
13 |
|
14 |
# Track app status
|
15 |
app_status = {
|
16 |
"status": "initializing",
|
17 |
+
"model_name": "distilgpt2",
|
18 |
"model_loaded": False,
|
19 |
"tokenizer_loaded": False,
|
20 |
"startup_time": time.time(),
|
|
|
22 |
}
|
23 |
|
24 |
# Load model and tokenizer once at startup
|
25 |
+
model_name = "distilgpt2"
|
26 |
try:
|
27 |
# Try to load tokenizer
|
28 |
app_status["status"] = "loading_tokenizer"
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
30 |
app_status["tokenizer_loaded"] = True
|
31 |
|
32 |
+
# Try to load model - with from_tf=True
|
33 |
app_status["status"] = "loading_model"
|
34 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, from_tf=True, cache_dir=cache_dir)
|
35 |
app_status["model_loaded"] = True
|
36 |
|
37 |
app_status["status"] = "ready"
|
38 |
except Exception as e:
|
39 |
error_msg = f"Error loading model or tokenizer: {str(e)}"
|
40 |
+
app_status["status"] = "limited_functionality"
|
41 |
app_status["errors"].append(error_msg)
|
42 |
print(error_msg)
|
43 |
|
|
|
49 |
async def generate_text(req: PromptRequest, response: Response):
|
50 |
if app_status["status"] != "ready":
|
51 |
response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
52 |
+
return {"error": "Model not ready", "status": app_status["status"], "details": app_status["errors"]}
|
53 |
|
54 |
try:
|
55 |
inputs = tokenizer(req.prompt, return_tensors="pt")
|
|
|
68 |
|
69 |
@app.get("/")
|
70 |
async def root():
|
71 |
+
return {"message": "API is responding", "status": app_status["status"]}
|
72 |
|
73 |
@app.get("/status")
|
74 |
async def get_status():
|