FlameF0X commited on
Commit
c2f4964
·
verified ·
1 Parent(s): 5333932

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -20
app.py CHANGED
@@ -1,44 +1,84 @@
1
- from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
  import os
 
6
 
7
  app = FastAPI()
8
 
9
  # Create cache directory
10
  os.makedirs("./model_cache", exist_ok=True)
11
 
 
 
 
 
 
 
 
 
 
 
12
  # Load model and tokenizer once at startup
13
  model_name = "distilgpt2" # change this to your own model
14
  try:
15
- # Try to load from local cache first
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./model_cache", local_files_only=False)
 
 
 
 
17
  model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./model_cache", local_files_only=False)
18
- except OSError as e:
19
- print(f"Error loading model: {e}")
20
- print("Attempting to download model directly...")
21
- # If that fails, try downloading explicitly
22
- tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./model_cache")
23
- model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./model_cache")
 
 
24
 
25
  class PromptRequest(BaseModel):
26
  prompt: str
27
  max_new_tokens: int = 50
28
 
29
  @app.post("/generate")
30
- async def generate_text(req: PromptRequest):
31
- inputs = tokenizer(req.prompt, return_tensors="pt")
32
- outputs = model.generate(
33
- **inputs,
34
- max_new_tokens=req.max_new_tokens,
35
- do_sample=True,
36
- temperature=0.8,
37
- top_p=0.95,
38
- )
39
- generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
- return {"generated_text": generated}
 
 
 
 
 
 
 
 
41
 
42
  @app.get("/")
43
  async def root():
44
- return {"status": "API is running", "model": model_name}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, Response, status
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
  import os
6
+ import time
7
 
8
  app = FastAPI()
9
 
10
  # Create cache directory
11
  os.makedirs("./model_cache", exist_ok=True)
12
 
13
+ # Track app status
14
+ app_status = {
15
+ "status": "initializing",
16
+ "model_name": "distilgpt2",
17
+ "model_loaded": False,
18
+ "tokenizer_loaded": False,
19
+ "startup_time": time.time(),
20
+ "errors": []
21
+ }
22
+
23
  # Load model and tokenizer once at startup
24
  model_name = "distilgpt2" # change this to your own model
25
  try:
26
+ # Try to load tokenizer
27
+ app_status["status"] = "loading_tokenizer"
28
  tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./model_cache", local_files_only=False)
29
+ app_status["tokenizer_loaded"] = True
30
+
31
+ # Try to load model
32
+ app_status["status"] = "loading_model"
33
  model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./model_cache", local_files_only=False)
34
+ app_status["model_loaded"] = True
35
+
36
+ app_status["status"] = "ready"
37
+ except Exception as e:
38
+ error_msg = f"Error loading model or tokenizer: {str(e)}"
39
+ app_status["status"] = "error"
40
+ app_status["errors"].append(error_msg)
41
+ print(error_msg)
42
 
43
  class PromptRequest(BaseModel):
44
  prompt: str
45
  max_new_tokens: int = 50
46
 
47
  @app.post("/generate")
48
+ async def generate_text(req: PromptRequest, response: Response):
49
+ if app_status["status"] != "ready":
50
+ response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
51
+ return {"error": "Model not ready", "status": app_status["status"]}
52
+
53
+ try:
54
+ inputs = tokenizer(req.prompt, return_tensors="pt")
55
+ outputs = model.generate(
56
+ **inputs,
57
+ max_new_tokens=req.max_new_tokens,
58
+ do_sample=True,
59
+ temperature=0.8,
60
+ top_p=0.95,
61
+ )
62
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
63
+ return {"generated_text": generated}
64
+ except Exception as e:
65
+ response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
66
+ return {"error": str(e)}
67
 
68
  @app.get("/")
69
  async def root():
70
+ return {"message": "API is running", "status": app_status["status"]}
71
+
72
+ @app.get("/status")
73
+ async def get_status():
74
+ # Calculate uptime
75
+ uptime = time.time() - app_status["startup_time"]
76
+
77
+ return {
78
+ "status": app_status["status"],
79
+ "model_name": app_status["model_name"],
80
+ "model_loaded": app_status["model_loaded"],
81
+ "tokenizer_loaded": app_status["tokenizer_loaded"],
82
+ "uptime_seconds": uptime,
83
+ "errors": app_status["errors"]
84
+ }