nao_happy_sad / app.py
brightlembo's picture
Update app.py
f6f0940 verified
raw
history blame contribute delete
2.79 kB
import torch
from torchvision import transforms
from PIL import Image
import gradio as gr
import json
from torchvision.models import efficientnet_b7, EfficientNet_B7_Weights
import torch.nn as nn
# Charger les noms des classes
with open("class_names.json", "r") as f:
class_names = json.load(f)
# Charger l'architecture et les poids du modèle
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Charger EfficientNet-B7 avec des poids pré-entraînés
weights = EfficientNet_B7_Weights.DEFAULT
base_model = efficientnet_b7(weights=weights)
# Adapter le modèle pour la classification (ajout d'une couche FC finale)
class CustomEfficientNet(nn.Module):
def __init__(self, base_model, num_classes):
super(CustomEfficientNet, self).__init__()
self.base = nn.Sequential(*list(base_model.children())[:-2]) # Couper la partie classification
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(2560, 512) # Taille de sortie du dernier bloc
self.relu = nn.ReLU()
self.fc2 = nn.Linear(512, num_classes) # Nombre de classes pour la classification
def forward(self, x):
x = self.base(x)
x = self.global_avg_pool(x)
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
# Initialiser le modèle avec 3 classes (ajuste ce nombre selon ton cas)
num_classes = len(class_names) # Nombre de classes dans le fichier JSON
model = CustomEfficientNet(base_model, num_classes).to(device)
# Charger les poids dans le modèle
model.load_state_dict(torch.load("efficientnet_b7_bestv1.pth", map_location=device))
model.eval() # Passer le modèle en mode évaluation
# Définir la taille de l'image
image_size = (224, 224)
# Transformation pour l'image
class GrayscaleToRGB:
def __call__(self, img):
return img.convert("RGB")
valid_test_transforms = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize(image_size),
GrayscaleToRGB(), # Conversion en RGB
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Fonction de prédiction
def predict_image(image):
image_tensor = valid_test_transforms(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image_tensor)
_, predicted_class = torch.max(outputs, 1)
predicted_label = class_names[predicted_class.item()]
return predicted_label
# Interface Gradio
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Prédiction d'images avec PyTorch",
description="Chargez une image pour obtenir une prédiction de classe."
)
if __name__ == "__main__":
interface.launch()