File size: 3,923 Bytes
fa91fe4 6c67532 fa91fe4 6c67532 fa91fe4 6c67532 fa91fe4 6c67532 fa91fe4 6c67532 fa91fe4 6c67532 fa91fe4 6c67532 fa91fe4 6c67532 fa91fe4 6c67532 fa91fe4 6c67532 fa91fe4 6c67532 fa91fe4 6c67532 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
import joblib
import pandas as pd
from huggingface_hub import hf_hub_download, HfApi
from is_click_predictor.model_trainer import train_models
from is_click_predictor.model_manager import save_models, load_models
from is_click_predictor.model_predictor import predict
from is_click_predictor.config import MODEL_DIR # Ensure consistency
# Hugging Face Model & Dataset Information
MODEL_REPO = "taimax13/is_click_predictor"
MODEL_FILENAME = "rf_model.pkl"
DATA_REPO = "taimax13/is_click_data"
LOCAL_MODEL_PATH = f"{MODEL_DIR}/{MODEL_FILENAME}" # Use config path
# Hugging Face API
api = HfApi()
class ModelConnector:
def __init__(self):
"""Initialize model connector and check if model exists."""
os.makedirs(MODEL_DIR, exist_ok=True) # Ensure directory exists
self.model = self.load_model()
def check_model_exists(self):
"""Check if the model exists on Hugging Face."""
try:
hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
return True
except Exception:
return False
def load_model(self):
"""Download and load the model from Hugging Face using is_click_predictor."""
if self.check_model_exists():
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
return joblib.load(model_path)
return None
def train_model(self):
"""Train a new model using is_click_predictor and upload it to Hugging Face."""
try:
# Load dataset
train_data_path = hf_hub_download(repo_id=DATA_REPO, filename="train_dataset_full.csv")
train_data = pd.read_csv(train_data_path)
X_train = train_data.drop(columns=["is_click"])
y_train = train_data["is_click"]
# Train model using `is_click_predictor`
models = train_models(X_train, y_train) # Uses RandomForest, CatBoost, XGBoost
rf_model = models["RandomForest"] # Use RF as default
# Save locally using `is_click_predictor`
save_models(models)
# Upload to Hugging Face
api.upload_file(
path_or_fileobj=LOCAL_MODEL_PATH,
path_in_repo=MODEL_FILENAME,
repo_id=MODEL_REPO,
)
self.model = rf_model # Update instance with trained model
return "Model trained and uploaded successfully!"
except Exception as e:
return f"Error during training: {str(e)}"
def retrain_model(self):
"""Retrain the existing model with new data using is_click_predictor."""
try:
# Load dataset
train_data_path = hf_hub_download(repo_id=DATA_REPO, filename="train_dataset_full.csv")
train_data = pd.read_csv(train_data_path)
X_train = train_data.drop(columns=["is_click"])
y_train = train_data["is_click"]
if self.model is None:
return "No existing model found. Train a new model first."
# Retrain using is_click_predictor
self.model.fit(X_train, y_train)
# Save and upload
save_models({"RandomForest": self.model})
api.upload_file(
path_or_fileobj=LOCAL_MODEL_PATH,
path_in_repo=MODEL_FILENAME,
repo_id=MODEL_REPO,
)
return "Model retrained and uploaded successfully!"
except Exception as e:
return f"Error during retraining: {str(e)}"
def predict(self, input_data):
"""Make predictions using is_click_predictor."""
if self.model is None:
return "No model found. Train the model first."
input_df = pd.DataFrame([input_data])
prediction = predict({"RandomForest": self.model}, input_df) # Use predict function
return int(prediction[0])
|