aiapp / main.py
abdullahalioo's picture
Update main.py
9a3022c verified
raw
history blame
1.96 kB
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import torch
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model and tokenizer
model_name = "Qwen/Qwen2.5-7B-Instruct-1M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
class Question(BaseModel):
question: str
def generate_response_chunks(prompt: str):
try:
# Prepare input
messages = [
{"role": "system", "content": "You are Orion AI assistant..."},
{"role": "user", "content": prompt}
]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Set up streamer
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Generate response with streaming
with torch.no_grad():
model.generate(
inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=streamer
)
# Since TextStreamer handles printing, we yield chunks from the streamer
for text in streamer:
if text:
yield text
except Exception as e:
yield f"Error occurred: {str(e)}"
@app.post("/ask")
async def ask(question: Question):
return StreamingResponse(
generate_response_chunks(question.question),
media_type="text/plain"
)