Spaces:
Runtime error
Runtime error
Minor changes
Browse files
app.py
CHANGED
@@ -35,7 +35,7 @@ model.load_state_dict(torch.load(config.MODEL_PATH, map_location=cpu), strict=Fa
|
|
35 |
model.to(cpu)
|
36 |
# Make the model in evaluation mode
|
37 |
model.eval()
|
38 |
-
print(f"Model Device: {next(model.parameters()).device}")
|
39 |
|
40 |
|
41 |
# Load the misclassified images data
|
@@ -66,7 +66,7 @@ def get_target_layer(layer_name):
|
|
66 |
|
67 |
|
68 |
def generate_prediction(input_image, num_classes=3, show_gradcam=True, transparency=0.6, layer_name="layer3_x"):
|
69 |
-
""" "Given an input image, generate the prediction, confidence and
|
70 |
mean = list(config.CIFAR_MEAN)
|
71 |
std = list(config.CIFAR_STD)
|
72 |
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
@@ -74,35 +74,35 @@ def generate_prediction(input_image, num_classes=3, show_gradcam=True, transpare
|
|
74 |
with torch.no_grad():
|
75 |
orginal_img = input_image
|
76 |
input_image = transform(input_image).unsqueeze(0).to(cpu)
|
77 |
-
print(f"Input Device: {input_image.device}")
|
78 |
-
|
79 |
-
print(f"Output Device: {outputs.device}")
|
80 |
-
|
81 |
-
print(f"Output Exp Device: {o.device}")
|
82 |
|
83 |
-
|
84 |
# get indexes of probabilties in descending order
|
85 |
-
sorted_indexes = np.argsort(
|
86 |
# sort the probabilities in descending order
|
87 |
-
final_class = classes[o_np.argmax()]
|
88 |
|
89 |
confidences = {}
|
90 |
-
for
|
91 |
# set the confidence of highest class with highest probability
|
92 |
-
confidences[classes[sorted_indexes[
|
93 |
|
94 |
# Show Grad Cam
|
95 |
if show_gradcam:
|
96 |
# Get the target layer
|
97 |
target_layers = get_target_layer(layer_name)
|
98 |
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
else:
|
103 |
-
|
104 |
|
105 |
-
return
|
106 |
|
107 |
|
108 |
def app_interface(
|
@@ -118,10 +118,8 @@ def app_interface(
|
|
118 |
):
|
119 |
"""Function which provides the Gradio interface"""
|
120 |
|
121 |
-
# Get the prediction for the input image along with confidence and
|
122 |
-
|
123 |
-
input_image, num_classes, show_gradcam, transparency, layer_name
|
124 |
-
)
|
125 |
|
126 |
if show_misclassified:
|
127 |
misclassified_fig, misclassified_axs = plot_misclassified_images(
|
@@ -149,7 +147,7 @@ def app_interface(
|
|
149 |
# del misclassified_axs
|
150 |
# del gradcam_axs
|
151 |
|
152 |
-
return
|
153 |
|
154 |
|
155 |
TITLE = "CIFAR10 Image classification using a Custom ResNet Model"
|
@@ -182,11 +180,12 @@ inference_app = gr.Interface(
|
|
182 |
gr.Slider(value=10, maximum=25, minimum=5, step=5.0, precision=0, label="#GradCAM images to show"),
|
183 |
],
|
184 |
outputs=[
|
185 |
-
gr.
|
186 |
-
gr.
|
187 |
-
|
188 |
-
|
189 |
-
gr.Plot(label="
|
|
|
190 |
],
|
191 |
title=TITLE,
|
192 |
description=DESCRIPTION,
|
|
|
35 |
model.to(cpu)
|
36 |
# Make the model in evaluation mode
|
37 |
model.eval()
|
38 |
+
# print(f"Model Device: {next(model.parameters()).device}")
|
39 |
|
40 |
|
41 |
# Load the misclassified images data
|
|
|
66 |
|
67 |
|
68 |
def generate_prediction(input_image, num_classes=3, show_gradcam=True, transparency=0.6, layer_name="layer3_x"):
|
69 |
+
""" "Given an input image, generate the prediction, confidence and display_image"""
|
70 |
mean = list(config.CIFAR_MEAN)
|
71 |
std = list(config.CIFAR_STD)
|
72 |
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
|
|
74 |
with torch.no_grad():
|
75 |
orginal_img = input_image
|
76 |
input_image = transform(input_image).unsqueeze(0).to(cpu)
|
77 |
+
# print(f"Input Device: {input_image.device}")
|
78 |
+
model_output = model(input_image).to(cpu)
|
79 |
+
# print(f"Output Device: {outputs.device}")
|
80 |
+
output_exp = torch.exp(model_output).to(cpu)
|
81 |
+
# print(f"Output Exp Device: {o.device}")
|
82 |
|
83 |
+
output_numpy = np.squeeze(np.asarray(output_exp.numpy()))
|
84 |
# get indexes of probabilties in descending order
|
85 |
+
sorted_indexes = np.argsort(output_numpy)[::-1]
|
86 |
# sort the probabilities in descending order
|
87 |
+
# final_class = classes[o_np.argmax()]
|
88 |
|
89 |
confidences = {}
|
90 |
+
for _ in range(int(num_classes)):
|
91 |
# set the confidence of highest class with highest probability
|
92 |
+
confidences[classes[sorted_indexes[_]]] = float(output_numpy[sorted_indexes[_]])
|
93 |
|
94 |
# Show Grad Cam
|
95 |
if show_gradcam:
|
96 |
# Get the target layer
|
97 |
target_layers = get_target_layer(layer_name)
|
98 |
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
|
99 |
+
cam_generated = cam(input_tensor=input_image, targets=None)
|
100 |
+
cam_generated = cam_generated[0, :]
|
101 |
+
display_image = show_cam_on_image(orginal_img / 255, cam_generated, use_rgb=True, image_weight=transparency)
|
102 |
else:
|
103 |
+
display_image = orginal_img
|
104 |
|
105 |
+
return confidences, display_image
|
106 |
|
107 |
|
108 |
def app_interface(
|
|
|
118 |
):
|
119 |
"""Function which provides the Gradio interface"""
|
120 |
|
121 |
+
# Get the prediction for the input image along with confidence and display_image
|
122 |
+
confidences, display_image = generate_prediction(input_image, num_classes, show_gradcam, transparency, layer_name)
|
|
|
|
|
123 |
|
124 |
if show_misclassified:
|
125 |
misclassified_fig, misclassified_axs = plot_misclassified_images(
|
|
|
147 |
# del misclassified_axs
|
148 |
# del gradcam_axs
|
149 |
|
150 |
+
return confidences, display_image, misclassified_fig, gradcam_fig
|
151 |
|
152 |
|
153 |
TITLE = "CIFAR10 Image classification using a Custom ResNet Model"
|
|
|
180 |
gr.Slider(value=10, maximum=25, minimum=5, step=5.0, precision=0, label="#GradCAM images to show"),
|
181 |
],
|
182 |
outputs=[
|
183 |
+
gr.Label(label="Confidences", container=True, show_label=True),
|
184 |
+
gr.Image(shape=(32, 32), label="Grad CAM/ Input Image", container=True, show_label=True).style(
|
185 |
+
width=256, height=256
|
186 |
+
),
|
187 |
+
gr.Plot(label="Misclassified images", container=True, show_label=True),
|
188 |
+
gr.Plot(label="Grad CAM of Misclassified images", container=True, show_label=True),
|
189 |
],
|
190 |
title=TITLE,
|
191 |
description=DESCRIPTION,
|