File size: 3,004 Bytes
0558e79
 
 
 
 
 
 
 
 
89208ac
 
0558e79
 
 
 
89208ac
8a209a5
0558e79
 
89208ac
 
 
 
0558e79
 
 
89208ac
0558e79
89208ac
0558e79
 
 
 
 
 
 
 
89208ac
0558e79
89208ac
 
 
 
 
 
0558e79
89208ac
 
 
 
 
 
 
 
 
 
 
 
0558e79
89208ac
 
0558e79
 
 
 
 
89208ac
 
 
0558e79
 
89208ac
0558e79
89208ac
0558e79
 
 
 
89208ac
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from octo.model.octo_model import OctoModel
from PIL import Image
import numpy as np
import jax
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import io
import base64
from typing import List
from fastapi.openapi.docs import get_swagger_ui_html

# Set JAX to use CPU platform (adjust if GPU is needed)
os.environ['JAX_PLATFORMS'] = 'cpu'

# Load the model once globally
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base-1.5")

# Initialize FastAPI app
app = FastAPI(
    title="Octo Model Inference API",
    docs_url="/"  # Swagger UI at root
)

# Define request body model
class InferenceRequest(BaseModel):
    image_base64: List[str]  # List of base64-encoded images in time sequence
    task: str = "pick up the fork"  # Default task
    window_size: int = 2  # Default window size, configurable

# Health check endpoint
@app.get("/health")
async def health_check():
    return {"status": "healthy"}

# Inference endpoint
@app.post("/predict")
async def predict(request: InferenceRequest, dataset_name: str = "bridge_dataset"):
    try:
        # Validate input
        if len(request.image_base64) < request.window_size:
            raise HTTPException(
                status_code=400,
                detail=f"At least {request.window_size} images required for the specified window size"
            )

        # Process images
        images = []
        for img_base64 in request.image_base64:
            if img_base64.startswith("data:image"):
                img_base64 = img_base64.split(",")[1]
            img_data = base64.b64decode(img_base64)
            img = Image.open(io.BytesIO(img_data)).resize((256, 256))
            img = np.array(img)
            images.append(img)

        # Stack all images and add batch dimension
        img_array = np.stack(images)[np.newaxis, ...]  # Shape: (1, T, 256, 256, 3)
        observation = {
            "image_primary": img_array,
            "timestep_pad_mask": np.full((1, len(images)), True, dtype=bool)  # Shape: (1, T)
        }

        # Create task and predict actions
        task_obj = model.create_tasks(texts=[request.task])
        actions = model.sample_actions(
            observation,
            task_obj,
            unnormalization_statistics=model.dataset_statistics[dataset_name]["action"],
            rng=jax.random.PRNGKey(0)
        )
        actions = actions[0]  # Remove batch dimension, Shape: (horizon, action_dim)

        # Convert to list for JSON response
        actions_list = actions.tolist()

        return {"actions": actions_list}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")

# Custom Swagger UI route (optional)
@app.get("/docs", include_in_schema=False)
async def custom_swagger_ui_html():
    return get_swagger_ui_html(
        openapi_url=app.openapi_url,
        title=app.title + " - Swagger UI",
        oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
    )