FlameF0X commited on
Commit
642232a
·
verified ·
1 Parent(s): 1261056

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -17
app.py CHANGED
@@ -1,29 +1,31 @@
1
- from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
- app = FastAPI()
7
-
8
- # Load model and tokenizer once at startup
9
- model_name = "./tiny-gpt2" # ← path to the local directory
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  model = AutoModelForCausalLM.from_pretrained(model_name)
12
 
 
 
13
 
14
  class PromptRequest(BaseModel):
15
  prompt: str
16
- max_new_tokens: int = 50
17
 
18
  @app.post("/generate")
19
- async def generate_text(req: PromptRequest):
20
- inputs = tokenizer(req.prompt, return_tensors="pt")
21
- outputs = model.generate(
22
- **inputs,
23
- max_new_tokens=req.max_new_tokens,
24
- do_sample=True,
25
- temperature=0.8,
26
- top_p=0.95,
27
- )
28
- generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
- return {"generated_text": generated}
 
 
 
1
+ from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
+ # Initialize the model and tokenizer (Tiny GPT-2)
7
+ model_name = "./tiny-gpt2" # Path to your tiny-gpt2 folder
 
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
+ # FastAPI app
12
+ app = FastAPI()
13
 
14
  class PromptRequest(BaseModel):
15
  prompt: str
16
+ max_new_tokens: int = 50 # You can adjust the number of tokens generated
17
 
18
  @app.post("/generate")
19
+ async def generate_text(request: PromptRequest):
20
+ # Encode the input prompt text
21
+ inputs = tokenizer.encode(request.prompt, return_tensors="pt")
22
+
23
+ # Generate the text using the model
24
+ with torch.no_grad():
25
+ outputs = model.generate(inputs, max_length=request.max_new_tokens + len(inputs[0]))
26
+
27
+ # Decode the generated text and return the response
28
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
+
30
+ return {"generated_text": generated_text}
31
+