File size: 3,923 Bytes
fa91fe4
 
 
 
6c67532
 
 
 
fa91fe4
 
 
 
 
6c67532
fa91fe4
 
 
 
 
 
 
 
6c67532
fa91fe4
 
 
 
 
 
 
 
 
 
 
6c67532
fa91fe4
 
 
 
 
 
6c67532
fa91fe4
 
 
 
 
 
 
 
6c67532
 
 
fa91fe4
6c67532
 
fa91fe4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c67532
fa91fe4
 
 
 
 
 
 
 
 
 
 
6c67532
fa91fe4
 
6c67532
 
fa91fe4
 
 
 
 
 
 
 
 
 
 
 
6c67532
fa91fe4
 
 
 
6c67532
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import joblib
import pandas as pd
from huggingface_hub import hf_hub_download, HfApi
from is_click_predictor.model_trainer import train_models
from is_click_predictor.model_manager import save_models, load_models
from is_click_predictor.model_predictor import predict
from is_click_predictor.config import MODEL_DIR  # Ensure consistency

# Hugging Face Model & Dataset Information
MODEL_REPO = "taimax13/is_click_predictor"
MODEL_FILENAME = "rf_model.pkl"
DATA_REPO = "taimax13/is_click_data"
LOCAL_MODEL_PATH = f"{MODEL_DIR}/{MODEL_FILENAME}"  # Use config path

# Hugging Face API
api = HfApi()


class ModelConnector:
    def __init__(self):
        """Initialize model connector and check if model exists."""
        os.makedirs(MODEL_DIR, exist_ok=True)  # Ensure directory exists
        self.model = self.load_model()

    def check_model_exists(self):
        """Check if the model exists on Hugging Face."""
        try:
            hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
            return True
        except Exception:
            return False

    def load_model(self):
        """Download and load the model from Hugging Face using is_click_predictor."""
        if self.check_model_exists():
            model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
            return joblib.load(model_path)
        return None

    def train_model(self):
        """Train a new model using is_click_predictor and upload it to Hugging Face."""
        try:
            # Load dataset
            train_data_path = hf_hub_download(repo_id=DATA_REPO, filename="train_dataset_full.csv")
            train_data = pd.read_csv(train_data_path)

            X_train = train_data.drop(columns=["is_click"])
            y_train = train_data["is_click"]

            # Train model using `is_click_predictor`
            models = train_models(X_train, y_train)  # Uses RandomForest, CatBoost, XGBoost
            rf_model = models["RandomForest"]  # Use RF as default

            # Save locally using `is_click_predictor`
            save_models(models)

            # Upload to Hugging Face
            api.upload_file(
                path_or_fileobj=LOCAL_MODEL_PATH,
                path_in_repo=MODEL_FILENAME,
                repo_id=MODEL_REPO,
            )

            self.model = rf_model  # Update instance with trained model
            return "Model trained and uploaded successfully!"

        except Exception as e:
            return f"Error during training: {str(e)}"

    def retrain_model(self):
        """Retrain the existing model with new data using is_click_predictor."""
        try:
            # Load dataset
            train_data_path = hf_hub_download(repo_id=DATA_REPO, filename="train_dataset_full.csv")
            train_data = pd.read_csv(train_data_path)

            X_train = train_data.drop(columns=["is_click"])
            y_train = train_data["is_click"]

            if self.model is None:
                return "No existing model found. Train a new model first."

            # Retrain using is_click_predictor
            self.model.fit(X_train, y_train)

            # Save and upload
            save_models({"RandomForest": self.model})
            api.upload_file(
                path_or_fileobj=LOCAL_MODEL_PATH,
                path_in_repo=MODEL_FILENAME,
                repo_id=MODEL_REPO,
            )

            return "Model retrained and uploaded successfully!"

        except Exception as e:
            return f"Error during retraining: {str(e)}"

    def predict(self, input_data):
        """Make predictions using is_click_predictor."""
        if self.model is None:
            return "No model found. Train the model first."

        input_df = pd.DataFrame([input_data])
        prediction = predict({"RandomForest": self.model}, input_df)  # Use predict function
        return int(prediction[0])