FlameF0X commited on
Commit
8dd9dba
·
verified ·
1 Parent(s): d948cd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -23
app.py CHANGED
@@ -1,33 +1,28 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
- import os
6
 
7
- # Define the path where the model is located
8
- model_directory = "./tiny-gpt2" # Make sure this is the correct relative or absolute path to your model folder
9
-
10
- # Initialize the model and tokenizer using the correct directory path
11
- tokenizer = AutoTokenizer.from_pretrained(model_directory)
12
- model = AutoModelForCausalLM.from_pretrained(model_directory)
13
-
14
- # FastAPI app
15
  app = FastAPI()
16
 
 
 
 
 
 
17
  class PromptRequest(BaseModel):
18
  prompt: str
19
- max_new_tokens: int = 50 # You can adjust the number of tokens generated
20
 
21
  @app.post("/generate")
22
- async def generate_text(request: PromptRequest):
23
- # Encode the input prompt text
24
- inputs = tokenizer.encode(request.prompt, return_tensors="pt")
25
-
26
- # Generate the text using the model
27
- with torch.no_grad():
28
- outputs = model.generate(inputs, max_length=request.max_new_tokens + len(inputs[0]))
29
-
30
- # Decode the generated text and return the response
31
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
32
-
33
- return {"generated_text": generated_text}
 
1
+ from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
 
5
 
 
 
 
 
 
 
 
 
6
  app = FastAPI()
7
 
8
+ # Load model and tokenizer once at startup
9
+ model_name = "FlameF0X/Muffin-2.9b-1C25" # change this to your own model
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForCausalLM.from_pretrained(model_name)
12
+
13
  class PromptRequest(BaseModel):
14
  prompt: str
15
+ max_new_tokens: int = 50
16
 
17
  @app.post("/generate")
18
+ async def generate_text(req: PromptRequest):
19
+ inputs = tokenizer(req.prompt, return_tensors="pt")
20
+ outputs = model.generate(
21
+ **inputs,
22
+ max_new_tokens=req.max_new_tokens,
23
+ do_sample=True,
24
+ temperature=0.8,
25
+ top_p=0.95,
26
+ )
27
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
+ return {"generated_text": generated}