Kaushik066 commited on
Commit
66f68c7
·
verified ·
1 Parent(s): d2aa485

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -21
app.py CHANGED
@@ -20,7 +20,7 @@ import datasets
20
  from torch.utils.data import Dataset, DataLoader
21
 
22
  # For Display
23
- from tqdm.notebook import tqdm
24
 
25
  # Other Generic Libraries
26
  import torch
@@ -96,7 +96,7 @@ def prod_function(transformer_model, prod_dl, prod_data):
96
 
97
  prod_preds = []
98
 
99
- for batch in tqdm(acclerated_prod_dl):
100
  with torch.no_grad():
101
  emb = accelerated_model(**batch)
102
  distance = F.pairwise_distance(emb, emb_prod)
@@ -131,7 +131,7 @@ class CreateDatasetProd():
131
 
132
  def get_pixels(self, img_paths):
133
  pixel_values = []
134
- for path in tqdm(img_paths):
135
  # Read and process Images
136
  img = PIL.Image.open(path)
137
  img = self.transform_prod(img)
@@ -156,8 +156,8 @@ image_paths = []
156
  image_file = glob(os.path.join(data_path, '*.jpg'))
157
  #st.write(image_file)
158
  image_paths.extend(image_file)
159
- st.write('input path size:', len(image_paths))
160
- st.write(image_paths)
161
 
162
  # Create DataLoader for Employees image
163
  dataset_prod_obj = CreateDatasetProd(image_processor_prod)
@@ -168,19 +168,19 @@ prod_dl = DataLoader(prod_ds, batch_size=BATCH_SIZE)
168
  prod_inputs = next(iter(prod_dl))
169
  st.write(prod_inputs['pixel_values'].shape)
170
 
171
- ## Read image from Camera
172
- #enable = st.checkbox("Enable camera")
173
- #picture_path = st.camera_input("Take a picture", disabled=not enable)
174
- #if picture_path:
175
- # # Create DataLoader for Webcam Image
176
- # webcam_ds = dataset_prod_obj.create_dataset(picture_path)
177
- # webcam_dl = DataLoader(webcam_ds, batch_size=BATCH_SIZE)
178
- #
179
- #
180
- #prediction = prod_function(model_pretrained, prod_dl, webcam_dl)
181
- #predictions = torch.cat(prediction, 0).to('cpu')
182
- #match_idx = torch.argmin(predictions)
183
- #if predictions[match_idx] <= 0.3:
184
- # st.write('Welcome: ',image_paths[match_idx].split('/')[-1].split('.')[0])
185
- #else:
186
- # st.write("Match not found")
 
20
  from torch.utils.data import Dataset, DataLoader
21
 
22
  # For Display
23
+ #from tqdm.notebook import tqdm
24
 
25
  # Other Generic Libraries
26
  import torch
 
96
 
97
  prod_preds = []
98
 
99
+ for batch in acclerated_prod_dl:
100
  with torch.no_grad():
101
  emb = accelerated_model(**batch)
102
  distance = F.pairwise_distance(emb, emb_prod)
 
131
 
132
  def get_pixels(self, img_paths):
133
  pixel_values = []
134
+ for path in img_paths:
135
  # Read and process Images
136
  img = PIL.Image.open(path)
137
  img = self.transform_prod(img)
 
156
  image_file = glob(os.path.join(data_path, '*.jpg'))
157
  #st.write(image_file)
158
  image_paths.extend(image_file)
159
+ #st.write('input path size:', len(image_paths))
160
+ #st.write(image_paths)
161
 
162
  # Create DataLoader for Employees image
163
  dataset_prod_obj = CreateDatasetProd(image_processor_prod)
 
168
  prod_inputs = next(iter(prod_dl))
169
  st.write(prod_inputs['pixel_values'].shape)
170
 
171
+ # Read image from Camera
172
+ enable = st.checkbox("Enable camera")
173
+ picture_path = st.camera_input("Take a picture", disabled=not enable)
174
+ if picture_path:
175
+ # Create DataLoader for Webcam Image
176
+ webcam_ds = dataset_prod_obj.create_dataset(picture_path)
177
+ webcam_dl = DataLoader(webcam_ds, batch_size=BATCH_SIZE)
178
+
179
+
180
+ prediction = prod_function(model_pretrained, prod_dl, webcam_dl)
181
+ predictions = torch.cat(prediction, 0).to('cpu')
182
+ match_idx = torch.argmin(predictions)
183
+ if predictions[match_idx] <= 0.3:
184
+ st.write('Welcome: ',image_paths[match_idx].split('/')[-1].split('.')[0])
185
+ else:
186
+ st.write("Match not found")