jervinjosh68 commited on
Commit
8bf15ec
·
1 Parent(s): 9fdfa21

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -25,7 +25,7 @@ def predict(image_name):
25
  inputs = inputs.to(device)
26
  with torch.no_grad():
27
  outputs = model(inputs.unsqueeze(0))
28
- values, indices = torch.topk(outputs, k=5)
29
  print(values,indices)
30
  return {i.item(): v.item() for i, v in zip(indices[0], values.detach()[0])}
31
  def preprocess(image_name):
@@ -57,12 +57,15 @@ def run_gradio():
57
  theme="huggingface",
58
  ).launch(debug=True, enable_queue=True)
59
 
60
- model = AQC_NET(pretrain=True,num_label=5)
61
  if not os.path.exists('weight.pth'):
62
  print("weight.pth does not exist. Downloading...")
63
  get_file("https://github.com/Kaldr4/EEE-199/releases/download/v1/weight.pth", 'weight.pth',"weight.pth")
64
  print("weight.pth downloaded")
65
  else:
66
  print('Specified file (weight.pth) already downloaded. Skipping this step.')
67
- torch.load("weight.pth")
 
 
 
68
  run_gradio()
 
25
  inputs = inputs.to(device)
26
  with torch.no_grad():
27
  outputs = model(inputs.unsqueeze(0))
28
+ values, indices = torch.topk(outputs, k=2)
29
  print(values,indices)
30
  return {i.item(): v.item() for i, v in zip(indices[0], values.detach()[0])}
31
  def preprocess(image_name):
 
57
  theme="huggingface",
58
  ).launch(debug=True, enable_queue=True)
59
 
60
+ model = AQC_NET(pretrain=True, num_label=2)
61
  if not os.path.exists('weight.pth'):
62
  print("weight.pth does not exist. Downloading...")
63
  get_file("https://github.com/Kaldr4/EEE-199/releases/download/v1/weight.pth", 'weight.pth',"weight.pth")
64
  print("weight.pth downloaded")
65
  else:
66
  print('Specified file (weight.pth) already downloaded. Skipping this step.')
67
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+ state_dict = torch.load("weight.pth")
69
+ model.load_state_dict(state_dict)
70
+
71
  run_gradio()