etemkocaaslan's picture
Update app.py
2ee95b8 verified
raw
history blame
4.84 kB
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.transforms import Compose
import requests
import random
import gradio as gr
image_prediction_models = {
'resnet': models.resnet50,
'alexnet': models.alexnet,
'vgg': models.vgg16,
'squeezenet': models.squeezenet1_0,
'densenet': models.densenet161,
'inception': models.inception_v3,
'googlenet': models.googlenet,
'shufflenet': models.shufflenet_v2_x1_0,
'mobilenet': models.mobilenet_v2,
'resnext': models.resnext50_32x4d,
'wide_resnet': models.wide_resnet50_2,
'mnasnet': models.mnasnet1_0,
'efficientnet': models.efficientnet_b0,
'regnet': models.regnet_y_400mf,
'vit': models.vit_b_16,
'convnext': models.convnext_tiny
}
def load_pretrained_model(model_name):
model_name_lower = model_name.lower()
if model_name_lower in image_prediction_models:
model_class = image_prediction_models[model_name_lower]
model = model_class(pretrained=True)
return model
else:
raise ValueError(f"Model {model_name} is not available for image prediction in torchvision.models")
def get_model_names(models_dict):
return [name.capitalize() for name in models_dict.keys()]
model_list = get_model_names(image_prediction_models)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def preprocess(model_name):
input_size = 224
if model_name == 'inception':
input_size = 299
return transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
normalize,
])
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
def postprocess_default(output):
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top_prob, top_catid = torch.topk(probabilities, 5)
confidences = {labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
return confidences
def postprocess_inception(output):
probabilities = torch.nn.functional.softmax(output[1], dim=0)
top_prob, top_catid = torch.topk(probabilities, 5)
confidences = {labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
return confidences
def classify_image(input_image, selected_model):
preprocess_input = preprocess(model_name=selected_model)
input_tensor = preprocess_input(input_image)
input_batch = input_tensor.unsqueeze(0)
model = load_pretrained_model(selected_model)
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
model.eval()
with torch.no_grad():
output = model(input_batch)
if selected_model.lower() == 'inception':
return postprocess_inception(output)
else:
return postprocess_default(output)
def get_random_image():
cifar10 = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
random_idx = random.randint(0, len(cifar10) - 1)
image, _ = cifar10[random_idx]
image = transforms.ToPILImage()(image)
return image
def generate_random_image():
image = get_random_image()
return image
def classify_generated_image(image, model):
return classify_image(image, model)
with gr.Blocks() as demo:
with gr.Tabs():
with gr.TabItem("Upload Image"):
with gr.Row():
with gr.Column():
upload_image = gr.Image(type='pil', label="Upload Image")
model_dropdown_upload = gr.Dropdown(model_list, label="Select Model")
classify_button_upload = gr.Button("Classify")
with gr.Column():
output_label_upload = gr.Label(num_top_classes=5)
classify_button_upload.click(classify_image, inputs=[upload_image, model_dropdown_upload], outputs=output_label_upload)
with gr.TabItem("Generate Random Image"):
with gr.Row():
with gr.Column():
generate_button = gr.Button("Generate Random Image")
random_image_output = gr.Image(type='pil', label="Random CIFAR-10 Image")
with gr.Column():
model_dropdown_random = gr.Dropdown(model_list, label="Select Model")
classify_button_random = gr.Button("Classify")
output_label_random = gr.Label(num_top_classes=5)
generate_button.click(generate_random_image, inputs=[], outputs=random_image_output)
classify_button_random.click(classify_generated_image, inputs=[random_image_output, model_dropdown_random], outputs=output_label_random)
demo.launch()