Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModel, AutoTokenizer | |
import torch | |
import json | |
import requests | |
from PIL import Image | |
from torchvision import transforms | |
import urllib.request | |
# Load the label-to-class mapping from your Hugging Face repository | |
label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json" | |
label_to_class = requests.get(label_map_url).json() | |
# Load the model and tokenizer from your Hugging Face repository | |
model = AutoModel.from_pretrained("Maverick98/EcommerceClassifier") | |
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en") | |
# Define image preprocessing | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
def load_image(image_path_or_url): | |
""" | |
Load an image from a URL or local path and preprocess it. | |
""" | |
if image_path_or_url.startswith("http"): | |
with urllib.request.urlopen(image_path_or_url) as url: | |
image = Image.open(url).convert('RGB') | |
else: | |
image = Image.open(image_path_or_url).convert('RGB') | |
image = transform(image) | |
image = image.unsqueeze(0) # Add batch dimension | |
return image | |
def predict(image_path_or_url, title, threshold=0.7): | |
""" | |
Predict the top 3 categories for the given image and title. | |
Includes "Others" if the confidence of the top prediction is below the threshold. | |
""" | |
# Preprocess the image | |
image = load_image(image_path_or_url) | |
# Tokenize the title | |
title_encoding = tokenizer(title, padding='max_length', max_length=32, truncation=True, return_tensors='pt') | |
input_ids = title_encoding['input_ids'] | |
attention_mask = title_encoding['attention_mask'] | |
# Predict | |
model.eval() | |
with torch.no_grad(): | |
output = model(image, input_ids=input_ids, attention_mask=attention_mask) | |
probabilities = torch.nn.functional.softmax(output, dim=1) | |
top3_probabilities, top3_indices = torch.topk(probabilities, 3, dim=1) | |
# Map the top 3 indices to class names | |
top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]] | |
# Check if the highest probability is below the threshold | |
if top3_probabilities[0][0].item() < threshold: | |
top3_classes.insert(0, "Others") | |
top3_probabilities = torch.cat((torch.tensor([[1.0 - top3_probabilities[0][0].item()]]), top3_probabilities), dim=1) | |
# Prepare the output as a dictionary | |
results = {} | |
for i in range(len(top3_classes)): | |
results[top3_classes[i]] = top3_probabilities[0][i].item() | |
return results | |
# Define the Gradio interface | |
title_input = gr.inputs.Textbox(label="Product Title", placeholder="Enter the product title here...") | |
image_input = gr.inputs.Textbox(label="Image URL or Path", placeholder="Enter image URL or local path here...") | |
output = gr.outputs.JSON(label="Top 3 Predictions with Probabilities") | |
gr.Interface( | |
fn=predict, | |
inputs=[image_input, title_input], | |
outputs=output, | |
title="Ecommerce Classifier", | |
description="This model classifies ecommerce products into one of 434 categories. If the model is unsure, it outputs 'Others'.", | |
).launch() | |