import gradio as gr import torch from io import BytesIO import os import cv2 import gradio as gr import numpy as np import requests from PIL import Image import gradio as gr import cv2 import tempfile import numpy as np import torch from torchvision import transforms from PIL import Image import matplotlib.pyplot as plt from io import BytesIO # Load the YOLO model from models.common import DetectMultiBackend weights_path = "./last.pt" device = torch.device("cpu") # Correctly define the device model = DetectMultiBackend(weights_path, device=device) # Load YOLOv5 model correctly model.eval() # model_path = "./last.pt" # model = torch.jit.load(model_path, map_location=torch.device("cpu")) # model.eval() # transform=transforms.Compose([ # transforms.ToPILImage(), # transforms.Resize((512,640)), # transforms.ToTensor() # ]) transform = transforms.Compose([ transforms.ToPILImage(), # Ensure input is a PIL image transforms.Resize((512, 640)), transforms.ToTensor() ]) # transform = transforms.Compose([ # transforms.Resize((640, 640)), # transforms.ToTensor(), # ]) OBJECT_NAMES = ['enemies'] def detect_objects_in_image(image): """ Detect objects in the given image. """ # Ensure image is a PIL Image if isinstance(image, torch.Tensor): image = transforms.ToPILImage()(image) # Convert tensor to PIL image if isinstance(image, Image.Image): orig_w, orig_h = image.size # PIL image size returns (width, height) else: raise TypeError(f"Expected a PIL Image but got {type(image)}") # Apply transformation img_tensor = transform(image).unsqueeze(0) with torch.no_grad(): pred = model(img_tensor)[0] if isinstance(pred[0], torch.Tensor): pred = [p.cpu().numpy() for p in pred] pred = np.concatenate(pred, axis=0) conf_thres = 0.25 mask = pred[:, 4] > conf_thres pred = pred[mask] if len(pred) == 0: return Image.fromarray(np.array(image)), None # Return only image and None for graph boxes, scores, class_probs = pred[:, :4], pred[:, 4], pred[:, 5:] class_ids = np.argmax(class_probs, axis=1) boxes[:, 0] = boxes[:, 0] - (boxes[:, 2] / 2) boxes[:, 1] = boxes[:, 1] - (boxes[:, 3] / 2) boxes[:, 2] = boxes[:, 0] + boxes[:, 2] boxes[:, 3] = boxes[:, 1] + boxes[:, 3] boxes[:, [0, 2]] *= orig_w / 640 boxes[:, [1, 3]] *= orig_h / 640 boxes = np.clip(boxes, 0, [orig_w, orig_h, orig_w, orig_h]) indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), conf_thres, 0.5) object_counts = {name: 0 for name in OBJECT_NAMES} img_array = np.array(image) if len(indices) > 0: for i in indices.flatten(): x1, y1, x2, y2 = map(int, boxes[i]) cls = class_ids[i] object_name = OBJECT_NAMES[cls] if cls < len(OBJECT_NAMES) else f"Unknown ({cls})" if object_name in object_counts: object_counts[object_name] += 1 cv2.rectangle(img_array, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText(img_array, f"{object_name}: {scores[i]:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) # Generate and return graph instead of dictionary graph_image = generate_vehicle_count_graph(object_counts) return Image.fromarray(img_array), graph_image # Now returning only 2 outputs # def generate_vehicle_count_graph(object_counts): # color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1', '#B8B8D1'] # fig, ax = plt.subplots(figsize=(8, 5)) # labels = list(object_counts.keys()) # values = list(object_counts.values()) # ax.bar(labels, values, color=color_palette[:len(labels)]) # ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold') # ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold') # ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold') # plt.xticks(rotation=45, ha='right', fontsize=10) # plt.yticks(fontsize=10) # plt.tight_layout() # buf = BytesIO() # plt.savefig(buf, format='png') # buf.seek(0) # return Image.open(buf) def generate_vehicle_count_graph(object_counts): color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1', '#B8B8D1'] fig, ax = plt.subplots(figsize=(8, 5)) labels = list(object_counts.keys()) values = list(object_counts.values()) ax.bar(labels, values, color=color_palette[:len(labels)]) ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold') ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold') ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold') plt.xticks(rotation=45, ha='right', fontsize=10) plt.yticks(fontsize=10) plt.tight_layout() buf = BytesIO() plt.savefig(buf, format='png') buf.seek(0) plt.close(fig) # ✅ CLOSE THE FIGURE TO FREE MEMORY return Image.open(buf) def detect_objects_in_video(video_input): cap = cv2.VideoCapture(video_input) if not cap.isOpened(): return "Error: Cannot open video file.", None # Returning a second value (None) to match expected outputs frame_width, frame_height, fps = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FPS)) temp_video_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name out = cv2.VideoWriter(temp_video_output, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height)) # Initialize the counts for vehicle categories total_counts = {name: 0 for name in ['car', 'truck', 'bus', 'motorcycle', 'bicycle']} while cap.isOpened(): ret, frame = cap.read() if not ret: break image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Get frame with detected objects and graph frame_with_boxes, graph_image = detect_objects_in_image(image) # Convert image back to OpenCV format for writing video out.write(cv2.cvtColor(np.array(frame_with_boxes), cv2.COLOR_RGB2BGR)) cap.release() out.release() return temp_video_output, graph_image # Return both expected outputs def greet(name): return "Hello " + name + "!!" demo = gr.Interface(fn=greet, inputs="text", outputs="text") from urllib.request import urlretrieve # get image examples from github urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-1450-_jpg.jpg?raw=true", "clip2_-1450-_jpg.jpg") # make sure to use "copy image address when copying image from Github" urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-539-_jpg.jpg?raw=true", "clip2_-539-_jpg.jpg") examples = [ # need to manually delete cache everytime new examples are added ["clip2_-1450-_jpg.jpg"], ["clip2_-539-_jpg.jpg"]] # define app features and run title = "SpecLab Demo" description = "

Gradio demo for an ASPP model architecture trained on the SpecLab dataset. To use it, simply add your image, or click one of the examples to load them. Since this demo is run on CPU only, please allow additional time for processing.

" article = "

Github Repo

" css = "#0 {object-fit: contain;} #1 {object-fit: contain;}" demo = gr.Interface(fn=detect_objects_in_image, title=title, description=description, article=article, inputs=gr.Image(elem_id=0, show_label=False), outputs=gr.Image(elem_id=1, show_label=False), css=css, examples=examples, cache_examples=True, allow_flagging='never') demo.launch()