File size: 13,918 Bytes
5d92054
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
import joblib
import uvicorn
import xgboost as xgb
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, UploadFile, HTTPException

from fastapi.responses import HTMLResponse
from fastapi.responses import JSONResponse
import asyncio
import json
import pickle
import warnings
import os
import io

import timeit
from PIL import Image
import numpy as np
import cv2

# Add this to your existing imports if not already present
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.openapi.utils import get_openapi

from models.detr_model import DETR
from models.glpn_model import GLPDepth
from models.lstm_model import LSTM_Model
from models.predict_z_location_single_row_lstm import predict_z_location_single_row_lstm
from utils.processing import PROCESSING
from config import CONFIG

warnings.filterwarnings("ignore")

# Initialize FastAPI app
app = FastAPI(
    title="Real-Time WebSocket Image Processing API",
    description="API for object detection and depth estimation using WebSocket for real-time image processing.",
)

try:
    # Load models and utilities
    device = CONFIG['device']
    print("Loading models...")

    detr = DETR()  # Object detection model (DETR)
    print("DETR model loaded.")

    glpn = GLPDepth()  # Depth estimation model (GLPN)
    print("GLPDepth model loaded.")

    zlocE_LSTM = LSTM_Model()  # LSTM model for prediction (e.g., localization)
    print("LSTM model loaded.")

    
    lstm_scaler = pickle.load(open(CONFIG['lstm_scaler_path'], 'rb'))  # Load pre-trained scaler for LSTM
    print("LSTM Scaler loaded.")


    processing = PROCESSING()  # Utility class for post-processing
    print("Processing utilities loaded.")
    
except Exception as e:
    print(f"An unexpected error occurred. Details: {e}")



# Serve HTML documentation for the API
@app.get("/", response_class=HTMLResponse)
async def get_docs():
    """

    Serve HTML documentation for the WebSocket-based image processing API.

    The HTML file must be available in the same directory.

    Returns a 404 error if the documentation file is not found.

    """
    html_path = os.path.join(os.path.dirname(__file__), "api_documentation.html")
    if not os.path.exists(html_path):
        return HTMLResponse(content="api_documentation.html file not found", status_code=404)
    with open(html_path, "r") as f:
        return HTMLResponse(f.read())


@app.get("/try_page", response_class=HTMLResponse)
async def get_docs():
    """

    Serve HTML documentation for the WebSocket-based image processing API.

    The HTML file must be available in the same directory.

    Returns a 404 error if the documentation file is not found.

    """
    html_path = os.path.join(os.path.dirname(__file__), "try_page.html")
    if not os.path.exists(html_path):
        return HTMLResponse(content="try_page.html file not found", status_code=404)
    with open(html_path, "r") as f:
        return HTMLResponse(f.read())


# Function to decode the image received via WebSocket
async def decode_image(image_bytes):
    """

    Decodes image bytes into a PIL Image and returns the image along with its shape.



    Args:

        image_bytes (bytes): The image data received from the client.



    Returns:

        tuple: A tuple containing:

            - pil_image (PIL.Image): The decoded image.

            - img_shape (tuple): Shape of the image as (height, width).

            - decode_time (float): Time taken to decode the image in seconds.



    Raises:

        ValueError: If image decoding fails.

    """
    start = timeit.default_timer()
    nparr = np.frombuffer(image_bytes, np.uint8)
    frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
    if frame is None:
        raise ValueError("Failed to decode image")
    color_converted = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pil_image = Image.fromarray(color_converted)
    img_shape = color_converted.shape[0:2]  # (height, width)
    end = timeit.default_timer()
    return pil_image, img_shape, end - start


# Function to run the DETR model for object detection
async def run_detr_model(pil_image):
    """

    Runs the DETR (DEtection TRansformer) model to perform object detection on the input image.



    Args:

        pil_image (PIL.Image): The image to be processed by the DETR model.



    Returns:

        tuple: A tuple containing:

            - detr_result (tuple): The DETR model output consisting of detections' scores and boxes.

            - detr_time (float): Time taken to run the DETR model in seconds.

    """
    start = timeit.default_timer()
    detr_result = await asyncio.to_thread(detr.detect, pil_image)
    end = timeit.default_timer()
    return detr_result, end - start


# Function to run the GLPN model for depth estimation
async def run_glpn_model(pil_image, img_shape):
    """

    Runs the GLPN (Global Local Prediction Network) model to estimate the depth of the objects in the image.



    Args:

        pil_image (PIL.Image): The image to be processed by the GLPN model.

        img_shape (tuple): The shape of the image as (height, width).



    Returns:

        tuple: A tuple containing:

            - depth_map (numpy.ndarray): The depth map for the input image.

            - glpn_time (float): Time taken to run the GLPN model in seconds.

    """
    start = timeit.default_timer()
    depth_map = await asyncio.to_thread(glpn.predict, pil_image, img_shape)
    end = timeit.default_timer()
    return depth_map, end - start


# Function to process the detections with depth map
async def process_detections(scores, boxes, depth_map):
    """

    Processes the DETR model detections and integrates depth information from the GLPN model.



    Args:

        scores (numpy.ndarray): The detection scores for the detected objects.

        boxes (numpy.ndarray): The bounding boxes for the detected objects.

        depth_map (numpy.ndarray): The depth map generated by the GLPN model.



    Returns:

        tuple: A tuple containing:

            - pdata (dict): Processed detection data including depth and bounding box information.

            - process_time (float): Time taken for processing detections in seconds.

    """
    start = timeit.default_timer()
    pdata = processing.process_detections(scores, boxes, depth_map, detr)
    end = timeit.default_timer()
    return pdata, end - start


# Function to generate JSON output for LSTM predictions
async def generate_json_output(data):
    """

       Predict Z-location for each object in the data and prepare the JSON output.



       Parameters:

       - data: DataFrame with bounding box coordinates, depth information, and class type.

       - ZlocE: Pre-loaded LSTM model for Z-location prediction.

       - scaler: Scaler for normalizing input data.



       Returns:

       - JSON structure with object class, distance estimated, and relevant features.

       """
    output_json = []
    start = timeit.default_timer()

    # Iterate over each row in the data
    for i, row in data.iterrows():
        # Predict distance for each object using the single-row prediction function

        distance = predict_z_location_single_row_lstm(row, zlocE_LSTM, lstm_scaler)

        # Create object info dictionary
        object_info = {
            "class": row["class"],  # Object class (e.g., 'car', 'truck')
            "distance_estimated": float(distance),  # Convert distance to float (if necessary)
            "features": {
                "xmin": float(row["xmin"]),  # Bounding box xmin
                "ymin": float(row["ymin"]),  # Bounding box ymin
                "xmax": float(row["xmax"]),  # Bounding box xmax
                "ymax": float(row["ymax"]),  # Bounding box ymax
                "mean_depth": float(row["depth_mean"]),  # Depth mean
                "depth_mean_trim": float(row["depth_mean_trim"]),  # Depth mean trim
                "depth_median": float(row["depth_median"]),  # Depth median
                "width": float(row["width"]),  # Object width
                "height": float(row["height"])  # Object height
            }
        }

        # Append each object info to the output JSON list
        output_json.append(object_info)

    end = timeit.default_timer()

    # Return the final JSON output structure, and time
    return {"objects": output_json}, end - start


# Function to process a single frame (image) in the WebSocket stream
async def process_frame(frame_id, image_bytes):
    """

    Processes a single frame (image) from the WebSocket stream. The process includes:

    - Decoding the image.

    - Running the DETR and GLPN models concurrently.

    - Processing detections and generating the final output JSON.



    Args:

        frame_id (int): The identifier for the frame being processed.

        image_bytes (bytes): The image data received from the WebSocket.



    Returns:

        dict: A dictionary containing the output JSON and timing information for each processing step.

    """
    timings = {}
    try:
        # Step 1: Decode the image
        pil_image, img_shape, decode_time = await decode_image(image_bytes)
        timings["decode_time"] = decode_time

        # Step 2: Run DETR and GLPN models in parallel
        (detr_result, detr_time), (depth_map, glpn_time) = await asyncio.gather(
            run_detr_model(pil_image),
            run_glpn_model(pil_image, img_shape)
        )
        models_time = max(detr_time, glpn_time)  # Take the longest time of the two models
        timings["models_time"] = models_time

        # Step 3: Process detections with depth map
        scores, boxes = detr_result
        pdata, process_time = await process_detections(scores, boxes, depth_map)
        timings["process_time"] = process_time

        # Step 4: Generate output JSON
        print("generate json")
        output_json, json_time = await generate_json_output(pdata)
        print(output_json)
        timings["json_time"] = json_time

        timings["total_time"] = decode_time + models_time + process_time + json_time

        # Add frame_id and timings to the JSON output
        output_json["frame_id"] = frame_id
        output_json["timings"] = timings

        return output_json

    except Exception as e:
        return {
            "error": str(e),
            "frame_id": frame_id,
            "timings": timings
        }
    

@app.post("/api/predict", summary="Process a single image for object detection and depth estimation")
async def process_image(file: UploadFile = File(...)):
    """

    Process a single image for object detection and depth estimation.

    

    The endpoint performs:

    - Object detection using DETR model

    - Depth estimation using GLPN model

    - Z-location prediction using LSTM model

    

    Parameters:

    - file: Image file to process (JPEG, PNG)

    

    Returns:

    - JSON response with detected objects, estimated distances, and timing information

    """
    try:
        # Read image content
        image_bytes = await file.read()
        if not image_bytes:
            raise HTTPException(status_code=400, detail="Empty file")
            
        # Use the same processing pipeline as the WebSocket endpoint
        result = await process_frame(0, image_bytes)
        
        # Check if there's an error
        if "error" in result:
            raise HTTPException(status_code=500, detail=result["error"])
            
        return JSONResponse(content=result)
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    


# Add custom OpenAPI documentation
@app.get("/api/docs", include_in_schema=False)
async def custom_swagger_ui_html():
    return get_swagger_ui_html(
        openapi_url="/api/openapi.json",
        title="Real-Time Image Processing API Documentation",
        swagger_js_url="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui-bundle.js",
        swagger_css_url="https://cdnjs.cloudflare.com/ajax/libs/swagger-ui/4.18.3/swagger-ui.css",
    )


@app.get("/api/openapi.json", include_in_schema=False)
async def get_open_api_endpoint():
    return get_openapi(
        title="Real-Time Image Processing API",
        version="1.0.0",
        description="API for object detection, depth estimation, and z-location prediction using computer vision models",
        routes=app.routes,
    )


@app.websocket("/ws/predict")
async def websocket_endpoint(websocket: WebSocket):
    """

    WebSocket endpoint for real-time image processing. Clients can send image frames to be processed

    and receive JSON output containing object detection, depth estimation, and predictions in real-time.



    - Handles the reception of image data over the WebSocket.

    - Calls the image processing pipeline and returns the result.



    Args:

        websocket (WebSocket): The WebSocket connection to the client.

    """
    await websocket.accept()
    frame_id = 0

    try:
        while True:
            frame_id += 1

            # Receive image bytes from the client
            image_bytes = await websocket.receive_bytes()

            # Process the frame asynchronously
            processing_task = asyncio.create_task(process_frame(frame_id, image_bytes))
            result = await processing_task

            # Send the result back to the client
            await websocket.send_text(json.dumps(result))

    except WebSocketDisconnect:
        print(f"Client disconnected after processing {frame_id} frames.")
    except Exception as e:
        print(f"Unexpected error: {e}")
    finally:
        await websocket.close()