Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -34,6 +34,19 @@ def get_model(model_name, classes, device):
|
|
34 |
model.load_state_dict(torch.load('BaseLine-Model.pt', map_location=torch.device(device)))
|
35 |
|
36 |
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def make_predictions(input_img, model_name):
|
39 |
classes = ['buildings','forest', 'glacier', 'mountain', 'sea', 'street']
|
|
|
34 |
model.load_state_dict(torch.load('BaseLine-Model.pt', map_location=torch.device(device)))
|
35 |
|
36 |
return model
|
37 |
+
|
38 |
+
def get_transform(input_img, device):
|
39 |
+
normalize = transforms.Normalize(
|
40 |
+
[0.485, 0.456, 0.406],
|
41 |
+
[0.229, 0.224, 0.225]
|
42 |
+
)
|
43 |
+
|
44 |
+
test_transform = transforms.Compose([
|
45 |
+
transforms.ToTensor(),
|
46 |
+
normalize,
|
47 |
+
])
|
48 |
+
input_img = test_transform(input_img).unsqueeze(0).to(device)
|
49 |
+
return input_img
|
50 |
|
51 |
def make_predictions(input_img, model_name):
|
52 |
classes = ['buildings','forest', 'glacier', 'mountain', 'sea', 'street']
|