Refactor sketch recognition app: simplify prediction function, update image handling, and remove OpenCV dependency
Browse files- app.py +15 -18
- requirements.txt +0 -1
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import numpy as np
|
2 |
import gradio as gr
|
3 |
import tensorflow as tf
|
4 |
-
import cv2
|
5 |
|
6 |
# App title
|
7 |
title = "Welcome to your first sketch recognition app!"
|
@@ -29,22 +28,20 @@ model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
|
|
29 |
|
30 |
# Prediction function for sketch recognition
|
31 |
def predict(data):
|
32 |
-
#
|
33 |
-
img = data
|
34 |
-
#
|
35 |
-
|
36 |
-
#
|
37 |
-
|
38 |
-
#
|
39 |
-
|
40 |
-
#
|
41 |
-
|
42 |
-
#
|
43 |
-
|
44 |
-
#
|
45 |
-
|
46 |
-
# Return the probability for each class
|
47 |
-
return {label: float(pred) for label, pred in zip(labels, preds)}
|
48 |
|
49 |
# Top 3 classes
|
50 |
label = gr.Label(num_top_classes=3)
|
@@ -52,7 +49,7 @@ label = gr.Label(num_top_classes=3)
|
|
52 |
# Open Gradio interface for sketch recognition
|
53 |
interface = gr.Interface(
|
54 |
fn=predict,
|
55 |
-
inputs=gr.Sketchpad(),
|
56 |
outputs=label,
|
57 |
title=title,
|
58 |
description=head,
|
|
|
1 |
import numpy as np
|
2 |
import gradio as gr
|
3 |
import tensorflow as tf
|
|
|
4 |
|
5 |
# App title
|
6 |
title = "Welcome to your first sketch recognition app!"
|
|
|
28 |
|
29 |
# Prediction function for sketch recognition
|
30 |
def predict(data):
|
31 |
+
# Reshape image to 28x28
|
32 |
+
img = np.reshape(data, (1, img_size, img_size, 1))
|
33 |
+
# Make prediction
|
34 |
+
pred = model.predict(img)
|
35 |
+
# Get top class
|
36 |
+
top_class = np.argmax
|
37 |
+
# Get top 3 classes
|
38 |
+
top_3_classes = np.argsort(pred[0])[-3:][::-1]
|
39 |
+
# Get top 3 probabilities
|
40 |
+
top_3_probs = pred[0][top_3_classes]
|
41 |
+
# Get class names
|
42 |
+
class_names = [labels[i] for i in top_3_classes]
|
43 |
+
# Return class names and probabilities
|
44 |
+
return {class_names[i]: top_3_probs[i] for i in range(3)}
|
|
|
|
|
45 |
|
46 |
# Top 3 classes
|
47 |
label = gr.Label(num_top_classes=3)
|
|
|
49 |
# Open Gradio interface for sketch recognition
|
50 |
interface = gr.Interface(
|
51 |
fn=predict,
|
52 |
+
inputs=gr.Sketchpad(crop_size=(28,28), type='numpy', image_mode='L', brush=gr.Brush()),
|
53 |
outputs=label,
|
54 |
title=title,
|
55 |
description=head,
|
requirements.txt
CHANGED
@@ -1,3 +1,2 @@
|
|
1 |
tensorflow
|
2 |
-
opencv-python-headless
|
3 |
numpy
|
|
|
1 |
tensorflow
|
|
|
2 |
numpy
|