import os import gradio as gr import glob import time import random import requests import numpy as np # Import necessary libraries from torchvision import models, transforms from PIL import Image import torch # Load pre-trained ResNet model once model = models.resnet50(pretrained=True) model.eval() # # Define image transformations transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Function to download imagenet_classes.txt def download_imagenet_classes(): url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" response = requests.get(url) if response.status_code == 200: with open("imagenet_classes.txt", "wb") as f: f.write(response.content) print("imagenet_classes.txt downloaded successfully.") else: print("Failed to download imagenet_classes.txt") # Check if imagenet_classes.txt exists, if not, download it if not os.path.exists("imagenet_classes.txt"): download_imagenet_classes() # Load class labels with open('imagenet_classes.txt', 'r') as f: labels = [line.strip() for line in f.readlines()] def classify_image(image): # Wait for a random interval between 0.5 and 1.5 seconds to look useful # time.sleep(random.uniform(0.5, 1.5)) print("Classifying image...") # Preprocess the image img = Image.fromarray(image).convert('RGB') img_t = transform(img) batch_t = torch.unsqueeze(img_t, 0) # Make prediction with torch.no_grad(): output = model(batch_t) # Get the predicted class _, predicted = torch.max(output, 1) classification = labels[predicted.item()] # Check if the predicted class is a bird bird_categories = [ 'cock', 'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'robin', 'bulbul', 'jay', 'magpie', 'chickadee', 'water ouzel', 'kite', 'bald eagle', 'vulture', 'great grey owl', 'European fire salamander', 'ptarmigan', 'ruffed grouse', 'prairie chicken', 'peacock', 'quail', 'partridge', 'African grey', 'macaw', 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', 'hornbill', 'hummingbird', 'jacamar', 'toucan', 'drake', 'red-breasted merganser', 'goose', 'black swan', 'white stork', 'black stork', 'spoonbill', 'flamingo', 'little blue heron', 'American egret', 'bittern', 'crane', 'limpkin', 'European gallinule', 'American coot', 'bustard', 'ruddy turnstone', 'red-backed sandpiper', 'redshank', 'dowitcher', 'oystercatcher', 'pelican', 'king penguin', 'albatross' ] is_bird = ('bird' in classification.lower()) or any(category in classification.lower() for category in bird_categories) # # Get the confidence score confidence_score = torch.nn.functional.softmax(output[0], dim=0)[predicted].item() confidence_percentage = f"{confidence_score:.2%}" if is_bird: return f"This is a bird! Specifically, it looks like a {classification}. Model confidence: {confidence_percentage}" else: return f"This is not a bird. It appears to be a {classification}. Model confidence: {confidence_percentage}" # # Dynamically create the list of example images example_files = sorted(glob.glob("examples/*.png")) examples = [[file] for file in example_files] # Create the Gradio interface demo = gr.Interface( fn=classify_image, # The function to run inputs="image", # The input type is an image outputs="text", # The output type is text examples=examples # Add example images ,title="Is this a picture of a bird?" # Title of the app ,description="Uses the latest in machine learning LLM Diffusion models to analyzes every pixel (twice) and to determine conclusively if it is a picture of a bird" # Description of the app ) # Launch the app demo.launch()