nao_streamlit / app.py
brightlembo's picture
Update app.py
09d9434 verified
raw
history blame contribute delete
2.73 kB
import streamlit as st
import torch
from torchvision import transforms, models
from PIL import Image
import requests
import json
import os
# URL du modèle hébergé sur Hugging Face
MODEL_URL = "https://huggingface.co/brightlembo/nao_sad_happy/blob/main/efficientnet_b7_best.pth"
MODEL_PATH = "efficientnet_b7_best.pth"
CLASS_NAMES_PATH = "class_names.json"
# Télécharger le modèle s'il n'existe pas localement
if not os.path.exists(MODEL_PATH):
st.info("Téléchargement du modèle depuis Hugging Face...")
response = requests.get(MODEL_URL, stream=True)
response.raise_for_status()
with open(MODEL_PATH, "wb") as f:
f.write(response.content)
st.success("Modèle téléchargé avec succès.")
# Charger les noms des classes
if not os.path.exists(CLASS_NAMES_PATH):
st.error(f"Le fichier {CLASS_NAMES_PATH} est introuvable. Veuillez le charger.")
st.stop()
with open(CLASS_NAMES_PATH, "r") as f:
class_names = json.load(f)
# Charger le modèle
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
# Créer le modèle EfficientNet
model = models.efficientnet_b7(pretrained=False) # Initialisation du modèle sans poids pré-entrainés
# Charger les poids du modèle
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval() # Passer en mode évaluation
except Exception as e:
st.error(f"Erreur lors du chargement du modèle : {e}")
st.stop()
# Transformation pour les images
image_size = (224, 224)
class GrayscaleToRGB:
def __call__(self, img):
return img.convert("RGB")
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize(image_size),
GrayscaleToRGB(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Interface utilisateur Streamlit
st.title("Prédiction d'Images avec PyTorch")
st.write("Chargez une image pour obtenir une prédiction de classe.")
uploaded_file = st.file_uploader("Choisissez une image...", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
try:
# Charger et afficher l'image
image = Image.open(uploaded_file)
st.image(image, caption="Image chargée", use_column_width=True)
# Transformation et prédiction
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image_tensor)
_, predicted_class = torch.max(outputs, 1)
predicted_label = class_names[predicted_class.item()]
st.success(f"Classe prédite : {predicted_label}")
except Exception as e:
st.error(f"Erreur lors de la prédiction : {e}")