aiapp / main.py
abdullahalioo's picture
Update main.py
c430681 verified
raw
history blame
1.97 kB
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model and tokenizer (do this once at startup)
model_name = "Qwen/Qwen2.5-VL-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
class Question(BaseModel):
question: str
def generate_response_chunks(prompt: str):
try:
# Prepare input
messages = [
{"role": "system", "content": "You are Orion AI assistant..."},
{"role": "user", "content": prompt}
]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Generate streamingly
with torch.no_grad():
for outputs in model.generate(
inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=None, # We'll implement manual streaming
stopping_criteria=None
):
chunk = outputs[0, inputs.shape[1]:]
text = tokenizer.decode(chunk, skip_special_tokens=True)
if text:
yield text
except Exception as e:
yield f"Error occurred: {e}"
@app.post("/ask")
async def ask(question: Question):
return StreamingResponse(
generate_response_chunks(question.question),
media_type="text/plain"
)