File size: 2,136 Bytes
0558e79
 
 
 
 
 
 
 
 
89208ac
0558e79
f8cb635
0558e79
 
f8cb635
0558e79
 
 
f8cb635
0558e79
f8cb635
0558e79
f8cb635
0558e79
 
 
 
 
 
 
 
 
89208ac
0558e79
f8cb635
89208ac
 
 
 
 
 
 
 
 
f8cb635
89208ac
0558e79
89208ac
f8cb635
0558e79
 
 
 
 
89208ac
 
 
0558e79
 
f8cb635
0558e79
f8cb635
0558e79
f8cb635
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from octo.model.octo_model import OctoModel
from PIL import Image
import numpy as np
import jax
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import io
import base64
from typing import List

# Set JAX to use CPU (adjust to GPU if available)
os.environ['JAX_PLATFORMS'] = 'cpu'

# Load Octo 1.5 model globally
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")

# Initialize FastAPI app
app = FastAPI(title="Octo 1.5 Inference API")

# Request body model
class InferenceRequest(BaseModel):
    image_base64: List[str]  # List of base64-encoded images
    task: str = "pick up the fork"  # Default task

# Health check endpoint
@app.get("/health")
async def health_check():
    return {"status": "healthy"}

# Inference endpoint
@app.post("/predict")
async def predict(request: InferenceRequest, dataset_name: str = "bridge_dataset"):
    try:
        # Decode and process images
        images = []
        for img_base64 in request.image_base64:
            if img_base64.startswith("data:image"):
                img_base64 = img_base64.split(",")[1]
            img_data = base64.b64decode(img_base64)
            img = Image.open(io.BytesIO(img_data)).resize((256, 256))
            img = np.array(img)
            images.append(img)

        # Stack images with batch dimension
        img_array = np.stack(images)[np.newaxis, ...]  # Shape: (1, T, 256, 256, 3)
        observation = {
            "image_primary": img_array,
            "timestep_pad_mask": np.ones((1, len(images)), dtype=bool)  # Shape: (1, T)
        }

        # Create task and predict actions
        task_obj = model.create_tasks(texts=[request.task])
        actions = model.sample_actions(
            observation,
            task_obj,
            unnormalization_statistics=model.dataset_statistics[dataset_name]["action"],
            rng=jax.random.PRNGKey(0)
        )
        actions = actions[0]  # Remove batch dimension, Shape: (T, action_dim)

        return {"actions": actions.tolist()}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error: {str(e)}")