File size: 576 Bytes
b4f3263
 
 
 
 
 
 
67bae95
 
b4f3263
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import joblib
import numpy as np
from typing import List


class RandomForestModel:
    def __init__(self):
        self.scaler = joblib.load("scalers/rf_scaler.joblib")
        self.model = joblib.load("models/random_forest.joblib")

    def preprocess_input(self, secondary_model_features: List[float]) -> np.ndarray:
        return self.scaler.transform(np.array(secondary_model_features).astype(np.float32).reshape(1, -1))

    def predict(self, secondary_model_features: List[float]):
        return self.model.predict(self.preprocess_input(secondary_model_features))[0]