Octo-1.5-Base / app.py
Nirav-Madhani's picture
Update app.py
8a209a5 verified
from octo.model.octo_model import OctoModel
from PIL import Image
import numpy as np
import jax
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import io
import base64
from typing import List
from fastapi.openapi.docs import get_swagger_ui_html
# Set JAX to use CPU platform (adjust if GPU is needed)
os.environ['JAX_PLATFORMS'] = 'cpu'
# Load the model once globally
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base-1.5")
# Initialize FastAPI app
app = FastAPI(
title="Octo Model Inference API",
docs_url="/" # Swagger UI at root
)
# Define request body model
class InferenceRequest(BaseModel):
image_base64: List[str] # List of base64-encoded images in time sequence
task: str = "pick up the fork" # Default task
window_size: int = 2 # Default window size, configurable
# Health check endpoint
@app.get("/health")
async def health_check():
return {"status": "healthy"}
# Inference endpoint
@app.post("/predict")
async def predict(request: InferenceRequest, dataset_name: str = "bridge_dataset"):
try:
# Validate input
if len(request.image_base64) < request.window_size:
raise HTTPException(
status_code=400,
detail=f"At least {request.window_size} images required for the specified window size"
)
# Process images
images = []
for img_base64 in request.image_base64:
if img_base64.startswith("data:image"):
img_base64 = img_base64.split(",")[1]
img_data = base64.b64decode(img_base64)
img = Image.open(io.BytesIO(img_data)).resize((256, 256))
img = np.array(img)
images.append(img)
# Stack all images and add batch dimension
img_array = np.stack(images)[np.newaxis, ...] # Shape: (1, T, 256, 256, 3)
observation = {
"image_primary": img_array,
"timestep_pad_mask": np.full((1, len(images)), True, dtype=bool) # Shape: (1, T)
}
# Create task and predict actions
task_obj = model.create_tasks(texts=[request.task])
actions = model.sample_actions(
observation,
task_obj,
unnormalization_statistics=model.dataset_statistics[dataset_name]["action"],
rng=jax.random.PRNGKey(0)
)
actions = actions[0] # Remove batch dimension, Shape: (horizon, action_dim)
# Convert to list for JSON response
actions_list = actions.tolist()
return {"actions": actions_list}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
# Custom Swagger UI route (optional)
@app.get("/docs", include_in_schema=False)
async def custom_swagger_ui_html():
return get_swagger_ui_html(
openapi_url=app.openapi_url,
title=app.title + " - Swagger UI",
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
)