File size: 2,662 Bytes
c2f4964
f8898c0
 
 
ef34ed3
c2f4964
f8898c0
642232a
f71bfa4
ef34ed3
aba490b
 
ef34ed3
c2f4964
 
 
aba490b
c2f4964
 
 
 
 
 
8dd9dba
aba490b
ef34ed3
c2f4964
 
aba490b
c2f4964
 
aba490b
c2f4964
aba490b
c2f4964
 
 
 
 
aba490b
c2f4964
 
8dd9dba
f8898c0
 
8dd9dba
f8898c0
 
c2f4964
 
 
aba490b
c2f4964
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef34ed3
 
 
aba490b
c2f4964
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from fastapi import FastAPI, Request, Response, status
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import time

app = FastAPI()

# Create cache directory
cache_dir = "./model_cache"
os.makedirs(cache_dir, exist_ok=True)

# Track app status
app_status = {
    "status": "initializing",
    "model_name": "distilgpt2",
    "model_loaded": False,
    "tokenizer_loaded": False,
    "startup_time": time.time(),
    "errors": []
}

# Load model and tokenizer once at startup
model_name = "distilgpt2"
try:
    # Try to load tokenizer
    app_status["status"] = "loading_tokenizer"
    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
    app_status["tokenizer_loaded"] = True
    
    # Try to load model - with from_tf=True
    app_status["status"] = "loading_model"
    model = AutoModelForCausalLM.from_pretrained(model_name, from_tf=True, cache_dir=cache_dir)
    app_status["model_loaded"] = True
    
    app_status["status"] = "ready"
except Exception as e:
    error_msg = f"Error loading model or tokenizer: {str(e)}"
    app_status["status"] = "limited_functionality"
    app_status["errors"].append(error_msg)
    print(error_msg)

class PromptRequest(BaseModel):
    prompt: str
    max_new_tokens: int = 50

@app.post("/generate")
async def generate_text(req: PromptRequest, response: Response):
    if app_status["status"] != "ready":
        response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
        return {"error": "Model not ready", "status": app_status["status"], "details": app_status["errors"]}
    
    try:
        inputs = tokenizer(req.prompt, return_tensors="pt")
        outputs = model.generate(
            **inputs,
            max_new_tokens=req.max_new_tokens,
            do_sample=True,
            temperature=0.8,
            top_p=0.95,
        )
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return {"generated_text": generated}
    except Exception as e:
        response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
        return {"error": str(e)}

@app.get("/")
async def root():
    return {"message": "API is responding", "status": app_status["status"]}

@app.get("/status")
async def get_status():
    # Calculate uptime
    uptime = time.time() - app_status["startup_time"]
    
    return {
        "status": app_status["status"],
        "model_name": app_status["model_name"],
        "model_loaded": app_status["model_loaded"],
        "tokenizer_loaded": app_status["tokenizer_loaded"],
        "uptime_seconds": uptime,
        "errors": app_status["errors"]
    }