File size: 3,250 Bytes
158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea ea323d5 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 158f4dc 0f734ea 5f05cbf 158f4dc 5f05cbf 158f4dc 0f734ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
import time
import streamlit as st
from PIL import Image
from model import ViTForImageClassification
st.set_page_config(
page_title="Grocery Classifier",
page_icon="interface/shopping-cart.png",
initial_sidebar_state="expanded",
)
@st.cache()
def load_model():
with st.spinner("Loading model"):
model = ViTForImageClassification("google/vit-base-patch16-224")
model.load("model/")
return model
model = load_model()
feedback_path = "feedback"
def predict(image):
print("Predicting...")
# Load using PIL
image = Image.open(image)
prediction, confidence = model.predict(image)
return {"prediction": prediction[0], "confidence": round(confidence[0], 3)}, image
def submit_feedback(correct_label, image):
folder_path = feedback_path + "/" + correct_label + "/"
os.makedirs(folder_path, exist_ok=True)
image.save(folder_path + correct_label + "_" + str(int(time.time())) + ".png")
def retrain_from_feedback():
model.retrain_from_path(feedback_path, remove_path=True)
def main():
labels = set(list(model.label_encoder.classes_))
st.title("π Grocery Classifier π₯")
if labels is None:
st.warning("Received error from server, labels could not be retrieved")
else:
st.write("Labels:", labels)
image_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if image_file is not None:
st.image(image_file)
st.subheader("Classification")
if st.button("Predict"):
st.session_state["response_json"], st.session_state["image"] = predict(
image_file
)
if (
"response_json" in st.session_state
and st.session_state["response_json"] is not None
):
# Show the result
st.markdown(
f"**Prediction:** {st.session_state['response_json']['prediction']}"
)
st.markdown(
f"**Confidence:** {st.session_state['response_json']['confidence']}"
)
# User feedback
st.subheader("User Feedback")
st.markdown(
"If this prediction was incorrect, please select below the correct label"
)
correct_labels = labels.copy()
correct_labels.remove(st.session_state["response_json"]["prediction"])
correct_label = st.selectbox("Correct label", correct_labels)
if st.button("Submit"):
# Save feedback
try:
submit_feedback(correct_label, st.session_state["image"])
st.success("Feedback submitted")
except Exception as e:
st.error("Feedback could not be submitted. Error: {}".format(e))
# Retrain from feedback
if st.button("Retrain from feedback"):
try:
with st.spinner("Retraining..."):
retrain_from_feedback()
st.success("Model retrained")
st.balloons()
except Exception as e:
st.warning("Model could not be retrained. Error: {}".format(e))
main()
|