from fastapi import FastAPI from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from transformers import AutoModelForCausalLM, AutoTokenizer import torch app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load model and tokenizer (do this once at startup) model_name = "Qwen/Qwen2.5-VL-7B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" ) class Question(BaseModel): question: str def generate_response_chunks(prompt: str): try: # Prepare input messages = [ {"role": "system", "content": "You are Orion AI assistant..."}, {"role": "user", "content": prompt} ] inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ).to(model.device) # Generate streamingly with torch.no_grad(): for outputs in model.generate( inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, streamer=None, # We'll implement manual streaming stopping_criteria=None ): chunk = outputs[0, inputs.shape[1]:] text = tokenizer.decode(chunk, skip_special_tokens=True) if text: yield text except Exception as e: yield f"Error occurred: {e}" @app.post("/ask") async def ask(question: Question): return StreamingResponse( generate_response_chunks(question.question), media_type="text/plain" )