Eric P. Nusbaum commited on
Commit
7a9ce47
·
1 Parent(s): 123d436

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -13
app.py CHANGED
@@ -2,36 +2,143 @@
2
 
3
  import gradio as gr
4
  import tensorflow as tf
 
 
5
  import os
6
 
7
  # Suppress TensorFlow logging for cleaner logs
8
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
9
 
10
- def list_tensor_names():
11
- # Load the frozen TensorFlow graph
12
- MODEL_DIR = 'tensorflow/'
13
- MODEL_PATH = os.path.join(MODEL_DIR, 'model.pb')
14
-
15
- with tf.io.gfile.GFile(MODEL_PATH, 'rb') as f:
 
16
  graph_def = tf.compat.v1.GraphDef()
17
  graph_def.ParseFromString(f.read())
18
 
19
  with tf.Graph().as_default() as graph:
20
  tf.import_graph_def(graph_def, name='')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- tensor_names = [op.name for op in graph.get_operations()]
23
- return "\n".join(tensor_names)
24
 
25
  # Define Gradio interface
26
- title = "JunkWaxHero 🦸‍♂️ - Tensor Names Inspector"
27
- description = "This interface lists all tensor names in the TensorFlow model to help identify the correct input and output tensors."
28
 
29
  iface = gr.Interface(
30
- fn=list_tensor_names,
31
- inputs=None,
32
- outputs="text",
33
  title=title,
34
  description=description,
 
 
 
 
 
35
  allow_flagging="never"
36
  )
37
 
 
2
 
3
  import gradio as gr
4
  import tensorflow as tf
5
+ import numpy as np
6
+ from PIL import Image, ImageDraw, ImageFont
7
  import os
8
 
9
  # Suppress TensorFlow logging for cleaner logs
10
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
11
 
12
+ # Load labels
13
+ with open('tensorflow/labels.txt', 'r') as f:
14
+ labels = f.read().splitlines()
15
+
16
+ # Function to load the frozen TensorFlow graph
17
+ def load_frozen_graph(pb_file_path):
18
+ with tf.io.gfile.GFile(pb_file_path, 'rb') as f:
19
  graph_def = tf.compat.v1.GraphDef()
20
  graph_def.ParseFromString(f.read())
21
 
22
  with tf.Graph().as_default() as graph:
23
  tf.import_graph_def(graph_def, name='')
24
+ return graph
25
+
26
+ # Load the TensorFlow model
27
+ MODEL_DIR = 'tensorflow/'
28
+ MODEL_PATH = os.path.join(MODEL_DIR, 'model.pb')
29
+ graph = load_frozen_graph(MODEL_PATH)
30
+ sess = tf.compat.v1.Session(graph=graph)
31
+
32
+ # Define tensor names based on your model's outputs
33
+ input_tensor = graph.get_tensor_by_name('image_tensor:0')
34
+ detected_boxes = graph.get_tensor_by_name('detected_boxes:0')
35
+ detected_classes = graph.get_tensor_by_name('detected_classes:0')
36
+ detected_scores = graph.get_tensor_by_name('detected_scores:0')
37
+
38
+ # Define the target size based on your model's expected input
39
+ # You may need to adjust these values based on your model's requirements
40
+ TARGET_WIDTH = 224
41
+ TARGET_HEIGHT = 224
42
+
43
+ def preprocess_image(image):
44
+ """
45
+ Preprocess the input image:
46
+ - Resize to target dimensions
47
+ - Convert to numpy array
48
+ - Normalize pixel values
49
+ - Convert RGB to BGR if required by the model
50
+ """
51
+ image = image.resize((TARGET_WIDTH, TARGET_HEIGHT))
52
+ image_np = np.array(image).astype(np.float32)
53
+ image_np = image_np / 255.0 # Normalize to [0,1]
54
+ image_np = image_np[:, :, (2, 1, 0)] # Convert RGB to BGR if required
55
+ image_np = np.expand_dims(image_np, axis=0) # Add batch dimension
56
+ return image_np
57
+
58
+ def draw_boxes(image, boxes, classes, scores, threshold=0.5):
59
+ """
60
+ Draw bounding boxes and labels on the image.
61
+ Args:
62
+ image (PIL.Image): The original image.
63
+ boxes (np.array): Array of bounding boxes.
64
+ classes (np.array): Array of class IDs.
65
+ scores (np.array): Array of confidence scores.
66
+ threshold (float): Confidence threshold to filter detections.
67
+ Returns:
68
+ PIL.Image: Annotated image.
69
+ """
70
+ draw = ImageDraw.Draw(image)
71
+ try:
72
+ font = ImageFont.truetype("arial.ttf", 15)
73
+ except IOError:
74
+ font = ImageFont.load_default()
75
+
76
+ for box, cls, score in zip(boxes[0], classes[0], scores[0]):
77
+ if score < threshold:
78
+ continue
79
+ # Convert box coordinates from normalized to absolute
80
+ ymin, xmin, ymax, xmax = box
81
+ left = xmin * image.width
82
+ right = xmax * image.width
83
+ top = ymin * image.height
84
+ bottom = ymax * image.height
85
+
86
+ # Draw rectangle
87
+ draw.rectangle([(left, top), (right, bottom)], outline="red", width=2)
88
+
89
+ # Prepare label
90
+ label = f"{labels[int(cls) - 1]}: {score:.2f}"
91
+
92
+ # Draw label background
93
+ text_size = draw.textsize(label, font=font)
94
+ draw.rectangle([(left, top - text_size[1]), (left + text_size[0], top)], fill="red")
95
+
96
+ # Draw text
97
+ draw.text((left, top - text_size[1]), label, fill="white", font=font)
98
+
99
+ return image
100
+
101
+ def predict(image):
102
+ """
103
+ Perform inference on the input image and return the annotated image.
104
+ Args:
105
+ image (PIL.Image): Uploaded image.
106
+ Returns:
107
+ PIL.Image: Annotated image with bounding boxes and labels.
108
+ """
109
+ try:
110
+ # Preprocess the image
111
+ input_array = preprocess_image(image)
112
+
113
+ # Run inference
114
+ boxes, classes, scores = sess.run(
115
+ [detected_boxes, detected_classes, detected_scores],
116
+ feed_dict={input_tensor: input_array}
117
+ )
118
+
119
+ # Annotate the image with bounding boxes and labels
120
+ annotated_image = draw_boxes(image.copy(), boxes, classes, scores, threshold=0.5)
121
+
122
+ return annotated_image
123
 
124
+ except Exception as e:
125
+ return Image.new('RGB', (500, 500), color = (255, 0, 0)).convert('RGB').copy()
126
 
127
  # Define Gradio interface
128
+ title = "JunkWaxHero 🦸‍♂️ - Baseball Card Set Identifier"
129
+ description = "Upload an image of a baseball card, and JunkWaxHero will identify the set it belongs to with high accuracy."
130
 
131
  iface = gr.Interface(
132
+ fn=predict,
133
+ inputs=gr.inputs.Image(type="pil"),
134
+ outputs=gr.outputs.Image(type="pil"),
135
  title=title,
136
  description=description,
137
+ examples=[
138
+ ["examples/card1.jpg"],
139
+ ["examples/card2.jpg"],
140
+ ["examples/card3.jpg"]
141
+ ],
142
  allow_flagging="never"
143
  )
144