FlameF0X commited on
Commit
aba490b
·
verified ·
1 Parent(s): 5e4f440

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -8,12 +8,13 @@ import time
8
  app = FastAPI()
9
 
10
  # Create cache directory
11
- os.makedirs("./model_cache", exist_ok=True)
 
12
 
13
  # Track app status
14
  app_status = {
15
  "status": "initializing",
16
- "model_name": "SmolLM2-135M-Instruct",
17
  "model_loaded": False,
18
  "tokenizer_loaded": False,
19
  "startup_time": time.time(),
@@ -21,22 +22,22 @@ app_status = {
21
  }
22
 
23
  # Load model and tokenizer once at startup
24
- model_name = "distilgpt2" # change this to your own model
25
  try:
26
  # Try to load tokenizer
27
  app_status["status"] = "loading_tokenizer"
28
- tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./model_cache", local_files_only=False)
29
  app_status["tokenizer_loaded"] = True
30
 
31
- # Try to load model
32
  app_status["status"] = "loading_model"
33
- model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./model_cache", local_files_only=False)
34
  app_status["model_loaded"] = True
35
 
36
  app_status["status"] = "ready"
37
  except Exception as e:
38
  error_msg = f"Error loading model or tokenizer: {str(e)}"
39
- app_status["status"] = "error"
40
  app_status["errors"].append(error_msg)
41
  print(error_msg)
42
 
@@ -48,7 +49,7 @@ class PromptRequest(BaseModel):
48
  async def generate_text(req: PromptRequest, response: Response):
49
  if app_status["status"] != "ready":
50
  response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
51
- return {"error": "Model not ready", "status": app_status["status"]}
52
 
53
  try:
54
  inputs = tokenizer(req.prompt, return_tensors="pt")
@@ -67,7 +68,7 @@ async def generate_text(req: PromptRequest, response: Response):
67
 
68
  @app.get("/")
69
  async def root():
70
- return {"message": "API is running", "status": app_status["status"]}
71
 
72
  @app.get("/status")
73
  async def get_status():
 
8
  app = FastAPI()
9
 
10
  # Create cache directory
11
+ cache_dir = "./model_cache"
12
+ os.makedirs(cache_dir, exist_ok=True)
13
 
14
  # Track app status
15
  app_status = {
16
  "status": "initializing",
17
+ "model_name": "distilgpt2",
18
  "model_loaded": False,
19
  "tokenizer_loaded": False,
20
  "startup_time": time.time(),
 
22
  }
23
 
24
  # Load model and tokenizer once at startup
25
+ model_name = "distilgpt2"
26
  try:
27
  # Try to load tokenizer
28
  app_status["status"] = "loading_tokenizer"
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
30
  app_status["tokenizer_loaded"] = True
31
 
32
+ # Try to load model - with from_tf=True
33
  app_status["status"] = "loading_model"
34
+ model = AutoModelForCausalLM.from_pretrained(model_name, from_tf=True, cache_dir=cache_dir)
35
  app_status["model_loaded"] = True
36
 
37
  app_status["status"] = "ready"
38
  except Exception as e:
39
  error_msg = f"Error loading model or tokenizer: {str(e)}"
40
+ app_status["status"] = "limited_functionality"
41
  app_status["errors"].append(error_msg)
42
  print(error_msg)
43
 
 
49
  async def generate_text(req: PromptRequest, response: Response):
50
  if app_status["status"] != "ready":
51
  response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
52
+ return {"error": "Model not ready", "status": app_status["status"], "details": app_status["errors"]}
53
 
54
  try:
55
  inputs = tokenizer(req.prompt, return_tensors="pt")
 
68
 
69
  @app.get("/")
70
  async def root():
71
+ return {"message": "API is responding", "status": app_status["status"]}
72
 
73
  @app.get("/status")
74
  async def get_status():