File size: 5,435 Bytes
8ae5b50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ebbf5a
 
 
8ae5b50
 
 
 
 
 
 
 
 
238743c
8ae5b50
 
 
 
 
 
 
 
eb81ea9
8ae5b50
eb81ea9
 
 
 
 
 
8ae5b50
 
 
 
 
 
 
eb81ea9
 
 
03851c2
 
eb81ea9
 
 
 
 
 
 
8ae5b50
eb81ea9
 
 
 
 
 
 
 
8c6ecf6
eb81ea9
 
 
 
 
 
 
 
 
 
 
7fd2205
 
eb81ea9
 
 
 
 
 
8ae5b50
 
 
eb81ea9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import numpy as np
import pandas as pd
from matplotlib import rcParams
import matplotlib.pyplot as plt
from requests import get
import streamlit as st
import cv2
from ultralytics import YOLO
import shutil


PREDICTION_PATH = os.path.join('.', 'predictions')
UPLOAD_PATH = os.path.join('.', 'uploads')

@st.cache_resource
def load_od_model():
    finetuned_model = YOLO('license_plate_detection_best.pt')
    return finetuned_model


def decode_license_number(names):
    lic_num = ''
    try:
        df = pd.read_csv(os.path.join(PREDICTION_PATH, 'predict', 'labels', 'input.txt'), sep=' ', names=['x', 'y', 'w', 'h'])
        df['y_h'] = df['y'] + (df['h'] / 2)
        df.sort_values(by='x', inplace=True)
    
        center = min(df['y_h'])
        bottom = max(df['y_h'])
         
    
        lic_num = "".join([names[i] for i in [*df.index]])
        # if np.abs(bottom / center - 1) > 0.6:
        if np.abs(bottom - center) >= 0.9*(np.max(df['h'])):
            fn = [*df.loc[df['y'] <= center, :].index]
            ln = [*df.loc[df['y'] >= center, :].index]
            lic_num = "".join([names[i] for i in fn+ln])
    except:
        pass
    
    return lic_num


    # center = min(df['y_h'])
    # bottom = max(df['y_h'])
    # if (bottom - center) > np.mean(df['h'])/2:
    #     fn = [*df.loc[df['y'] <= center, :].index]
    #     ln = [*df.loc[df['y'] >= center, :].index]
    #     lic_num = "".join([names[i] for i in fn+ln])
    # else:
    #     lic_num = "".join([names[i] for i in [*df.index]])
    # return lic_num


def inference(input_image_path: str):
    finetuned_model = load_od_model()
    results = finetuned_model.predict(input_image_path, 
                        show=False, 
                        save=True,
                        save_crop=False,
                        imgsz=640, 
                        conf=0.6, 
                        save_txt=True,  
                        project= PREDICTION_PATH, 
                        show_labels=True,
                        show_conf=False,
                        line_width=2,
                        exist_ok=True)

    names = finetuned_model.names

    lic_num = decode_license_number(names)
    
    # with placeholder.container():
    st.markdown(f"<h5>Detected Number: {lic_num}</h5>", unsafe_allow_html=True)
    st.image(os.path.join(PREDICTION_PATH, 'predict', 'input.jpg'))
        

def files_cleanup(path_: str):
    if os.path.exists(path_):
        os.remove(path_)
    if os.path.exists(PREDICTION_PATH):
        shutil.rmtree(PREDICTION_PATH)


# @st.cache_resource
def get_upload_path():
    if not os.path.exists(UPLOAD_PATH):
        os.makedirs(UPLOAD_PATH)
    upload_filename = "input.jpg"
    upload_file_path = os.path.join(UPLOAD_PATH, upload_filename)
    return upload_file_path


def process_input_image(input_):
    upload_file_path = get_upload_path()
    if input_type == 'Paste image URL':
        headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36'}
        r = get(img_url, headers=headers)
        arr = np.frombuffer(r.content, np.uint8)
    else:
        arr = np.frombuffer(input_, np.uint8)
    input_image = cv2.imdecode(arr, cv2.IMREAD_UNCHANGED)
    input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
    input_image = cv2.resize(input_image, (640, 640))
    cv2.imwrite(upload_file_path, cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)) 
    return upload_file_path


st.markdown("<h3>Vehicle Registration Number Detection</h3>", unsafe_allow_html=True)
desc = '''Dataset used to fine-tune YOLOv8 
can be found <a href="https://universe.roboflow.com/my-workspace-cut7i/lpr_ocr/dataset/1" target="_blank">
here</a><br>See that the input images are similar to these for better detection.<br>Since the training data consists of images from the internet, 
this application may only properly detect images from the internet. 
'''
st.markdown(desc, unsafe_allow_html=True)
input_type = st.radio('Select an option:', ['Paste image URL', 'Capture using camera'], 
                      captions=['Recommended for laptops/desktops', 'Recommended for mobile devices'],
                      horizontal=True)


try:
    if input_type == 'Paste image URL':
        img_url = st.text_input("Paste the image URL of the vehicle license/registration plate:", "")
        if img_url:
            img_path = process_input_image(img_url)
            inference(img_path)
            files_cleanup(img_path)
    else:
        st.markdown("<h4>Capture the image in landscape mode</h4>", unsafe_allow_html=True)
        col1, col2, col3 = st.columns([0.2, 0.6, 0.2])
        with col1:
            st.write(' ')
        with col2:   
            img_file_buffer = st.camera_input("Take the number plate's picture")
            st.write(' ')
        with col3:
            st.write(' ')
        with col2:   
            st.write(' ')
        with col2:   
            st.write(' ')
        with col2:   
            st.write(' ')
        if img_file_buffer is not None:
            bytes_data = img_file_buffer.getvalue()
            img_path = process_input_image(bytes_data)
            inference(img_path)
            files_cleanup(img_path)
                
except Exception as e:
    upload_file_path = get_upload_path()
    files_cleanup(upload_file_path)
    st.error(f'An unexpected error occured:  \n{e}')