from octo.model.octo_model import OctoModel from PIL import Image import numpy as np import jax from fastapi import FastAPI, HTTPException from pydantic import BaseModel import os import io import base64 from typing import List from fastapi.openapi.docs import get_swagger_ui_html # Set JAX to use CPU platform (adjust if GPU is needed) os.environ['JAX_PLATFORMS'] = 'cpu' # Load the model once globally model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base-1.5") # Initialize FastAPI app app = FastAPI( title="Octo Model Inference API", docs_url="/" # Swagger UI at root ) # Define request body model class InferenceRequest(BaseModel): image_base64: List[str] # List of base64-encoded images in time sequence task: str = "pick up the fork" # Default task window_size: int = 2 # Default window size, configurable # Health check endpoint @app.get("/health") async def health_check(): return {"status": "healthy"} # Inference endpoint @app.post("/predict") async def predict(request: InferenceRequest, dataset_name: str = "bridge_dataset"): try: # Validate input if len(request.image_base64) < request.window_size: raise HTTPException( status_code=400, detail=f"At least {request.window_size} images required for the specified window size" ) # Process images images = [] for img_base64 in request.image_base64: if img_base64.startswith("data:image"): img_base64 = img_base64.split(",")[1] img_data = base64.b64decode(img_base64) img = Image.open(io.BytesIO(img_data)).resize((256, 256)) img = np.array(img) images.append(img) # Stack all images and add batch dimension img_array = np.stack(images)[np.newaxis, ...] # Shape: (1, T, 256, 256, 3) observation = { "image_primary": img_array, "timestep_pad_mask": np.full((1, len(images)), True, dtype=bool) # Shape: (1, T) } # Create task and predict actions task_obj = model.create_tasks(texts=[request.task]) actions = model.sample_actions( observation, task_obj, unnormalization_statistics=model.dataset_statistics[dataset_name]["action"], rng=jax.random.PRNGKey(0) ) actions = actions[0] # Remove batch dimension, Shape: (horizon, action_dim) # Convert to list for JSON response actions_list = actions.tolist() return {"actions": actions_list} except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") # Custom Swagger UI route (optional) @app.get("/docs", include_in_schema=False) async def custom_swagger_ui_html(): return get_swagger_ui_html( openapi_url=app.openapi_url, title=app.title + " - Swagger UI", oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url, )