Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
def load_model(model_path): | |
# Architektur aufbauen | |
model = models.resnet50(pretrained=False) | |
model.fc = nn.Linear(2048, 228) | |
# State Dict laden | |
state_dict = torch.load(model_path, map_location=torch.device("cpu")) | |
# Keys ggf. anpassen | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
if k.startswith("predictor."): | |
new_k = k.replace("predictor.", "") | |
else: | |
new_k = k | |
new_state_dict[new_k] = v | |
model.load_state_dict(new_state_dict) | |
model.eval() | |
return model | |
def predict_attributes(model, input_tensor): | |
with torch.no_grad(): | |
output = model(input_tensor) | |
prediction = torch.sigmoid(output).squeeze().numpy() | |
threshold = 0.5 | |
predicted_indices = [i for i, p in enumerate(prediction) if p > threshold] | |
return predicted_indices | |