firststreamlit / apps.py
namb0010's picture
Upload 6 files
a4f1c8f
import numpy as np
import pandas as pd
import streamlit as st
import plotly.graph_objects as go
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
iris_data = load_iris()
# separate the data into features and target
features = pd.DataFrame(
iris_data.data, columns=iris_data.feature_names
)
target = pd.Series(iris_data.target)
# split the data into train and test
x_train, x_test, y_train, y_test = train_test_split(
features, target, test_size=0.2, stratify=target
)
class StreamlitApp:
def __init__(self):
self.model = RandomForestClassifier()
def train_data(self):
self.model.fit(x_train, y_train)
return self.model
def construct_sidebar(self):
cols = [col for col in features.columns]
st.sidebar.markdown(
'<p class="header-style">Iris Data Classification</p>',
unsafe_allow_html=True
)
sepal_length = st.sidebar.selectbox(
f"Select {cols[0]}",
sorted(features[cols[0]].unique())
)
sepal_width = st.sidebar.selectbox(
f"Select {cols[1]}",
sorted(features[cols[1]].unique())
)
petal_length = st.sidebar.selectbox(
f"Select {cols[2]}",
sorted(features[cols[2]].unique())
)
petal_width = st.sidebar.selectbox(
f"Select {cols[3]}",
sorted(features[cols[3]].unique())
)
values = [sepal_length, sepal_width, petal_length, petal_width]
return values
def plot_pie_chart(self, probabilities):
fig = go.Figure(
data=[go.Pie(
labels=list(iris_data.target_names),
values=probabilities[0]
)]
)
fig = fig.update_traces(
hoverinfo='label+percent',
textinfo='value',
textfont_size=15
)
return fig
def construct_app(self):
self.train_data()
values = self.construct_sidebar()
values_to_predict = np.array(values).reshape(1, -1)
prediction = self.model.predict(values_to_predict)
prediction_str = iris_data.target_names[prediction[0]]
probabilities = self.model.predict_proba(values_to_predict)
st.markdown(
"""
<style>
.header-style {
font-size:25px;
font-family:sans-serif;
}
</style>
""",
unsafe_allow_html=True
)
st.markdown(
"""
<style>
.font-style {
font-size:20px;
font-family:sans-serif;
}
</style>
""",
unsafe_allow_html=True
)
st.markdown(
'<p class="header-style"> Iris Data Predictions </p>',
unsafe_allow_html=True
)
column_1, column_2 = st.columns(2)
column_1.markdown(
f'<p class="font-style" >Prediction </p>',
unsafe_allow_html=True
)
column_1.write(f"{prediction_str}")
column_2.markdown(
'<p class="font-style" >Probability </p>',
unsafe_allow_html=True
)
column_2.write(f"{probabilities[0][prediction[0]]}")
fig = self.plot_pie_chart(probabilities)
st.markdown(
'<p class="font-style" >Probability Distribution</p>',
unsafe_allow_html=True
)
st.plotly_chart(fig, use_container_width=True)
return self
sa = StreamlitApp()
sa.construct_app()