Spaces:
Running
Running
# Import required libraries | |
import streamlit as st | |
from transformers import ViTForImageClassification, ViTFeatureExtractor | |
from PIL import Image | |
import torch | |
# Load the pre-trained model and feature extractor | |
model_name = "nateraw/vit-age-classifier" | |
model = ViTForImageClassification.from_pretrained(model_name) | |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) | |
# Set up Streamlit app | |
st.set_page_config(page_title="Age Classifier", page_icon="👶") | |
st.title("Age Classification using AI") | |
st.write("Upload an image of a person, and the model will predict their age group.") | |
# Upload image | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Open the uploaded image | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
# Preprocess the image | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
# Perform inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Get the predicted class | |
predicted_class_idx = logits.argmax(-1).item() | |
predicted_age_group = model.config.id2label[predicted_class_idx] | |
# Display the result | |
st.write(f"**Predicted Age Group:** {predicted_age_group}") |