Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torchvision.transforms as transforms | |
from torchvision.transforms import InterpolationMode | |
import torch | |
from huggingface_hub import hf_hub_download | |
from model import Model | |
# Load Model | |
model_path = hf_hub_download( | |
repo_id="itserr/exvoto_classifier_convnext_base_224", | |
filename="model.pt" | |
) | |
model = Model('convnext_base') | |
ckpt = torch.load(model_path, map_location=torch.device("cpu")) # Ensure compatibility | |
model.load_state_dict(ckpt['model']) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model.to(device) | |
model.eval() | |
# Image Transformations | |
transform = transforms.Compose([ | |
transforms.Resize(size=(224,224), interpolation=InterpolationMode.BICUBIC), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Classification Function | |
def classify_img(img, threshold): | |
classification_threshold = threshold | |
img_tensor = transform(img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
pred = model(img_tensor) | |
score = torch.sigmoid(pred).item() | |
# Determine Prediction | |
if score >= classification_threshold: | |
label = "β This is an **Ex-Voto** image!" | |
else: | |
label = "β This is **NOT** an Ex-Voto image." | |
# Format Confidence Score | |
confidence = f"The probability that the image is an ex-voto is: {score:.2%}" | |
return label, confidence | |
example_images = [['examples/exvoto1.jpg', None], | |
['examples/exvoto2.jpg', None], | |
['examples/nonexvoto1.jpg', None], | |
['examples/nonexvoto2.jpg', None], | |
['examples/natural1.jpg', None], | |
['examples/natural2.jpg', None],] | |
# Function to Clear Outputs When a New Image is Uploaded | |
def clear_outputs(img): | |
return gr.update(value=""), gr.update(value="") | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## Ex-Voto Image Classifier") | |
gr.Markdown("πΈ **Upload an image** to check if it's an **Ex-Voto** painting!") | |
with gr.Row(): | |
with gr.Column(scale=2): # Left section: Image upload & slider | |
img_input = gr.Image(type="pil") | |
threshold_slider = gr.Slider( | |
minimum=0.5, maximum=1.0, value=0.7, step=0.1, label="Classification Threshold" | |
) | |
submit_btn = gr.Button("Classify") | |
with gr.Column(scale=1): # Right section: Prediction & Confidence | |
prediction_output = gr.Textbox(label="Prediction", interactive=False) | |
confidence_output = gr.Textbox(label="Confidence Score", interactive=False) | |
# Clear outputs when a new image is uploaded | |
img_input.change(fn=clear_outputs, inputs=[img_input], outputs=[prediction_output, confidence_output]) | |
# Submit button triggers classification | |
submit_btn.click(fn=classify_img, inputs=[img_input, threshold_slider], outputs=[prediction_output, confidence_output]) | |
# Example images (Only show images, no threshold value) | |
gr.Examples( | |
examples=example_images, | |
inputs=[img_input] | |
) | |
# Launch App | |
demo.launch() |