aiapp / main.py
abdullahalioo's picture
Update main.py
2bfae3d verified
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from queue import Queue
from threading import Thread
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model and tokenizer
model_name = "Qwen/Qwen2.5-7B-Instruct-1M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
class Question(BaseModel):
question: str
class CustomTextStreamer:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.queue = Queue()
self.skip_prompt = True
self.skip_special_tokens = True
def put(self, value):
# Handle token IDs (value is a tensor of token IDs)
if isinstance(value, torch.Tensor):
if value.dim() > 1:
value = value.squeeze(0) # Remove batch dimension if present
text = self.tokenizer.decode(value, skip_special_tokens=self.skip_special_tokens)
if text and not (self.skip_prompt and self.is_prompt(value)):
self.queue.put(text)
def end(self):
self.queue.put(None) # Signal end of generation
def is_prompt(self, value):
# Simple heuristic to skip prompt tokens (optional, adjust as needed)
return False # For simplicity, assume all tokens are response tokens
def __iter__(self):
while True:
item = self.queue.get()
if item is None:
break
yield item
def generate_response_chunks(prompt: str):
try:
# Prepare input
messages = [
{"role": "system", "content": "You are Orion AI assistant..."},
{"role": "user", "content": prompt}
]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Set up custom streamer
streamer = CustomTextStreamer(tokenizer)
# Run generation in a separate thread to avoid blocking
def generate():
with torch.no_grad():
model.generate(
inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=streamer
)
# Start generation in a thread
thread = Thread(target=generate)
thread.start()
# Yield chunks from the streamer
for text in streamer:
yield text
except Exception as e:
yield f"Error occurred: {str(e)}"
@app.post("/ask")
async def ask(question: Question):
return StreamingResponse(
generate_response_chunks(question.question),
media_type="text/plain"
)