File size: 2,441 Bytes
d1e0895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import streamlit as st
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import cv2
from PIL import Image
import time

# Set page title and favicon
st.set_page_config(page_title="Cat and Dog Classifier", page_icon="🐱🐶")

# Load the pre-trained model
mobilenet_model = 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4'
pretrained_model = hub.KerasLayer(mobilenet_model, input_shape=(224, 224, 3), trainable=False)
num_of_classes = 2
model = tf.keras.Sequential([
    pretrained_model,
    tf.keras.layers.Dense(num_of_classes)
])

model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['acc']
)

model.load_weights("cat_dog_classifier.h5")

# Define functions for image resizing and classification
def preprocess_image(image):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (224, 224))
    image = image / 255.0
    image = np.expand_dims(image, axis=0)
    return image

def classify_image(image):
    image = preprocess_image(image)
    prediction = model.predict(image)
    return np.argmax(prediction)

# Sidebar
st.sidebar.header("Cat and Dog Classifier")
uploaded_image = st.sidebar.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])

# Main content
st.title("Cat/Dog Image Classification")

if uploaded_image:
    with st.spinner("Uploading image..."):
        time.sleep(2)  # Simulate image upload process, replace with actual image upload logic

    st.success("Image upload complete!")

    image = Image.open(uploaded_image)
    st.image(image, caption="Uploaded Image", use_column_width=True)

    if st.button("Classify"):
        image = np.array(image)
        pred_label = classify_image(image)

        if pred_label == 0:
            st.write('<div style="font-size: 24px; color: white;">Prediction: It\'s a Cat 😺</div>', unsafe_allow_html=True)
        else:
            st.write('<div style="font-size: 24px; color: white;">Prediction: It\'s a Dog 🐶</div>', unsafe_allow_html=True)

# Add a footer with CSS for positioning
st.markdown(
    """
    <div style="position: fixed; bottom: 0; right: 10px; padding: 10px; color: white;">
        <a href="https://github.com/sg-sparsh-goyal" target="_blank" style="color: white; text-decoration: none;">
            ✨ Github
        </a><br>
        By Sparsh
    </div>
    """,
    unsafe_allow_html=True
)