mohli commited on
Commit
6709385
·
verified ·
1 Parent(s): ed96a49

Upload RandomForestClass.py

Browse files
Files changed (1) hide show
  1. RandomForest/RandomForestClass.py +101 -0
RandomForest/RandomForestClass.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ import pandas as pd
4
+ import numpy as np
5
+ from sklearn.ensemble import RandomForestClassifier
6
+ from sklearn.metrics import accuracy_score, classification_report
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.preprocessing import MinMaxScaler
9
+ import matplotlib.pyplot as plt
10
+ import gradio as gr
11
+
12
+
13
+
14
+ class My_RandomForest:
15
+ def __init__(self):
16
+ self.target_column = "Experience_Level" # Change to suit your classification target
17
+ self.models = {
18
+ "Male": None,
19
+ "Female": None,
20
+ "Unspecified": None
21
+ }
22
+
23
+ # Default parameters
24
+ self.n_estimators = 10000 # Number of trees
25
+ self.max_depth = 4 # Maximum tree depth
26
+ self.max_features = 'sqrt'
27
+ self.criterion = 'gini'
28
+
29
+ self.accuracies = {"Male": None, "Female": None, "Unspecified": None} # Store accuracies
30
+
31
+ self.selected_features = {
32
+ "Male": ["Workout_Frequency (days/week)", "Session_Duration (hours)", "Water_Intake (liters)"],
33
+ "Female": ["Workout_Frequency (days/week)", "Session_Duration (hours)", "Water_Intake (liters)"],
34
+ "Unspecified": ["Workout_Frequency (days/week)", "Session_Duration (hours)", "Water_Intake (liters)"]
35
+ }
36
+
37
+ self.scaler = MinMaxScaler() # Initialize the scaler
38
+ self.init_dataset()
39
+
40
+ def init_dataset(self):
41
+ # Load the dataset
42
+ csv_file = os.path.join("app", "data", "gym_members_exercise_tracking.csv")
43
+ df_original = pd.read_csv(csv_file)
44
+ self.df_original = df_original
45
+
46
+ def train_model(self, gender="Unspecified"):
47
+ if gender not in self.models:
48
+ raise ValueError("Invalid gender specified. Choose from 'Male', 'Female', or 'Unspecified'.")
49
+
50
+ # Filter data by gender for training specific models
51
+ if gender == "Male":
52
+ df_filtered = self.df_original[self.df_original["Gender"] == "Male"]
53
+ elif gender == "Female":
54
+ df_filtered = self.df_original[self.df_original["Gender"] == "Female"]
55
+ else:
56
+ df_filtered = self.df_original # Use all data for Unspecified
57
+
58
+ features = self.selected_features[gender]
59
+ X = df_filtered[features]
60
+ y = df_filtered[self.target_column]
61
+
62
+ # Split the data into training and testing sets
63
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
64
+
65
+ # Fit the scaler on the training data and transform both sets
66
+ self.scaler.fit(X_train)
67
+ X_train = self.scaler.transform(X_train)
68
+ X_test = self.scaler.transform(X_test)
69
+
70
+ # Initialize and train the Random Forest model
71
+ model = RandomForestClassifier(
72
+ n_estimators=self.n_estimators,
73
+ max_depth=self.max_depth,
74
+ max_features=self.max_features,
75
+ criterion=self.criterion,
76
+ random_state=42
77
+ )
78
+ model.fit(X_train, y_train)
79
+
80
+ # Evaluate the model
81
+ y_pred = model.predict(X_test)
82
+ accuracy = accuracy_score(y_test, y_pred)
83
+
84
+ #print(f"{gender} Model Accuracy: {accuracy:.4f}")
85
+ #print(f"{gender} Model Classification Report:")
86
+ #print(classification_report(y_test, y_pred))
87
+
88
+ self.models[gender] = model
89
+ self.accuracies[gender] = accuracy # Store the accuracy
90
+
91
+ def predict(self, input_data: pd.DataFrame, gender="Unspecified"):
92
+ if gender not in self.models or self.models[gender] is None:
93
+ raise ValueError(f"Model for {gender} is not trained yet.")
94
+
95
+ features = self.selected_features[gender]
96
+ scaled_input = self.scaler.transform(input_data[features])
97
+ prediction = self.models[gender].predict(scaled_input)
98
+ return prediction
99
+
100
+
101
+