Spaces:
Sleeping
Sleeping
Commit
·
8bf15ec
1
Parent(s):
9fdfa21
updated app.py
Browse files
app.py
CHANGED
@@ -25,7 +25,7 @@ def predict(image_name):
|
|
25 |
inputs = inputs.to(device)
|
26 |
with torch.no_grad():
|
27 |
outputs = model(inputs.unsqueeze(0))
|
28 |
-
values, indices = torch.topk(outputs, k=
|
29 |
print(values,indices)
|
30 |
return {i.item(): v.item() for i, v in zip(indices[0], values.detach()[0])}
|
31 |
def preprocess(image_name):
|
@@ -57,12 +57,15 @@ def run_gradio():
|
|
57 |
theme="huggingface",
|
58 |
).launch(debug=True, enable_queue=True)
|
59 |
|
60 |
-
model = AQC_NET(pretrain=True,num_label=
|
61 |
if not os.path.exists('weight.pth'):
|
62 |
print("weight.pth does not exist. Downloading...")
|
63 |
get_file("https://github.com/Kaldr4/EEE-199/releases/download/v1/weight.pth", 'weight.pth',"weight.pth")
|
64 |
print("weight.pth downloaded")
|
65 |
else:
|
66 |
print('Specified file (weight.pth) already downloaded. Skipping this step.')
|
67 |
-
torch.
|
|
|
|
|
|
|
68 |
run_gradio()
|
|
|
25 |
inputs = inputs.to(device)
|
26 |
with torch.no_grad():
|
27 |
outputs = model(inputs.unsqueeze(0))
|
28 |
+
values, indices = torch.topk(outputs, k=2)
|
29 |
print(values,indices)
|
30 |
return {i.item(): v.item() for i, v in zip(indices[0], values.detach()[0])}
|
31 |
def preprocess(image_name):
|
|
|
57 |
theme="huggingface",
|
58 |
).launch(debug=True, enable_queue=True)
|
59 |
|
60 |
+
model = AQC_NET(pretrain=True, num_label=2)
|
61 |
if not os.path.exists('weight.pth'):
|
62 |
print("weight.pth does not exist. Downloading...")
|
63 |
get_file("https://github.com/Kaldr4/EEE-199/releases/download/v1/weight.pth", 'weight.pth',"weight.pth")
|
64 |
print("weight.pth downloaded")
|
65 |
else:
|
66 |
print('Specified file (weight.pth) already downloaded. Skipping this step.')
|
67 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
68 |
+
state_dict = torch.load("weight.pth")
|
69 |
+
model.load_state_dict(state_dict)
|
70 |
+
|
71 |
run_gradio()
|