neotestkit / gradio_image.py
johnwick999's picture
Upload folder using huggingface_hub
e1c4487 verified
import gradio as gr
from PIL import Image
import numpy as np
import cv2
from tensorflow.lite.python.interpreter import Interpreter
def tflite_detect_images(
modelpath,
lblpath,
image_path,
min_conf=0.1,
):
# Grab filenames of all images in test folder
# Load the label map into memory
with open(lblpath, "r") as f:
labels = [line.strip() for line in f.readlines()]
# Load the Tensorflow Lite model into memory
interpreter = Interpreter(model_path=modelpath)
interpreter.allocate_tensors()
# Get model details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# print("input", input_details)
# print("output_details________________________")
# print("output", output_details)
height = input_details[0]["shape"][1]
width = input_details[0]["shape"][2]
# print(height, width)
float_input = input_details[0]["dtype"] == np.float32
input_mean = 127.5
input_std = 127.5
# Loop over every image and perform detection
# Load image and resize to expected shape [1xHxWx3]
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
imH, imW, _ = image.shape
image_resized = cv2.resize(image, (width, height))
input_data = np.expand_dims(image_resized, axis=0)
# print("before_float", input_data)
# Normalize pixel values if using a floating model (i.e. if model is non-quantized)
if float_input:
# print("truue")
input_data = (np.float32(input_data) - input_mean) / input_std
# print("after float_mean", input_data)
# Perform the actual detection by running the model with the image as input
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
# Retrieve detection results
boxes = interpreter.get_tensor(output_details[1]["index"])[
0
] # Bounding box coordinates of detected objects
classes = interpreter.get_tensor(output_details[3]["index"])[
0
] # Class index of detected objects
scores = interpreter.get_tensor(output_details[0]["index"])[
0
] # Confidence of detected objects
# print(boxes)
# print("clas", classes)
# print("scores", scores)
# Loop over all detections and draw detection box if confidence is above minimum threshold
for i in range(len(scores)):
if (scores[i] > min_conf) and (scores[i] <= 1.0):
# Get bounding box coordinates and draw box
# Interpreter can return coordinates that are outside of image dimensions, need to force them to be within image using max() and min()
ymin = int(max(1, (boxes[i][0] * imH)))
xmin = int(max(1, (boxes[i][1] * imW)))
ymax = int(min(imH, (boxes[i][2] * imH)))
xmax = int(min(imW, (boxes[i][3] * imW)))
cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (10, 255, 0), 2)
# Draw label
# object_name = labels[
# int(classes[i])
# ] # Look up object name from "labels" array using class index
# # label = "%s: %d%%" % (
# # object_name,
# # int(scores[i] * 100),
# # ) # Example: 'person: 72%'
# label = object_name
# labelSize, baseLine = cv2.getTextSize(
# label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2
# ) # Get font size
# # Define the position and rotation of the main text
# # main_x = 20
# # main_y = 180
# main_rotation = 90
# # Calculate the rotation matrix for the main text
# main_rotation_matrix = cv2.getRotationMatrix2D((xmax, ymin), main_rotation, 1)
# # Create a black image with the same size as the input image
# text_img = np.zeros_like(image)
# label_ymin = max(
# ymin , labelSize[1] + 10
# ) # Make sure not to draw label too close to top of window
# cv2.rectangle(
# text_img,
# (xmin, label_ymin - labelSize[1] - 10),
# (xmin + labelSize[0], label_ymin + baseLine - 10),
# (255, 255, 255),
# cv2.FILLED,
# ) # Draw white box to put label text in
# cv2.putText(
# text_img,
# label,
# (xmin, label_ymin - 7),
# cv2.FONT_HERSHEY_SIMPLEX,
# 0.7,
# (0, 0, 0),
# 2,
# ) # Draw label text
# rotated_text_img = cv2.warpAffine(text_img, main_rotation_matrix, (image.shape[1], image.shape[0]))
# image = cv2.add(image, rotated_text_img)
# detections.append([object_name, scores[i], xmin, ymin, xmax, ymax])
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# cv2.imwrite("output.jpg", image)
return image
def show_image(img):
PATH_TO_MODEL = "detect.tflite" # Path to .tflite model file
PATH_TO_LABELS = "labelmap.txt" # Path to labelmap.txt file
min_conf_threshold = 0.3 # Confidence threshold (try changing this to 0.01 if you don't see any detection results
# Run inferencing function!
cv_image = tflite_detect_images(
PATH_TO_MODEL, PATH_TO_LABELS, img, min_conf_threshold
)
# # Convert To PIL Image
# image = Image.open(img)
# print(type(image))
# # Convert the image to a NumPy array
# image_array = np.array(image)
# print(type(image_array))
return cv_image
app = gr.Interface(
fn=show_image,
inputs=gr.Image(label="Input Image", type="filepath"),
outputs=gr.Image(label="Output Image", type="filepath"),
)
app.launch(share=True)