FlameF0X commited on
Commit
ef34ed3
·
verified ·
1 Parent(s): d04f77a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -2
app.py CHANGED
@@ -2,13 +2,25 @@ from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
 
5
 
6
  app = FastAPI()
7
 
 
 
 
8
  # Load model and tokenizer once at startup
9
  model_name = "distilgpt2" # change this to your own model
10
- tokenizer = AutoTokenizer.from_pretrained(model_name)
11
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
12
 
13
  class PromptRequest(BaseModel):
14
  prompt: str
@@ -26,3 +38,7 @@ async def generate_text(req: PromptRequest):
26
  )
27
  generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
  return {"generated_text": generated}
 
 
 
 
 
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
+ import os
6
 
7
  app = FastAPI()
8
 
9
+ # Create cache directory
10
+ os.makedirs("./model_cache", exist_ok=True)
11
+
12
  # Load model and tokenizer once at startup
13
  model_name = "distilgpt2" # change this to your own model
14
+ try:
15
+ # Try to load from local cache first
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./model_cache", local_files_only=False)
17
+ model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./model_cache", local_files_only=False)
18
+ except OSError as e:
19
+ print(f"Error loading model: {e}")
20
+ print("Attempting to download model directly...")
21
+ # If that fails, try downloading explicitly
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./model_cache")
23
+ model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./model_cache")
24
 
25
  class PromptRequest(BaseModel):
26
  prompt: str
 
38
  )
39
  generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
  return {"generated_text": generated}
41
+
42
+ @app.get("/")
43
+ async def root():
44
+ return {"status": "API is running", "model": model_name}