import gradio as gr import torchvision.transforms as transforms from torchvision.transforms import InterpolationMode import torch from huggingface_hub import hf_hub_download from model import Model # Load Model model_path = hf_hub_download( repo_id="itserr/exvoto_classifier_convnext_base_224", filename="model.pt" ) model = Model('convnext_base') ckpt = torch.load(model_path, map_location=torch.device("cpu")) # Ensure compatibility model.load_state_dict(ckpt['model']) device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) model.eval() # Image Transformations transform = transforms.Compose([ transforms.Resize(size=(224,224), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Classification Function def classify_img(img, threshold): classification_threshold = threshold img_tensor = transform(img).unsqueeze(0).to(device) with torch.no_grad(): pred = model(img_tensor) score = torch.sigmoid(pred).item() # Determine Prediction if score >= classification_threshold: label = "✅ This is an **Ex-Voto** image!" else: label = "❌ This is **NOT** an Ex-Voto image." # Format Confidence Score confidence = f"The probability that the image is an ex-voto is: {score:.2%}" return label, confidence example_images = [['examples/exvoto1.jpg', None], ['examples/exvoto2.jpg', None], ['examples/nonexvoto1.jpg', None], ['examples/nonexvoto2.jpg', None], ['examples/natural1.jpg', None], ['examples/natural2.jpg', None],] # Function to Clear Outputs When a New Image is Uploaded def clear_outputs(img): return gr.update(value=""), gr.update(value="") # Gradio Interface with gr.Blocks() as demo: gr.Markdown("## Ex-Voto Image Classifier") gr.Markdown("📸 **Upload an image** to check if it's an **Ex-Voto** painting!") with gr.Row(): with gr.Column(scale=2): # Left section: Image upload & slider img_input = gr.Image(type="pil") threshold_slider = gr.Slider( minimum=0.5, maximum=1.0, value=0.7, step=0.1, label="Classification Threshold" ) submit_btn = gr.Button("Classify") with gr.Column(scale=1): # Right section: Prediction & Confidence prediction_output = gr.Textbox(label="Prediction", interactive=False) confidence_output = gr.Textbox(label="Confidence Score", interactive=False) # Clear outputs when a new image is uploaded img_input.change(fn=clear_outputs, inputs=[img_input], outputs=[prediction_output, confidence_output]) # Submit button triggers classification submit_btn.click(fn=classify_img, inputs=[img_input, threshold_slider], outputs=[prediction_output, confidence_output]) # Example images (Only show images, no threshold value) gr.Examples( examples=example_images, inputs=[img_input] ) # Launch App demo.launch()