ugaray96's picture
Streamlit cache resource for cache
ea323d5 unverified
import os
import time
import streamlit as st
from PIL import Image
from model import ViTForImageClassification
st.set_page_config(
page_title="Grocery Classifier",
page_icon="interface/shopping-cart.png",
initial_sidebar_state="expanded",
)
@st.cache()
def load_model():
with st.spinner("Loading model"):
model = ViTForImageClassification("google/vit-base-patch16-224")
model.load("model/")
return model
model = load_model()
feedback_path = "feedback"
def predict(image):
print("Predicting...")
# Load using PIL
image = Image.open(image)
prediction, confidence = model.predict(image)
return {"prediction": prediction[0], "confidence": round(confidence[0], 3)}, image
def submit_feedback(correct_label, image):
folder_path = feedback_path + "/" + correct_label + "/"
os.makedirs(folder_path, exist_ok=True)
image.save(folder_path + correct_label + "_" + str(int(time.time())) + ".png")
def retrain_from_feedback():
model.retrain_from_path(feedback_path, remove_path=True)
def main():
labels = set(list(model.label_encoder.classes_))
st.title("πŸ‡ Grocery Classifier πŸ₯‘")
if labels is None:
st.warning("Received error from server, labels could not be retrieved")
else:
st.write("Labels:", labels)
image_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if image_file is not None:
st.image(image_file)
st.subheader("Classification")
if st.button("Predict"):
st.session_state["response_json"], st.session_state["image"] = predict(
image_file
)
if (
"response_json" in st.session_state
and st.session_state["response_json"] is not None
):
# Show the result
st.markdown(
f"**Prediction:** {st.session_state['response_json']['prediction']}"
)
st.markdown(
f"**Confidence:** {st.session_state['response_json']['confidence']}"
)
# User feedback
st.subheader("User Feedback")
st.markdown(
"If this prediction was incorrect, please select below the correct label"
)
correct_labels = labels.copy()
correct_labels.remove(st.session_state["response_json"]["prediction"])
correct_label = st.selectbox("Correct label", correct_labels)
if st.button("Submit"):
# Save feedback
try:
submit_feedback(correct_label, st.session_state["image"])
st.success("Feedback submitted")
except Exception as e:
st.error("Feedback could not be submitted. Error: {}".format(e))
# Retrain from feedback
if st.button("Retrain from feedback"):
try:
with st.spinner("Retraining..."):
retrain_from_feedback()
st.success("Model retrained")
st.balloons()
except Exception as e:
st.warning("Model could not be retrained. Error: {}".format(e))
main()