Spaces:
Running
Running
File size: 3,138 Bytes
2c97dd8 d6be5f7 2bfae3d 3ada3ad 2bfae3d d6be5f7 2c97dd8 d6be5f7 9a3022c 560244c 3ada3ad d6be5f7 2bfae3d 3ada3ad d6be5f7 3ada3ad 9a3022c 2bfae3d 9a3022c 3ada3ad 9a3022c d6be5f7 9a3022c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from queue import Queue
from threading import Thread
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model and tokenizer
model_name = "Qwen/Qwen2.5-7B-Instruct-1M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
class Question(BaseModel):
question: str
class CustomTextStreamer:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.queue = Queue()
self.skip_prompt = True
self.skip_special_tokens = True
def put(self, value):
# Handle token IDs (value is a tensor of token IDs)
if isinstance(value, torch.Tensor):
if value.dim() > 1:
value = value.squeeze(0) # Remove batch dimension if present
text = self.tokenizer.decode(value, skip_special_tokens=self.skip_special_tokens)
if text and not (self.skip_prompt and self.is_prompt(value)):
self.queue.put(text)
def end(self):
self.queue.put(None) # Signal end of generation
def is_prompt(self, value):
# Simple heuristic to skip prompt tokens (optional, adjust as needed)
return False # For simplicity, assume all tokens are response tokens
def __iter__(self):
while True:
item = self.queue.get()
if item is None:
break
yield item
def generate_response_chunks(prompt: str):
try:
# Prepare input
messages = [
{"role": "system", "content": "You are Orion AI assistant..."},
{"role": "user", "content": prompt}
]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Set up custom streamer
streamer = CustomTextStreamer(tokenizer)
# Run generation in a separate thread to avoid blocking
def generate():
with torch.no_grad():
model.generate(
inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=streamer
)
# Start generation in a thread
thread = Thread(target=generate)
thread.start()
# Yield chunks from the streamer
for text in streamer:
yield text
except Exception as e:
yield f"Error occurred: {str(e)}"
@app.post("/ask")
async def ask(question: Question):
return StreamingResponse(
generate_response_chunks(question.question),
media_type="text/plain"
) |