File size: 5,778 Bytes
4b1ee17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6bc454
4b1ee17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6bc454
 
 
 
 
 
 
 
 
4b1ee17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2f9a6d
4b1ee17
f6bc454
4b1ee17
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import numpy as np
from matplotlib import rcParams
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.utils import load_img, save_img, img_to_array
from tensorflow.keras.applications.vgg19 import preprocess_input
from tensorflow.keras.layers import GlobalAveragePooling2D
from pymilvus import connections, Collection, utility
from requests import get
from shutil import rmtree
import streamlit as st
import zipfile

# unzip vegetable images
with zipfile.ZipFile("Vegetable Images.zip", 'r') as zip_ref:
    zip_ref.extractall('.')

placeholder = st.empty()

class ImageVectorizer:
    '''
    Get vector representation of an image using VGG19 model fine tuned on vegetable images for classification
    '''
    
    def __init__(self):
        self.__model = self.get_model()
    
    @staticmethod
    def get_model():
        model = load_model('vegetable_classification_model_vgg.h5') # loading saved VGG model finetuned on vegetable images for classification
        top = model.get_layer('block5_pool').output
        top = GlobalAveragePooling2D()(top)
        model = Model(inputs=model.input, outputs=top)
        return model
    
    def vectorize(self, img_path: str):
        model = self.__model
        test_image = load_img(img_path, color_mode="rgb", target_size=(224, 224))
        test_image = img_to_array(test_image)
        test_image = preprocess_input(test_image)
        test_image = np.array([test_image])
        return model(test_image).numpy()[0]


def get_milvus_collection():
    uri = os.environ.get("URI")
    token = os.environ.get("TOKEN")
    connections.connect("default", uri=uri, token=token)
    print(f"Connected to DB")
    collection_name = os.environ.get("COLLECTION_NAME")
    collection = Collection(name=collection_name)
    collection.load()
    return collection


def plot_images(input_image_path: str, similar_img_paths: list):
    # plotting similar images
    rows = 5 # rows in subplots
    cols = 3 # columns in subplots
    fig, ax = plt.subplots(rows, cols, figsize=(12, 20))
    r = 0
    c = 0
    for i in range(rows*cols):
        sim_image = load_img(similar_img_paths[i], color_mode="rgb", target_size=(224, 224))
        ax[r,c].axis("off")
        ax[r,c].imshow(sim_image)
        c += 1
        if c == cols:
            c = 0
            r += 1
    plt.subplots_adjust(wspace=0.01, hspace=0.01)

    # display input image
    rcParams.update({'figure.autolayout': True})
    input_image = load_img(input_image_path, color_mode="rgb", target_size=(224, 224))
    with placeholder.container():
        st.markdown('<p style="font-size: 20px; font-weight: bold">Input image</p>', unsafe_allow_html=True)
        st.image(input_image)
    
        st.write('  \n')
    
        # display similar images
        st.markdown('<p style="font-size: 20px; font-weight: bold">Similar images</p>', unsafe_allow_html=True)
        st.pyplot(fig)


def find_similar_images(img_path: str, top_n: int=15):
    search_params = {"metric_type": "L2"}
    search_vec = vectorizer.vectorize(img_path)
    result = collection.search([search_vec],
                                anns_field='image_vector', # annotation field specified in the schema definition
                                param=search_params,
                                limit=top_n,
                                guarantee_timestamp=1, 
                                output_fields=['image_path']) # which fields to return in output
    
    output_dict = {"input_image_path": img_path, "similar_image_paths": [hit.entity.get('image_path') for hits in result for hit in hits]} 
    plot_images(output_dict['input_image_path'], output_dict['similar_image_paths'])


def delete_file(path_: str):
    if os.path.exists(path_):
        rmtree(path=path_, ignore_errors=True)


def process_input_image(img_url):
    upload_file_path = os.path.join('.', 'uploads')
    os.makedirs(upload_file_path, exist_ok=True)
    upload_filename = "input.jpg"
    upload_file_path = os.path.join(upload_file_path, upload_filename)
    headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36'}
    r = get(img_url, headers=headers)
    with open(upload_file_path, "wb") as file:
        file.write(r.content)
    return upload_file_path


vectorizer = ImageVectorizer()
collection = get_milvus_collection()


def main():
    try:
        st.markdown("<h3>Find Similar Vegetable Images</h3>", unsafe_allow_html=True)
        desc = '''<p style="font-size: 15px;">Allowed Vegetables: Broad Bean, Bitter Gourd, Bottle Gourd, 
        Green Round Brinjal, Broccoli, Cabbage, Capsicum, Carrot, Cauliflower, Cucumber, 
        Raw Papaya, Potato, Green Pumpkin, Radish, Tomato.
        </p> 
        <p style="font-size: 13px;">Image embeddings are extracted from a fine-tuned VGG model. The model is fine-tuned on <a href="https://www.kaggle.com/datasets/misrakahmed/vegetable-image-dataset" target="_blank">images</a> clicked using a mobile phone camera. 
        Embeddings of 20,000 vegetable images are stored in Milvus vector database. Embeddings of the input image are computed and 15 most similar images (based on L2 distance) are displayed.</p>
        '''
        st.markdown(desc, unsafe_allow_html=True)
        img_url = st.text_input("Paste the image URL of a vegetable and hit Enter:", "")
        if img_url:
            placeholder.empty()
            img_path = process_input_image(img_url)
            find_similar_images(img_path, 15)
            delete_file(os.path.dirname(img_path))
    except Exception as e:
        st.error(f'An unexpected error occured:  \n{e}')


if __name__ == "__main__":
    main()