alibayram commited on
Commit
9cc90a0
·
1 Parent(s): e6fdb4c

Refactor sketch recognition app: simplify prediction function, update image handling, and remove OpenCV dependency

Browse files
Files changed (2) hide show
  1. app.py +15 -18
  2. requirements.txt +0 -1
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import numpy as np
2
  import gradio as gr
3
  import tensorflow as tf
4
- import cv2
5
 
6
  # App title
7
  title = "Welcome to your first sketch recognition app!"
@@ -29,22 +28,20 @@ model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
29
 
30
  # Prediction function for sketch recognition
31
  def predict(data):
32
- # Extract the 'image' key from the input dictionary
33
- img = data['image']
34
- # Convert to NumPy array
35
- img = np.array(img)
36
- # Convert to grayscale
37
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
38
- # Resize image to 28x28
39
- img = cv2.resize(img, (img_size, img_size))
40
- # Normalize pixel values
41
- img = img / 255.0
42
- # Reshape image to match model input
43
- img = img.reshape(1, img_size, img_size, 1)
44
- # Model predictions
45
- preds = model.predict(img)[0]
46
- # Return the probability for each class
47
- return {label: float(pred) for label, pred in zip(labels, preds)}
48
 
49
  # Top 3 classes
50
  label = gr.Label(num_top_classes=3)
@@ -52,7 +49,7 @@ label = gr.Label(num_top_classes=3)
52
  # Open Gradio interface for sketch recognition
53
  interface = gr.Interface(
54
  fn=predict,
55
- inputs=gr.Sketchpad(),
56
  outputs=label,
57
  title=title,
58
  description=head,
 
1
  import numpy as np
2
  import gradio as gr
3
  import tensorflow as tf
 
4
 
5
  # App title
6
  title = "Welcome to your first sketch recognition app!"
 
28
 
29
  # Prediction function for sketch recognition
30
  def predict(data):
31
+ # Reshape image to 28x28
32
+ img = np.reshape(data, (1, img_size, img_size, 1))
33
+ # Make prediction
34
+ pred = model.predict(img)
35
+ # Get top class
36
+ top_class = np.argmax
37
+ # Get top 3 classes
38
+ top_3_classes = np.argsort(pred[0])[-3:][::-1]
39
+ # Get top 3 probabilities
40
+ top_3_probs = pred[0][top_3_classes]
41
+ # Get class names
42
+ class_names = [labels[i] for i in top_3_classes]
43
+ # Return class names and probabilities
44
+ return {class_names[i]: top_3_probs[i] for i in range(3)}
 
 
45
 
46
  # Top 3 classes
47
  label = gr.Label(num_top_classes=3)
 
49
  # Open Gradio interface for sketch recognition
50
  interface = gr.Interface(
51
  fn=predict,
52
+ inputs=gr.Sketchpad(crop_size=(28,28), type='numpy', image_mode='L', brush=gr.Brush()),
53
  outputs=label,
54
  title=title,
55
  description=head,
requirements.txt CHANGED
@@ -1,3 +1,2 @@
1
  tensorflow
2
- opencv-python-headless
3
  numpy
 
1
  tensorflow
 
2
  numpy