Kaushik066 commited on
Commit
8e55f0f
·
verified ·
1 Parent(s): 3b9211b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -94,7 +94,8 @@ def prod_function(transformer_model, prod_dl, prod_data):
94
  accelerated_model.eval()
95
 
96
  # Find Embedding of the image to be evaluated
97
- emb_prod = accelerated_model(acclerated_prod_data)
 
98
 
99
  prod_preds = []
100
 
@@ -186,8 +187,8 @@ prod_ds = dataset_prod_obj.create_dataset(image_paths, webcam=False)
186
  prod_dl = DataLoader(prod_ds, batch_size=BATCH_SIZE)
187
 
188
  # Testing the dataloader
189
- #prod_inputs = next(iter(prod_dl))
190
- #st.write(prod_inputs['pixel_values'].shape)
191
 
192
  # Read image from Camera
193
  enable = st.checkbox("Enable camera")
@@ -205,13 +206,13 @@ if picture is not None:
205
  prod_inputs = next(iter(webcam_dl))
206
  st.write(prod_inputs['pixel_values'].shape)
207
 
208
- ## Run the predictions
209
- #prediction = prod_function(model_pretrained, prod_dl, webcam_dl)
210
- #predictions = torch.cat(prediction, 0).to(device)
211
- #match_idx = torch.argmin(predictions)
212
-
213
- ## Display the results
214
- #if predictions[match_idx] <= 0.3:
215
- # st.write('Welcome: ',image_paths[match_idx].split('/')[-1].split('.')[0])
216
- #else:
217
- # st.write("Match not found")
 
94
  accelerated_model.eval()
95
 
96
  # Find Embedding of the image to be evaluated
97
+ with torch.no_grad():
98
+ emb_prod = accelerated_model(**acclerated_prod_data)
99
 
100
  prod_preds = []
101
 
 
187
  prod_dl = DataLoader(prod_ds, batch_size=BATCH_SIZE)
188
 
189
  # Testing the dataloader
190
+ prod_inputs = next(iter(prod_dl))
191
+ st.write(prod_inputs['pixel_values'].shape)
192
 
193
  # Read image from Camera
194
  enable = st.checkbox("Enable camera")
 
206
  prod_inputs = next(iter(webcam_dl))
207
  st.write(prod_inputs['pixel_values'].shape)
208
 
209
+ # Run the predictions
210
+ prediction = prod_function(model_pretrained, prod_dl, webcam_dl)
211
+ predictions = torch.cat(prediction, 0).to(device)
212
+ match_idx = torch.argmin(predictions)
213
+
214
+ # Display the results
215
+ if predictions[match_idx] <= 0.3:
216
+ st.write('Welcome: ',image_paths[match_idx].split('/')[-1].split('.')[0])
217
+ else:
218
+ st.write("Match not found")