File size: 3,250 Bytes
158f4dc
0f734ea
 
158f4dc
 
0f734ea
158f4dc
 
 
0f734ea
 
 
158f4dc
 
0f734ea
ea323d5
158f4dc
 
0f734ea
 
158f4dc
0f734ea
 
158f4dc
 
 
0f734ea
158f4dc
 
 
 
 
 
 
0f734ea
 
158f4dc
 
 
 
 
0f734ea
 
158f4dc
 
 
0f734ea
158f4dc
 
 
 
0f734ea
158f4dc
 
 
 
 
 
 
 
 
 
0f734ea
158f4dc
0f734ea
 
 
158f4dc
0f734ea
 
 
 
158f4dc
0f734ea
 
 
 
 
 
 
158f4dc
 
0f734ea
 
 
158f4dc
0f734ea
158f4dc
 
 
 
0f734ea
158f4dc
 
 
0f734ea
158f4dc
 
 
0f734ea
5f05cbf
158f4dc
5f05cbf
158f4dc
 
0f734ea
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import time

import streamlit as st
from PIL import Image

from model import ViTForImageClassification

st.set_page_config(
    page_title="Grocery Classifier",
    page_icon="interface/shopping-cart.png",
    initial_sidebar_state="expanded",
)


@st.cache()
def load_model():
    with st.spinner("Loading model"):
        model = ViTForImageClassification("google/vit-base-patch16-224")
        model.load("model/")
    return model


model = load_model()
feedback_path = "feedback"


def predict(image):
    print("Predicting...")
    # Load using PIL
    image = Image.open(image)

    prediction, confidence = model.predict(image)

    return {"prediction": prediction[0], "confidence": round(confidence[0], 3)}, image


def submit_feedback(correct_label, image):
    folder_path = feedback_path + "/" + correct_label + "/"
    os.makedirs(folder_path, exist_ok=True)
    image.save(folder_path + correct_label + "_" + str(int(time.time())) + ".png")


def retrain_from_feedback():
    model.retrain_from_path(feedback_path, remove_path=True)


def main():
    labels = set(list(model.label_encoder.classes_))

    st.title("πŸ‡ Grocery Classifier πŸ₯‘")

    if labels is None:
        st.warning("Received error from server, labels could not be retrieved")
    else:
        st.write("Labels:", labels)

    image_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
    if image_file is not None:
        st.image(image_file)

        st.subheader("Classification")

        if st.button("Predict"):
            st.session_state["response_json"], st.session_state["image"] = predict(
                image_file
            )

        if (
            "response_json" in st.session_state
            and st.session_state["response_json"] is not None
        ):
            # Show the result
            st.markdown(
                f"**Prediction:** {st.session_state['response_json']['prediction']}"
            )
            st.markdown(
                f"**Confidence:** {st.session_state['response_json']['confidence']}"
            )

            # User feedback
            st.subheader("User Feedback")
            st.markdown(
                "If this prediction was incorrect, please select below the correct label"
            )
            correct_labels = labels.copy()
            correct_labels.remove(st.session_state["response_json"]["prediction"])
            correct_label = st.selectbox("Correct label", correct_labels)
            if st.button("Submit"):
                # Save feedback
                try:
                    submit_feedback(correct_label, st.session_state["image"])
                    st.success("Feedback submitted")
                except Exception as e:
                    st.error("Feedback could not be submitted. Error: {}".format(e))

            # Retrain from feedback
            if st.button("Retrain from feedback"):
                try:
                    with st.spinner("Retraining..."):
                        retrain_from_feedback()
                    st.success("Model retrained")
                    st.balloons()
                except Exception as e:
                    st.warning("Model could not be retrained. Error: {}".format(e))


main()