|
import os |
|
import gradio as gr |
|
import glob |
|
import time |
|
import random |
|
import requests |
|
import numpy as np |
|
|
|
|
|
from torchvision import models, transforms |
|
from PIL import Image |
|
import torch |
|
|
|
|
|
model = models.resnet50(pretrained=True) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
def download_imagenet_classes(): |
|
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" |
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
with open("imagenet_classes.txt", "wb") as f: |
|
f.write(response.content) |
|
print("imagenet_classes.txt downloaded successfully.") |
|
else: |
|
print("Failed to download imagenet_classes.txt") |
|
|
|
|
|
if not os.path.exists("imagenet_classes.txt"): |
|
download_imagenet_classes() |
|
|
|
|
|
with open('imagenet_classes.txt', 'r') as f: |
|
labels = [line.strip() for line in f.readlines()] |
|
|
|
def classify_image(image): |
|
|
|
|
|
print("Classifying image...") |
|
|
|
|
|
img = Image.fromarray(image).convert('RGB') |
|
img_t = transform(img) |
|
batch_t = torch.unsqueeze(img_t, 0) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(batch_t) |
|
|
|
|
|
_, predicted = torch.max(output, 1) |
|
classification = labels[predicted.item()] |
|
|
|
|
|
bird_categories = ['bird', 'fowl', 'hen', 'cock', 'rooster', 'peacock', 'parrot', 'eagle', 'owl', 'penguin'] |
|
is_bird = ('bird' in classification.lower()) or any(category in classification.lower() for category in bird_categories) |
|
|
|
|
|
confidence_score = torch.nn.functional.softmax(output[0], dim=0)[predicted].item() |
|
confidence_percentage = f"{confidence_score:.2%}" |
|
|
|
if is_bird: |
|
return f"This is a bird! Specifically, it looks like a {classification}. Model confidence: {confidence_percentage}" |
|
else: |
|
return f"This is not a bird. It appears to be a {classification}. Model confidence: {confidence_percentage}" |
|
|
|
|
|
example_files = sorted(glob.glob("examples/*.png")) |
|
examples = [[file] for file in example_files] |
|
|
|
|
|
demo = gr.Interface( |
|
fn=classify_image, |
|
inputs="image", |
|
outputs="text", |
|
examples=examples |
|
,title="Is this a picture of a bird?" |
|
,description="Uses the latest in machine learning LLM Diffusion models to analyzes every pixel (twice) and to determine conclusively if it is a picture of a bird" |
|
) |
|
|
|
demo.launch() |