Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import asyncio | |
# FastAPI app | |
app = FastAPI() | |
# CORS Middleware (so JS from browser can access it) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Change "*" to your frontend URL for better security | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Request body model | |
class Question(BaseModel): | |
question: str | |
# Load the model and tokenizer | |
model_name = "Qwen/Qwen2.5-7B-Instruct" # Use Qwen2.5-7B-Instruct (adjust for VL if needed) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, # Use float16 for GPU memory efficiency | |
device_map="auto", # Automatically map to GPU/CPU | |
trust_remote_code=True | |
) | |
async def generate_response_chunks(prompt: str): | |
try: | |
# Prepare the input prompt | |
messages = [ | |
{"role": "system", "content": "You are Orion AI assistant created by Abdullah Ali, who is very intelligent, 13 years old, and lives in Lahore."}, | |
{"role": "user", "content": prompt} | |
] | |
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
# Asynchronous generator to yield tokens as they are generated | |
async def stream_tokens(): | |
# Generate tokens one by one | |
for output in model.generate( | |
inputs, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
return_dict_in_generate=True, | |
output_scores=False, | |
streaming=True # Enable streaming in model.generate (if supported) | |
): | |
# Decode the latest token | |
token_id = output.sequences[0][-1] # Get the last generated token | |
token_text = tokenizer.decode([token_id], skip_special_tokens=True) | |
if token_text: | |
yield token_text | |
await asyncio.sleep(0.01) # Small delay to control streaming speed | |
else: | |
# Handle special tokens or empty outputs | |
continue | |
return stream_tokens() | |
except Exception as e: | |
yield f"Error occurred: {e}" | |
async def ask(question: Question): | |
return StreamingResponse( | |
generate_response_chunks(question.question), | |
media_type="text/plain" | |
) |