File size: 9,280 Bytes
131da64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import asyncio
import base64
import logging
import multiprocessing as mp
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Union
import random
import json
import hydra
import torch
import time
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from uvicorn import run

from decoupled_utils import breakpoint_on_error
from demo.api_data_defs import ChatMessage, ChatRequest, ContentPart
from demo.inference_utils import (convert_request_base64_to_pil,
                                  convert_request_pil_to_base64,
                                  trim_merge_messages)
from utils import set_omega_conf_resolvers

logger = logging.getLogger("uvicorn.error")

mp.set_start_method('spawn', force=True)

set_omega_conf_resolvers()


async def dummy_response(messages: List[Dict[str, Any]]) -> ChatRequest:
    await asyncio.sleep(0.1)
    response_content = []
    for msg in messages:
        if msg["role"] == "user":
            for item in msg["content"]:
                if item["type"] == "text":
                    response_content.append(ContentPart(type="text", text="Response to: " + item["text"]))
                elif item["type"] == "image_url":
                    response_content.append(ContentPart(type="text", text="Image received and processed."))
    
    image_path = Path("static/0457_01.jpg")  # Replace with a real image path
    if image_path.is_file():
        with image_path.open("rb") as img_file:
            base64_str = base64.b64encode(img_file.read()).decode('utf-8')
            response_content.append(ContentPart(
                type="image_url",
                image_url={"url": f"data:image/jpeg;base64,{base64_str}"}
            ))
    else:
        logger.warning(f"Image file not found at {image_path}")
    
    return ChatRequest(messages=[ChatMessage(role="assistant", content=response_content)])

def call_model(messages: List[Dict[str, Any]], inference) -> ChatRequest:
    print(f"input messages: {messages}")
    returned_messages = inference(messages)
    openai_messages = convert_request_pil_to_base64(returned_messages)
    return openai_messages

def generate_response(messages: List[Dict[str, Any]], inference, dummy_response: bool = False) -> ChatRequest:
    if dummy_response:
        return dummy_response(messages)
    else:
        return call_model(messages, inference)

def call(inference, request: ChatRequest):
    try:
        print(f"Hash: {request.request_hash}")
        output_dir = Path(f"{Path(__file__).parent}/outputs/responses")
        filename = output_dir / f"{request.request_hash}.json"

        if request.request_hash is not None and filename.exists():
            with open(filename, "r") as f:
                generated_json = json.load(f)
                print(f"Response loaded from {filename}")
        else:
            processed_messages = convert_request_base64_to_pil(request)
            processed_messages = trim_merge_messages(processed_messages)
            generated: ChatRequest = generate_response(processed_messages, inference)
            generated_json = generated.messages[-1].model_dump()


        if request.request_hash is not None and not filename.exists():
            filename.parent.mkdir(parents=True, exist_ok=True)
            with open(filename, "w") as f:
                json.dump(generated.messages[-1].model_dump(), f, indent=2)
            
            print(f"Response saved to {filename}")

        # OpenAI format
        return JSONResponse({
            "id": "cmpl-000",
            "object": "chat.completion",
            "created": int(asyncio.get_event_loop().time()),
            "choices": [{
                "index": 0,
                "message": generated_json,
                "finish_reason": "stop"
            }],
            "usage": {
                "prompt_tokens": 0,
                "completion_tokens": 0,
                "total_tokens": 0
            }
        })
        
    except Exception as e:
        from traceback import format_exc
        logger.error(f"Error processing request: {str(e)}")
        logger.error(format_exc())
        raise HTTPException(status_code=500, detail=str(e))


def gpu_worker(gpu_id, config, request_queue, response_queue):
    torch.cuda.set_device(gpu_id) # We use this instead of CUDA_VISIBLE_DEVICES since the user may have set.
    from demo.inference import setup
    inference = setup(config)
    print(f"GPU {gpu_id} Initialized inference")
    while True:
        # Wait for a new request (blocking call)
        print(f"GPU {gpu_id} Waiting for request")
        request_data = request_queue.get()
        print(f"GPU {gpu_id} Received request")
        if request_data is None:
            print(f"GPU {gpu_id} Received shutdown signal")
            break  # a way to shut down this worker gracefully
        try:
            # Process the request – note that this call is synchronous
            print(f"GPU {gpu_id} Processing request")
            start_time = time.time()
            result = call(inference, request_data)
            print(f"GPU {gpu_id} Finished processing request in {time.time() - start_time} seconds")
            response_queue.put(result)
            print(f"GPU {gpu_id} Put result in response queue")
        except Exception as e:
            print(f"GPU {gpu_id} Error processing request {request_data}: {e}")
            response_queue.put(e)

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Check if we're in development mode
    dev_mode = getattr(app.config, "dev_mode", False)
    app.state.dev_mode = dev_mode
    print(f"Dev mode: {dev_mode}")
    
    if dev_mode:
        # Development mode: Single synchronous GPU process
        logging.info("Starting in DEVELOPMENT mode - synchronous operation, no multiprocessing")
        from demo.inference import setup
        app.state.inference = setup(app.config)
        yield
    else:
        # Normal mode with worker processes
        app.state.worker_lock = asyncio.Lock()
        workers = []
        num_gpus = torch.cuda.device_count()
        logging.info(f"Number of GPUs: {num_gpus}")
        for gpu_id in range(num_gpus):
            req_q = mp.Queue(maxsize=1)  # enforce one request at a time
            res_q = mp.Queue()
            p = mp.Process(target=gpu_worker, args=(gpu_id, app.config, req_q, res_q))
            p.start()
            workers.append({"process": p, "req_q": req_q, "res_q": res_q})
            logging.info(f"Started worker {gpu_id}")
        
        app.state.workers = workers
        yield
        # On shutdown: signal all workers to stop and join them
        for worker in app.state.workers:
            worker["req_q"].put(None)
        for worker in app.state.workers:
            worker["process"].join()
            logger.info("Worker process joined.")


app = FastAPI(title="Multimodal VLM Endpoint", lifespan=lifespan)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # or ["*"] to allow all origins
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
app.state.workers = []
logger = logging.getLogger("uvicorn")


@app.post("/v1/chat/completions")
async def chat_completion(request: ChatRequest):
    if getattr(app.state, "dev_mode", False):
        return call(app.state.inference, request)
    
    worker = None
    async with app.state.worker_lock:
        while worker is None:
            # Shuffle workers each time to distribute load
            workers = list(enumerate(app.state.workers))
            random.shuffle(workers)
            for i, w in workers:
                print(f"Trying to assign request to worker {i}")
                try:
                    w["req_q"].put_nowait(request)
                    worker = w
                    print(f"Assigned request to worker {w['process'].name}")
                    break
                except mp.queues.Full:
                    print(f"Worker {w['process'].name} is full")
                    continue
            if worker is None:
                await asyncio.sleep(0.1)

    loop = asyncio.get_running_loop()
    result = await loop.run_in_executor(None, worker["res_q"].get)
    if isinstance(result, Exception):
        raise HTTPException(status_code=500, detail=str(result))
    return result

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
    body = await request.body()
    logger.error("Request body: %s", body)
    logger.error("Validation errors: %s", exc.errors())
    logger.error("Original body: %s", exc.body)
    return JSONResponse(
        status_code=422,
        content={"detail": exc.errors(), "body": exc.body},
    )

@hydra.main(version_base=None, config_path="../configs", config_name="config")
@torch.no_grad()
def main(config):
    with breakpoint_on_error():
        app.config = config
        dev_mode = getattr(config, "dev_mode", False)
        app.state.dev_mode = dev_mode
        run(app, host="0.0.0.0", port=getattr(config, "port", 8001))
    
if __name__ == "__main__":
    main()