Spaces:
Sleeping
Sleeping
Eric P. Nusbaum
commited on
Commit
·
7a9ce47
1
Parent(s):
123d436
Updated app.py
Browse files
app.py
CHANGED
@@ -2,36 +2,143 @@
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import tensorflow as tf
|
|
|
|
|
5 |
import os
|
6 |
|
7 |
# Suppress TensorFlow logging for cleaner logs
|
8 |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
16 |
graph_def = tf.compat.v1.GraphDef()
|
17 |
graph_def.ParseFromString(f.read())
|
18 |
|
19 |
with tf.Graph().as_default() as graph:
|
20 |
tf.import_graph_def(graph_def, name='')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
|
25 |
# Define Gradio interface
|
26 |
-
title = "JunkWaxHero 🦸♂️ -
|
27 |
-
description = "
|
28 |
|
29 |
iface = gr.Interface(
|
30 |
-
fn=
|
31 |
-
inputs=
|
32 |
-
outputs="
|
33 |
title=title,
|
34 |
description=description,
|
|
|
|
|
|
|
|
|
|
|
35 |
allow_flagging="never"
|
36 |
)
|
37 |
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import tensorflow as tf
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image, ImageDraw, ImageFont
|
7 |
import os
|
8 |
|
9 |
# Suppress TensorFlow logging for cleaner logs
|
10 |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
11 |
|
12 |
+
# Load labels
|
13 |
+
with open('tensorflow/labels.txt', 'r') as f:
|
14 |
+
labels = f.read().splitlines()
|
15 |
+
|
16 |
+
# Function to load the frozen TensorFlow graph
|
17 |
+
def load_frozen_graph(pb_file_path):
|
18 |
+
with tf.io.gfile.GFile(pb_file_path, 'rb') as f:
|
19 |
graph_def = tf.compat.v1.GraphDef()
|
20 |
graph_def.ParseFromString(f.read())
|
21 |
|
22 |
with tf.Graph().as_default() as graph:
|
23 |
tf.import_graph_def(graph_def, name='')
|
24 |
+
return graph
|
25 |
+
|
26 |
+
# Load the TensorFlow model
|
27 |
+
MODEL_DIR = 'tensorflow/'
|
28 |
+
MODEL_PATH = os.path.join(MODEL_DIR, 'model.pb')
|
29 |
+
graph = load_frozen_graph(MODEL_PATH)
|
30 |
+
sess = tf.compat.v1.Session(graph=graph)
|
31 |
+
|
32 |
+
# Define tensor names based on your model's outputs
|
33 |
+
input_tensor = graph.get_tensor_by_name('image_tensor:0')
|
34 |
+
detected_boxes = graph.get_tensor_by_name('detected_boxes:0')
|
35 |
+
detected_classes = graph.get_tensor_by_name('detected_classes:0')
|
36 |
+
detected_scores = graph.get_tensor_by_name('detected_scores:0')
|
37 |
+
|
38 |
+
# Define the target size based on your model's expected input
|
39 |
+
# You may need to adjust these values based on your model's requirements
|
40 |
+
TARGET_WIDTH = 224
|
41 |
+
TARGET_HEIGHT = 224
|
42 |
+
|
43 |
+
def preprocess_image(image):
|
44 |
+
"""
|
45 |
+
Preprocess the input image:
|
46 |
+
- Resize to target dimensions
|
47 |
+
- Convert to numpy array
|
48 |
+
- Normalize pixel values
|
49 |
+
- Convert RGB to BGR if required by the model
|
50 |
+
"""
|
51 |
+
image = image.resize((TARGET_WIDTH, TARGET_HEIGHT))
|
52 |
+
image_np = np.array(image).astype(np.float32)
|
53 |
+
image_np = image_np / 255.0 # Normalize to [0,1]
|
54 |
+
image_np = image_np[:, :, (2, 1, 0)] # Convert RGB to BGR if required
|
55 |
+
image_np = np.expand_dims(image_np, axis=0) # Add batch dimension
|
56 |
+
return image_np
|
57 |
+
|
58 |
+
def draw_boxes(image, boxes, classes, scores, threshold=0.5):
|
59 |
+
"""
|
60 |
+
Draw bounding boxes and labels on the image.
|
61 |
+
Args:
|
62 |
+
image (PIL.Image): The original image.
|
63 |
+
boxes (np.array): Array of bounding boxes.
|
64 |
+
classes (np.array): Array of class IDs.
|
65 |
+
scores (np.array): Array of confidence scores.
|
66 |
+
threshold (float): Confidence threshold to filter detections.
|
67 |
+
Returns:
|
68 |
+
PIL.Image: Annotated image.
|
69 |
+
"""
|
70 |
+
draw = ImageDraw.Draw(image)
|
71 |
+
try:
|
72 |
+
font = ImageFont.truetype("arial.ttf", 15)
|
73 |
+
except IOError:
|
74 |
+
font = ImageFont.load_default()
|
75 |
+
|
76 |
+
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
|
77 |
+
if score < threshold:
|
78 |
+
continue
|
79 |
+
# Convert box coordinates from normalized to absolute
|
80 |
+
ymin, xmin, ymax, xmax = box
|
81 |
+
left = xmin * image.width
|
82 |
+
right = xmax * image.width
|
83 |
+
top = ymin * image.height
|
84 |
+
bottom = ymax * image.height
|
85 |
+
|
86 |
+
# Draw rectangle
|
87 |
+
draw.rectangle([(left, top), (right, bottom)], outline="red", width=2)
|
88 |
+
|
89 |
+
# Prepare label
|
90 |
+
label = f"{labels[int(cls) - 1]}: {score:.2f}"
|
91 |
+
|
92 |
+
# Draw label background
|
93 |
+
text_size = draw.textsize(label, font=font)
|
94 |
+
draw.rectangle([(left, top - text_size[1]), (left + text_size[0], top)], fill="red")
|
95 |
+
|
96 |
+
# Draw text
|
97 |
+
draw.text((left, top - text_size[1]), label, fill="white", font=font)
|
98 |
+
|
99 |
+
return image
|
100 |
+
|
101 |
+
def predict(image):
|
102 |
+
"""
|
103 |
+
Perform inference on the input image and return the annotated image.
|
104 |
+
Args:
|
105 |
+
image (PIL.Image): Uploaded image.
|
106 |
+
Returns:
|
107 |
+
PIL.Image: Annotated image with bounding boxes and labels.
|
108 |
+
"""
|
109 |
+
try:
|
110 |
+
# Preprocess the image
|
111 |
+
input_array = preprocess_image(image)
|
112 |
+
|
113 |
+
# Run inference
|
114 |
+
boxes, classes, scores = sess.run(
|
115 |
+
[detected_boxes, detected_classes, detected_scores],
|
116 |
+
feed_dict={input_tensor: input_array}
|
117 |
+
)
|
118 |
+
|
119 |
+
# Annotate the image with bounding boxes and labels
|
120 |
+
annotated_image = draw_boxes(image.copy(), boxes, classes, scores, threshold=0.5)
|
121 |
+
|
122 |
+
return annotated_image
|
123 |
|
124 |
+
except Exception as e:
|
125 |
+
return Image.new('RGB', (500, 500), color = (255, 0, 0)).convert('RGB').copy()
|
126 |
|
127 |
# Define Gradio interface
|
128 |
+
title = "JunkWaxHero 🦸♂️ - Baseball Card Set Identifier"
|
129 |
+
description = "Upload an image of a baseball card, and JunkWaxHero will identify the set it belongs to with high accuracy."
|
130 |
|
131 |
iface = gr.Interface(
|
132 |
+
fn=predict,
|
133 |
+
inputs=gr.inputs.Image(type="pil"),
|
134 |
+
outputs=gr.outputs.Image(type="pil"),
|
135 |
title=title,
|
136 |
description=description,
|
137 |
+
examples=[
|
138 |
+
["examples/card1.jpg"],
|
139 |
+
["examples/card2.jpg"],
|
140 |
+
["examples/card3.jpg"]
|
141 |
+
],
|
142 |
allow_flagging="never"
|
143 |
)
|
144 |
|