|
import os |
|
import joblib |
|
import pandas as pd |
|
from huggingface_hub import hf_hub_download, HfApi |
|
from is_click_predictor.model_trainer import train_models |
|
from is_click_predictor.model_manager import save_models, load_models |
|
from is_click_predictor.model_predictor import predict |
|
from is_click_predictor.config import MODEL_DIR |
|
|
|
|
|
MODEL_REPO = "taimax13/is_click_predictor" |
|
MODEL_FILENAME = "rf_model.pkl" |
|
DATA_REPO = "taimax13/is_click_data" |
|
LOCAL_MODEL_PATH = f"{MODEL_DIR}/{MODEL_FILENAME}" |
|
|
|
|
|
api = HfApi() |
|
|
|
|
|
class ModelConnector: |
|
def __init__(self): |
|
"""Initialize model connector and check if model exists.""" |
|
os.makedirs(MODEL_DIR, exist_ok=True) |
|
self.model = self.load_model() |
|
|
|
def check_model_exists(self): |
|
"""Check if the model exists on Hugging Face.""" |
|
try: |
|
hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) |
|
return True |
|
except Exception: |
|
return False |
|
|
|
def load_model(self): |
|
"""Download and load the model from Hugging Face using is_click_predictor.""" |
|
if self.check_model_exists(): |
|
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) |
|
return joblib.load(model_path) |
|
return None |
|
|
|
def train_model(self): |
|
"""Train a new model using is_click_predictor and upload it to Hugging Face.""" |
|
try: |
|
|
|
train_data_path = hf_hub_download(repo_id=DATA_REPO, filename="train_dataset_full.csv") |
|
train_data = pd.read_csv(train_data_path) |
|
|
|
X_train = train_data.drop(columns=["is_click"]) |
|
y_train = train_data["is_click"] |
|
|
|
|
|
models = train_models(X_train, y_train) |
|
rf_model = models["RandomForest"] |
|
|
|
|
|
save_models(models) |
|
|
|
|
|
api.upload_file( |
|
path_or_fileobj=LOCAL_MODEL_PATH, |
|
path_in_repo=MODEL_FILENAME, |
|
repo_id=MODEL_REPO, |
|
) |
|
|
|
self.model = rf_model |
|
return "Model trained and uploaded successfully!" |
|
|
|
except Exception as e: |
|
return f"Error during training: {str(e)}" |
|
|
|
def retrain_model(self): |
|
"""Retrain the existing model with new data using is_click_predictor.""" |
|
try: |
|
|
|
train_data_path = hf_hub_download(repo_id=DATA_REPO, filename="train_dataset_full.csv") |
|
train_data = pd.read_csv(train_data_path) |
|
|
|
X_train = train_data.drop(columns=["is_click"]) |
|
y_train = train_data["is_click"] |
|
|
|
if self.model is None: |
|
return "No existing model found. Train a new model first." |
|
|
|
|
|
self.model.fit(X_train, y_train) |
|
|
|
|
|
save_models({"RandomForest": self.model}) |
|
api.upload_file( |
|
path_or_fileobj=LOCAL_MODEL_PATH, |
|
path_in_repo=MODEL_FILENAME, |
|
repo_id=MODEL_REPO, |
|
) |
|
|
|
return "Model retrained and uploaded successfully!" |
|
|
|
except Exception as e: |
|
return f"Error during retraining: {str(e)}" |
|
|
|
def predict(self, input_data): |
|
"""Make predictions using is_click_predictor.""" |
|
if self.model is None: |
|
return "No model found. Train the model first." |
|
|
|
input_df = pd.DataFrame([input_data]) |
|
prediction = predict({"RandomForest": self.model}, input_df) |
|
return int(prediction[0]) |
|
|