|
import asyncio |
|
import base64 |
|
import logging |
|
import multiprocessing as mp |
|
from contextlib import asynccontextmanager |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Union |
|
import random |
|
import json |
|
import hydra |
|
import torch |
|
import time |
|
from fastapi import FastAPI, HTTPException, Request |
|
from fastapi.exceptions import RequestValidationError |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import JSONResponse |
|
from uvicorn import run |
|
|
|
from decoupled_utils import breakpoint_on_error |
|
from demo.api_data_defs import ChatMessage, ChatRequest, ContentPart |
|
from demo.inference_utils import (convert_request_base64_to_pil, |
|
convert_request_pil_to_base64, |
|
trim_merge_messages) |
|
from utils import set_omega_conf_resolvers |
|
|
|
logger = logging.getLogger("uvicorn.error") |
|
|
|
mp.set_start_method('spawn', force=True) |
|
|
|
set_omega_conf_resolvers() |
|
|
|
|
|
async def dummy_response(messages: List[Dict[str, Any]]) -> ChatRequest: |
|
await asyncio.sleep(0.1) |
|
response_content = [] |
|
for msg in messages: |
|
if msg["role"] == "user": |
|
for item in msg["content"]: |
|
if item["type"] == "text": |
|
response_content.append(ContentPart(type="text", text="Response to: " + item["text"])) |
|
elif item["type"] == "image_url": |
|
response_content.append(ContentPart(type="text", text="Image received and processed.")) |
|
|
|
image_path = Path("static/0457_01.jpg") |
|
if image_path.is_file(): |
|
with image_path.open("rb") as img_file: |
|
base64_str = base64.b64encode(img_file.read()).decode('utf-8') |
|
response_content.append(ContentPart( |
|
type="image_url", |
|
image_url={"url": f"data:image/jpeg;base64,{base64_str}"} |
|
)) |
|
else: |
|
logger.warning(f"Image file not found at {image_path}") |
|
|
|
return ChatRequest(messages=[ChatMessage(role="assistant", content=response_content)]) |
|
|
|
def call_model(messages: List[Dict[str, Any]], inference) -> ChatRequest: |
|
print(f"input messages: {messages}") |
|
returned_messages = inference(messages) |
|
openai_messages = convert_request_pil_to_base64(returned_messages) |
|
return openai_messages |
|
|
|
def generate_response(messages: List[Dict[str, Any]], inference, dummy_response: bool = False) -> ChatRequest: |
|
if dummy_response: |
|
return dummy_response(messages) |
|
else: |
|
return call_model(messages, inference) |
|
|
|
def call(inference, request: ChatRequest): |
|
try: |
|
print(f"Hash: {request.request_hash}") |
|
output_dir = Path(f"{Path(__file__).parent}/outputs/responses") |
|
filename = output_dir / f"{request.request_hash}.json" |
|
|
|
if request.request_hash is not None and filename.exists(): |
|
with open(filename, "r") as f: |
|
generated_json = json.load(f) |
|
print(f"Response loaded from {filename}") |
|
else: |
|
processed_messages = convert_request_base64_to_pil(request) |
|
processed_messages = trim_merge_messages(processed_messages) |
|
generated: ChatRequest = generate_response(processed_messages, inference) |
|
generated_json = generated.messages[-1].model_dump() |
|
|
|
|
|
if request.request_hash is not None and not filename.exists(): |
|
filename.parent.mkdir(parents=True, exist_ok=True) |
|
with open(filename, "w") as f: |
|
json.dump(generated.messages[-1].model_dump(), f, indent=2) |
|
|
|
print(f"Response saved to {filename}") |
|
|
|
|
|
return JSONResponse({ |
|
"id": "cmpl-000", |
|
"object": "chat.completion", |
|
"created": int(asyncio.get_event_loop().time()), |
|
"choices": [{ |
|
"index": 0, |
|
"message": generated_json, |
|
"finish_reason": "stop" |
|
}], |
|
"usage": { |
|
"prompt_tokens": 0, |
|
"completion_tokens": 0, |
|
"total_tokens": 0 |
|
} |
|
}) |
|
|
|
except Exception as e: |
|
from traceback import format_exc |
|
logger.error(f"Error processing request: {str(e)}") |
|
logger.error(format_exc()) |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
def gpu_worker(gpu_id, config, request_queue, response_queue): |
|
torch.cuda.set_device(gpu_id) |
|
from demo.inference import setup |
|
inference = setup(config) |
|
print(f"GPU {gpu_id} Initialized inference") |
|
while True: |
|
|
|
print(f"GPU {gpu_id} Waiting for request") |
|
request_data = request_queue.get() |
|
print(f"GPU {gpu_id} Received request") |
|
if request_data is None: |
|
print(f"GPU {gpu_id} Received shutdown signal") |
|
break |
|
try: |
|
|
|
print(f"GPU {gpu_id} Processing request") |
|
start_time = time.time() |
|
result = call(inference, request_data) |
|
print(f"GPU {gpu_id} Finished processing request in {time.time() - start_time} seconds") |
|
response_queue.put(result) |
|
print(f"GPU {gpu_id} Put result in response queue") |
|
except Exception as e: |
|
print(f"GPU {gpu_id} Error processing request {request_data}: {e}") |
|
response_queue.put(e) |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
dev_mode = getattr(app.config, "dev_mode", False) |
|
app.state.dev_mode = dev_mode |
|
print(f"Dev mode: {dev_mode}") |
|
|
|
if dev_mode: |
|
|
|
logging.info("Starting in DEVELOPMENT mode - synchronous operation, no multiprocessing") |
|
from demo.inference import setup |
|
app.state.inference = setup(app.config) |
|
yield |
|
else: |
|
|
|
app.state.worker_lock = asyncio.Lock() |
|
workers = [] |
|
num_gpus = torch.cuda.device_count() |
|
logging.info(f"Number of GPUs: {num_gpus}") |
|
for gpu_id in range(num_gpus): |
|
req_q = mp.Queue(maxsize=1) |
|
res_q = mp.Queue() |
|
p = mp.Process(target=gpu_worker, args=(gpu_id, app.config, req_q, res_q)) |
|
p.start() |
|
workers.append({"process": p, "req_q": req_q, "res_q": res_q}) |
|
logging.info(f"Started worker {gpu_id}") |
|
|
|
app.state.workers = workers |
|
yield |
|
|
|
for worker in app.state.workers: |
|
worker["req_q"].put(None) |
|
for worker in app.state.workers: |
|
worker["process"].join() |
|
logger.info("Worker process joined.") |
|
|
|
|
|
app = FastAPI(title="Multimodal VLM Endpoint", lifespan=lifespan) |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
app.state.workers = [] |
|
logger = logging.getLogger("uvicorn") |
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
async def chat_completion(request: ChatRequest): |
|
if getattr(app.state, "dev_mode", False): |
|
return call(app.state.inference, request) |
|
|
|
worker = None |
|
async with app.state.worker_lock: |
|
while worker is None: |
|
|
|
workers = list(enumerate(app.state.workers)) |
|
random.shuffle(workers) |
|
for i, w in workers: |
|
print(f"Trying to assign request to worker {i}") |
|
try: |
|
w["req_q"].put_nowait(request) |
|
worker = w |
|
print(f"Assigned request to worker {w['process'].name}") |
|
break |
|
except mp.queues.Full: |
|
print(f"Worker {w['process'].name} is full") |
|
continue |
|
if worker is None: |
|
await asyncio.sleep(0.1) |
|
|
|
loop = asyncio.get_running_loop() |
|
result = await loop.run_in_executor(None, worker["res_q"].get) |
|
if isinstance(result, Exception): |
|
raise HTTPException(status_code=500, detail=str(result)) |
|
return result |
|
|
|
@app.exception_handler(RequestValidationError) |
|
async def validation_exception_handler(request: Request, exc: RequestValidationError): |
|
body = await request.body() |
|
logger.error("Request body: %s", body) |
|
logger.error("Validation errors: %s", exc.errors()) |
|
logger.error("Original body: %s", exc.body) |
|
return JSONResponse( |
|
status_code=422, |
|
content={"detail": exc.errors(), "body": exc.body}, |
|
) |
|
|
|
@hydra.main(version_base=None, config_path="../configs", config_name="config") |
|
@torch.no_grad() |
|
def main(config): |
|
with breakpoint_on_error(): |
|
app.config = config |
|
dev_mode = getattr(config, "dev_mode", False) |
|
app.state.dev_mode = dev_mode |
|
run(app, host="0.0.0.0", port=getattr(config, "port", 8001)) |
|
|
|
if __name__ == "__main__": |
|
main() |