|
import os |
|
import time |
|
|
|
import streamlit as st |
|
from PIL import Image |
|
|
|
from model import ViTForImageClassification |
|
|
|
st.set_page_config( |
|
page_title="Grocery Classifier", |
|
page_icon="interface/shopping-cart.png", |
|
initial_sidebar_state="expanded", |
|
) |
|
|
|
|
|
@st.cache() |
|
def load_model(): |
|
with st.spinner("Loading model"): |
|
model = ViTForImageClassification("google/vit-base-patch16-224") |
|
model.load("model/") |
|
return model |
|
|
|
|
|
model = load_model() |
|
feedback_path = "feedback" |
|
|
|
|
|
def predict(image): |
|
print("Predicting...") |
|
|
|
image = Image.open(image) |
|
|
|
prediction, confidence = model.predict(image) |
|
|
|
return {"prediction": prediction[0], "confidence": round(confidence[0], 3)}, image |
|
|
|
|
|
def submit_feedback(correct_label, image): |
|
folder_path = feedback_path + "/" + correct_label + "/" |
|
os.makedirs(folder_path, exist_ok=True) |
|
image.save(folder_path + correct_label + "_" + str(int(time.time())) + ".png") |
|
|
|
|
|
def retrain_from_feedback(): |
|
model.retrain_from_path(feedback_path, remove_path=True) |
|
|
|
|
|
def main(): |
|
labels = set(list(model.label_encoder.classes_)) |
|
|
|
st.title("π Grocery Classifier π₯") |
|
|
|
if labels is None: |
|
st.warning("Received error from server, labels could not be retrieved") |
|
else: |
|
st.write("Labels:", labels) |
|
|
|
image_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
if image_file is not None: |
|
st.image(image_file) |
|
|
|
st.subheader("Classification") |
|
|
|
if st.button("Predict"): |
|
st.session_state["response_json"], st.session_state["image"] = predict( |
|
image_file |
|
) |
|
|
|
if ( |
|
"response_json" in st.session_state |
|
and st.session_state["response_json"] is not None |
|
): |
|
|
|
st.markdown( |
|
f"**Prediction:** {st.session_state['response_json']['prediction']}" |
|
) |
|
st.markdown( |
|
f"**Confidence:** {st.session_state['response_json']['confidence']}" |
|
) |
|
|
|
|
|
st.subheader("User Feedback") |
|
st.markdown( |
|
"If this prediction was incorrect, please select below the correct label" |
|
) |
|
correct_labels = labels.copy() |
|
correct_labels.remove(st.session_state["response_json"]["prediction"]) |
|
correct_label = st.selectbox("Correct label", correct_labels) |
|
if st.button("Submit"): |
|
|
|
try: |
|
submit_feedback(correct_label, st.session_state["image"]) |
|
st.success("Feedback submitted") |
|
except Exception as e: |
|
st.error("Feedback could not be submitted. Error: {}".format(e)) |
|
|
|
|
|
if st.button("Retrain from feedback"): |
|
try: |
|
with st.spinner("Retraining..."): |
|
retrain_from_feedback() |
|
st.success("Model retrained") |
|
st.balloons() |
|
except Exception as e: |
|
st.warning("Model could not be retrained. Error: {}".format(e)) |
|
|
|
|
|
main() |
|
|