foc_intrusion_detector / inference.py
ebelfrank's picture
Upload inference.py with huggingface_hub
b65c5de verified
raw
history blame contribute delete
682 Bytes
import xgboost as xgb
import json
import numpy as np
def model_fn(model_dir):
# Load the model from Hugging Face Hub
model = xgb.Booster()
model.load_model(f"{model_dir}/xgboost_model.json")
return model
def predict_fn(data, model):
# Convert input data into DMatrix
dmatrix = xgb.DMatrix(np.array(data['inputs']))
prediction = model.predict(dmatrix)
return prediction.tolist()
if __name__ == "__main__":
# Example of testing locally
model = model_fn(".")
sample_data = {"inputs": [[1, 2, 3], [4, 5, 6]]} # Replace with your input features
predictions = predict_fn(sample_data, model)
print(predictions)