File size: 3,100 Bytes
158f4dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import streamlit as st
from PIL import Image
import requests
import io
import time
from model import ViTForImageClassification

st.set_page_config(
     page_title="Grocery Classifier",
     page_icon="interface/shopping-cart.png",
     initial_sidebar_state="expanded"
)

@st.cache()
def load_model():
    with st.spinner("Loading model"):
        model = ViTForImageClassification('google/vit-base-patch16-224')
        model.load('model/')
    return model
        
model = load_model()
feedback_path = "feedback"

def predict(image):
    print("Predicting...")
    # Load using PIL
    image = Image.open(image)

    prediction, confidence = model.predict(image)

    return {'prediction': prediction[0], 'confidence': round(confidence[0], 3)}, image

def submit_feedback(correct_label, image):
    folder_path = feedback_path + "/" + correct_label + "/"
    os.makedirs(folder_path, exist_ok=True)
    image.save(folder_path + correct_label + "_" + str(int(time.time())) + ".png")
    
def retrain_from_feedback():
    model.retrain_from_path(feedback_path, remove_path=True)

def main():
    labels = set(list(model.label_encoder.classes_))

    st.title("πŸ‡ Grocery Classifier πŸ₯‘")
        
    if labels is None:
        st.warning("Received error from server, labels could not be retrieved")
    else:
        st.write("Labels:", labels)

    image_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
    if image_file is not None:
        st.image(image_file)

        st.subheader("Classification")
        
        if st.button("Predict"):
            st.session_state['response_json'], st.session_state['image'] = predict(image_file)

        if 'response_json' in st.session_state and st.session_state['response_json'] is not None:
            # Show the result
            st.markdown(f"**Prediction:** {st.session_state['response_json']['prediction']}")
            st.markdown(f"**Confidence:** {st.session_state['response_json']['confidence']}")
            
            # User feedback
            st.subheader("User Feedback")
            st.markdown("If this prediction was incorrect, please select below the correct label")
            correct_labels = labels.copy()
            correct_labels.remove(st.session_state['response_json']["prediction"])
            correct_label = st.selectbox("Correct label", correct_labels)
            if st.button("Submit"):
                # Save feedback
                try:
                    submit_feedback(correct_label, st.session_state['image'])
                    st.success("Feedback submitted")
                except Exception as e:
                    st.error("Feedback could not be submitted. Error: {}".format(e))
                    
            # Retrain from feedback
            if st.button("Retrain from feedback"):
                try:
                    retrain_from_feedback()
                    st.success("Model retrained")
                except Exception as e:
                    st.warning("Model could not be retrained. Error: {}".format(e))
                    
main()