aiapp / main.py
abdullahalioo's picture
Create main.py
d6be5f7 verified
raw
history blame
2.82 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import asyncio
# FastAPI app
app = FastAPI()
# CORS Middleware (so JS from browser can access it)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Change "*" to your frontend URL for better security
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request body model
class Question(BaseModel):
question: str
# Load the model and tokenizer
model_name = "Qwen/Qwen2.5-7B-Instruct" # Use Qwen2.5-7B-Instruct (adjust for VL if needed)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # Use float16 for GPU memory efficiency
device_map="auto", # Automatically map to GPU/CPU
trust_remote_code=True
)
async def generate_response_chunks(prompt: str):
try:
# Prepare the input prompt
messages = [
{"role": "system", "content": "You are Orion AI assistant created by Abdullah Ali, who is very intelligent, 13 years old, and lives in Lahore."},
{"role": "user", "content": prompt}
]
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
# Asynchronous generator to yield tokens as they are generated
async def stream_tokens():
# Generate tokens one by one
for output in model.generate(
inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=False,
streaming=True # Enable streaming in model.generate (if supported)
):
# Decode the latest token
token_id = output.sequences[0][-1] # Get the last generated token
token_text = tokenizer.decode([token_id], skip_special_tokens=True)
if token_text:
yield token_text
await asyncio.sleep(0.01) # Small delay to control streaming speed
else:
# Handle special tokens or empty outputs
continue
return stream_tokens()
except Exception as e:
yield f"Error occurred: {e}"
@app.post("/ask")
async def ask(question: Question):
return StreamingResponse(
generate_response_chunks(question.question),
media_type="text/plain"
)