|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
|
|
from collections import OrderedDict |
|
from sklearn.datasets import make_classification |
|
from sklearn.ensemble import RandomForestClassifier |
|
|
|
def compare(number1, number2): |
|
if number1 > number2: |
|
number2 = number1 |
|
return number2 |
|
|
|
def do_train(random_state, n_samples, min_estimators, max_estimators): |
|
RANDOM_STATE = random_state |
|
|
|
|
|
X, y = make_classification( |
|
n_samples=n_samples, |
|
n_features=25, |
|
n_clusters_per_class=1, |
|
n_informative=15, |
|
random_state=RANDOM_STATE, |
|
) |
|
|
|
|
|
|
|
|
|
ensemble_clfs = [ |
|
( |
|
"RandomForestClassifier, max_features='sqrt'", |
|
RandomForestClassifier( |
|
warm_start=True, |
|
oob_score=True, |
|
max_features="sqrt", |
|
random_state=RANDOM_STATE, |
|
), |
|
), |
|
( |
|
"RandomForestClassifier, max_features='log2'", |
|
RandomForestClassifier( |
|
warm_start=True, |
|
max_features="log2", |
|
oob_score=True, |
|
random_state=RANDOM_STATE, |
|
), |
|
), |
|
( |
|
"RandomForestClassifier, max_features=None", |
|
RandomForestClassifier( |
|
warm_start=True, |
|
max_features=None, |
|
oob_score=True, |
|
random_state=RANDOM_STATE, |
|
), |
|
), |
|
] |
|
|
|
|
|
error_rate = OrderedDict((label, []) for label, _ in ensemble_clfs) |
|
|
|
|
|
min_estimators = 15 |
|
max_estimators = 150 |
|
|
|
for label, clf in ensemble_clfs: |
|
for i in range(min_estimators, max_estimators + 1, 5): |
|
clf.set_params(n_estimators=i) |
|
clf.fit(X, y) |
|
|
|
|
|
oob_error = 1 - clf.oob_score_ |
|
error_rate[label].append((i, oob_error)) |
|
|
|
|
|
fig, ax = plt.subplots() |
|
for label, clf_err in error_rate.items(): |
|
xs, ys = zip(*clf_err) |
|
ax.plot(xs, ys, label=label) |
|
|
|
ax.set_xlim(min_estimators, max_estimators) |
|
ax.set_xlabel("n_estimators") |
|
ax.set_ylabel("OOB error rate") |
|
ax.legend(loc="upper right") |
|
return fig |
|
|
|
model_card = f""" |
|
## Description |
|
The ``RandomForestClassifier`` is trained using bootstrap aggregation, where each new tree is fit from a bootstrap sample of the training observations $z_i = (x_i, y_i)$. |
|
The out-of-bag (OOB) error is the average error for each $z_i$ calculated using predictions from the trees that do not contain |
|
$z_i$ in their respective bootstrap sample. This allows the ``RandomForestClassifier`` to be fit and validated whilst being trained. |
|
You can play around with ``number of samples``, ``random seed``, ``min estimators`` and ``max estimators`` controlling the number of trees. |
|
The example demonstrates how the OOB error can be measured at the addition of each new tree during training. |
|
The resulting plot allows a practitioner to approximate a suitable value of ``n_estimators`` at which the error stabilizes. |
|
## Dataset |
|
Simulation data |
|
""" |
|
with gr.Blocks() as demo: |
|
gr.Markdown(''' |
|
<div> |
|
<h1 style='text-align: center'>Out-of-Bag(OOB) Errors for Random Forests</h1> |
|
</div> |
|
''') |
|
gr.Markdown(model_card) |
|
gr.Markdown("Author: <a href=\"https://huggingface.co/bharat-raghunathan\">Bharat Raghunathan</a>. Based on the example from <a href=\"https://scikit-learn.org/stable/auto_examples/ensemble/plot_ensemble_oob.html#sphx-glr-auto-examples-ensemble-plot-ensemble-oob-py\">scikit-learn</a>") |
|
n_samples = gr.Slider(minimum=500, maximum=5000, step=500, value=500, label="Number of samples") |
|
random_state = gr.Slider(minimum=0, maximum=2000, step=1, value=0, label="Random seed") |
|
min_estimators = gr.Slider(minimum=5, maximum=300, step=5, value=15, label="Minimum number of trees") |
|
max_estimators = gr.Slider(minimum=5, maximum=300, step=5, value=150, label="Maximum number of trees") |
|
|
|
min_estimators.change(compare, [min_estimators, max_estimators], max_estimators) |
|
with gr.Row(): |
|
with gr.Column(): |
|
plot = gr.Plot() |
|
|
|
n_samples.change(fn=do_train, inputs=[n_samples, random_state, min_estimators, max_estimators], outputs=[plot]) |
|
random_state.change(fn=do_train, inputs=[n_samples, random_state, min_estimators, max_estimators], outputs=[plot]) |
|
min_estimators.change(fn=do_train, inputs=[n_samples, random_state, min_estimators, max_estimators], outputs=[plot]) |
|
max_estimators.change(fn=do_train, inputs=[n_samples, random_state, min_estimators, max_estimators], outputs=[plot]) |
|
|
|
demo.queue().launch() |
|
|