Spaces:
Sleeping
Sleeping
File size: 2,136 Bytes
0558e79 89208ac 0558e79 f8cb635 0558e79 f8cb635 0558e79 f8cb635 0558e79 f8cb635 0558e79 f8cb635 0558e79 89208ac 0558e79 f8cb635 89208ac f8cb635 89208ac 0558e79 89208ac f8cb635 0558e79 89208ac 0558e79 f8cb635 0558e79 f8cb635 0558e79 f8cb635 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from octo.model.octo_model import OctoModel
from PIL import Image
import numpy as np
import jax
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import io
import base64
from typing import List
# Set JAX to use CPU (adjust to GPU if available)
os.environ['JAX_PLATFORMS'] = 'cpu'
# Load Octo 1.5 model globally
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
# Initialize FastAPI app
app = FastAPI(title="Octo 1.5 Inference API")
# Request body model
class InferenceRequest(BaseModel):
image_base64: List[str] # List of base64-encoded images
task: str = "pick up the fork" # Default task
# Health check endpoint
@app.get("/health")
async def health_check():
return {"status": "healthy"}
# Inference endpoint
@app.post("/predict")
async def predict(request: InferenceRequest, dataset_name: str = "bridge_dataset"):
try:
# Decode and process images
images = []
for img_base64 in request.image_base64:
if img_base64.startswith("data:image"):
img_base64 = img_base64.split(",")[1]
img_data = base64.b64decode(img_base64)
img = Image.open(io.BytesIO(img_data)).resize((256, 256))
img = np.array(img)
images.append(img)
# Stack images with batch dimension
img_array = np.stack(images)[np.newaxis, ...] # Shape: (1, T, 256, 256, 3)
observation = {
"image_primary": img_array,
"timestep_pad_mask": np.ones((1, len(images)), dtype=bool) # Shape: (1, T)
}
# Create task and predict actions
task_obj = model.create_tasks(texts=[request.task])
actions = model.sample_actions(
observation,
task_obj,
unnormalization_statistics=model.dataset_statistics[dataset_name]["action"],
rng=jax.random.PRNGKey(0)
)
actions = actions[0] # Remove batch dimension, Shape: (T, action_dim)
return {"actions": actions.tolist()}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {str(e)}") |