ugaray96's picture
Refactor and improve model, app, and training components
0f734ea unverified
raw
history blame
3.26 kB
import os
import time
import streamlit as st
from PIL import Image
from model import ViTForImageClassification
st.set_page_config(
page_title="Grocery Classifier",
page_icon="interface/shopping-cart.png",
initial_sidebar_state="expanded",
)
@st.cache_resource()
def load_model():
with st.spinner("Loading model"):
model = ViTForImageClassification("google/vit-base-patch16-224")
model.load("model/")
return model
model = load_model()
feedback_path = "feedback"
def predict(image):
print("Predicting...")
# Load using PIL
image = Image.open(image)
prediction, confidence = model.predict(image)
return {"prediction": prediction[0], "confidence": round(confidence[0], 3)}, image
def submit_feedback(correct_label, image):
folder_path = feedback_path + "/" + correct_label + "/"
os.makedirs(folder_path, exist_ok=True)
image.save(folder_path + correct_label + "_" + str(int(time.time())) + ".png")
def retrain_from_feedback():
model.retrain_from_path(feedback_path, remove_path=True)
def main():
labels = set(list(model.label_encoder.classes_))
st.title("πŸ‡ Grocery Classifier πŸ₯‘")
if labels is None:
st.warning("Received error from server, labels could not be retrieved")
else:
st.write("Labels:", labels)
image_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if image_file is not None:
st.image(image_file)
st.subheader("Classification")
if st.button("Predict"):
st.session_state["response_json"], st.session_state["image"] = predict(
image_file
)
if (
"response_json" in st.session_state
and st.session_state["response_json"] is not None
):
# Show the result
st.markdown(
f"**Prediction:** {st.session_state['response_json']['prediction']}"
)
st.markdown(
f"**Confidence:** {st.session_state['response_json']['confidence']}"
)
# User feedback
st.subheader("User Feedback")
st.markdown(
"If this prediction was incorrect, please select below the correct label"
)
correct_labels = labels.copy()
correct_labels.remove(st.session_state["response_json"]["prediction"])
correct_label = st.selectbox("Correct label", correct_labels)
if st.button("Submit"):
# Save feedback
try:
submit_feedback(correct_label, st.session_state["image"])
st.success("Feedback submitted")
except Exception as e:
st.error("Feedback could not be submitted. Error: {}".format(e))
# Retrain from feedback
if st.button("Retrain from feedback"):
try:
with st.spinner("Retraining..."):
retrain_from_feedback()
st.success("Model retrained")
st.balloons()
except Exception as e:
st.warning("Model could not be retrained. Error: {}".format(e))
main()