Nirav-Madhani commited on
Commit
89208ac
·
verified ·
1 Parent(s): cddb79a

Update app.py

Browse files

Multiple frames and data stats

Files changed (1) hide show
  1. app.py +45 -22
app.py CHANGED
@@ -7,20 +7,26 @@ from pydantic import BaseModel
7
  import os
8
  import io
9
  import base64
 
 
10
 
11
  # Set JAX to use CPU platform (adjust if GPU is needed)
12
  os.environ['JAX_PLATFORMS'] = 'cpu'
13
 
14
- # Load the model once globally (assumes it's cached locally)
15
  model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
16
 
17
  # Initialize FastAPI app
18
- app = FastAPI(title="Octo Model Inference API")
 
 
 
19
 
20
  # Define request body model
21
  class InferenceRequest(BaseModel):
22
- image_base64: str # Base64-encoded image string
23
  task: str = "pick up the fork" # Default task
 
24
 
25
  # Health check endpoint
26
  @app.get("/health")
@@ -29,37 +35,54 @@ async def health_check():
29
 
30
  # Inference endpoint
31
  @app.post("/predict")
32
- async def predict(request: InferenceRequest):
33
  try:
34
- # Decode base64 image
35
- img_base64 = request.image_base64
36
- if img_base64.startswith("data:image"):
37
- img_base64 = img_base64.split(",")[1]
38
-
39
- img_data = base64.b64decode(img_base64)
40
- img = Image.open(io.BytesIO(img_data)).resize((256, 256))
41
- img = np.array(img)
42
 
43
- # Add batch and time horizon dimensions
44
- img = img[np.newaxis, np.newaxis, ...] # Shape: (1, 1, 256, 256, 3)
 
 
 
 
 
 
 
 
 
 
45
  observation = {
46
- "image_primary": img,
47
- "timestep_pad_mask": np.array([[True]])
48
  }
49
 
50
  # Create task and predict actions
51
  task_obj = model.create_tasks(texts=[request.task])
52
  actions = model.sample_actions(
53
- observation,
54
- task_obj,
55
- unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"],
56
  rng=jax.random.PRNGKey(0)
57
  )
58
- actions = actions[0]
59
 
60
- # Convert NumPy array to list for JSON response
61
  actions_list = actions.tolist()
62
 
63
  return {"actions": actions_list}
64
  except Exception as e:
65
- raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
 
 
 
 
 
 
 
 
 
 
7
  import os
8
  import io
9
  import base64
10
+ from typing import List
11
+ from fastapi.openapi.docs import get_swagger_ui_html
12
 
13
  # Set JAX to use CPU platform (adjust if GPU is needed)
14
  os.environ['JAX_PLATFORMS'] = 'cpu'
15
 
16
+ # Load the model once globally
17
  model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
18
 
19
  # Initialize FastAPI app
20
+ app = FastAPI(
21
+ title="Octo Model Inference API",
22
+ docs_url="/" # Swagger UI at root
23
+ )
24
 
25
  # Define request body model
26
  class InferenceRequest(BaseModel):
27
+ image_base64: List[str] # List of base64-encoded images in time sequence
28
  task: str = "pick up the fork" # Default task
29
+ window_size: int = 2 # Default window size, configurable
30
 
31
  # Health check endpoint
32
  @app.get("/health")
 
35
 
36
  # Inference endpoint
37
  @app.post("/predict")
38
+ async def predict(request: InferenceRequest, dataset_name: str = "bridge_dataset"):
39
  try:
40
+ # Validate input
41
+ if len(request.image_base64) < request.window_size:
42
+ raise HTTPException(
43
+ status_code=400,
44
+ detail=f"At least {request.window_size} images required for the specified window size"
45
+ )
 
 
46
 
47
+ # Process images
48
+ images = []
49
+ for img_base64 in request.image_base64:
50
+ if img_base64.startswith("data:image"):
51
+ img_base64 = img_base64.split(",")[1]
52
+ img_data = base64.b64decode(img_base64)
53
+ img = Image.open(io.BytesIO(img_data)).resize((256, 256))
54
+ img = np.array(img)
55
+ images.append(img)
56
+
57
+ # Stack all images and add batch dimension
58
+ img_array = np.stack(images)[np.newaxis, ...] # Shape: (1, T, 256, 256, 3)
59
  observation = {
60
+ "image_primary": img_array,
61
+ "timestep_pad_mask": np.full((1, len(images)), True, dtype=bool) # Shape: (1, T)
62
  }
63
 
64
  # Create task and predict actions
65
  task_obj = model.create_tasks(texts=[request.task])
66
  actions = model.sample_actions(
67
+ observation,
68
+ task_obj,
69
+ unnormalization_statistics=model.dataset_statistics[dataset_name]["action"],
70
  rng=jax.random.PRNGKey(0)
71
  )
72
+ actions = actions[0] # Remove batch dimension, Shape: (horizon, action_dim)
73
 
74
+ # Convert to list for JSON response
75
  actions_list = actions.tolist()
76
 
77
  return {"actions": actions_list}
78
  except Exception as e:
79
+ raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
80
+
81
+ # Custom Swagger UI route (optional)
82
+ @app.get("/docs", include_in_schema=False)
83
+ async def custom_swagger_ui_html():
84
+ return get_swagger_ui_html(
85
+ openapi_url=app.openapi_url,
86
+ title=app.title + " - Swagger UI",
87
+ oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
88
+ )