Kaushik066 commited on
Commit
b7d97a8
·
verified ·
1 Parent(s): 0b05083

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -10
app.py CHANGED
@@ -147,10 +147,29 @@ class CreateDatasetProd():
147
  del img
148
  gc.collect()
149
  return pixel_values
 
 
 
 
 
 
 
 
150
 
151
- def create_dataset(self, image_paths):
152
 
153
- pixel_values = torch.stack(self.get_pixels(image_paths))
 
 
 
 
 
 
 
 
 
 
 
154
  return CustomDatasetProd(pixel_values=pixel_values)
155
 
156
  # Read images from directory
@@ -163,7 +182,7 @@ image_paths.extend(image_file)
163
 
164
  # Create DataLoader for Employees image
165
  dataset_prod_obj = CreateDatasetProd(image_processor_prod)
166
- prod_ds = dataset_prod_obj.create_dataset(image_paths)
167
  prod_dl = DataLoader(prod_ds, batch_size=BATCH_SIZE)
168
 
169
  # Testing the dataloader
@@ -175,19 +194,19 @@ enable = st.checkbox("Enable camera")
175
  picture = st.camera_input("Take a picture", disabled=not enable)
176
  if picture is not None:
177
  #img = Image.open(picture)
178
- picture.save(webcam_path, "JPEG")
179
- st.write('Image saved as:',webcam_path)
180
 
181
- # Create DataLoader for Webcam Image
182
- webcam_ds = dataset_prod_obj.create_dataset(webcam_path)
183
  webcam_dl = DataLoader(webcam_ds, batch_size=BATCH_SIZE)
184
 
185
- # Run the predictions
186
  prediction = prod_function(model_pretrained, prod_dl, webcam_dl)
187
- predictions = torch.cat(prediction, 0).to('cpu')
188
  match_idx = torch.argmin(predictions)
189
 
190
- # Display the results
191
  if predictions[match_idx] <= 0.3:
192
  st.write('Welcome: ',image_paths[match_idx].split('/')[-1].split('.')[0])
193
  else:
 
147
  del img
148
  gc.collect()
149
  return pixel_values
150
+
151
+ def get_pixel(self, img_path):
152
+ # Read and process Images
153
+ img = Image.open(img_path)
154
+ img = self.transform_prod(img)
155
+
156
+ # Scaling the video to ML model's desired format
157
+ img = self.image_processor(img, return_tensors='pt') #, input_data_format='channels_first')
158
 
159
+ pixel_values = img['pixel_values'].squeeze(0)
160
 
161
+ # Force garbage collection
162
+ del img
163
+ gc.collect()
164
+
165
+ return pixel_values
166
+
167
+ def create_dataset(self, image_paths, webcam=False):
168
+ if webcam == True:
169
+ pixel_values = torch.stack(self.get_pixel(image_paths))
170
+ else:
171
+ pixel_values = torch.stack(self.get_pixels(image_paths))
172
+
173
  return CustomDatasetProd(pixel_values=pixel_values)
174
 
175
  # Read images from directory
 
182
 
183
  # Create DataLoader for Employees image
184
  dataset_prod_obj = CreateDatasetProd(image_processor_prod)
185
+ prod_ds = dataset_prod_obj.create_dataset(image_paths, webcam=False)
186
  prod_dl = DataLoader(prod_ds, batch_size=BATCH_SIZE)
187
 
188
  # Testing the dataloader
 
194
  picture = st.camera_input("Take a picture", disabled=not enable)
195
  if picture is not None:
196
  #img = Image.open(picture)
197
+ #picture.save(webcam_path, "JPEG")
198
+ #st.write('Image saved as:',webcam_path)
199
 
200
+ ## Create DataLoader for Webcam Image
201
+ webcam_ds = dataset_prod_obj.create_dataset(webcam_path, webcam=True)
202
  webcam_dl = DataLoader(webcam_ds, batch_size=BATCH_SIZE)
203
 
204
+ ## Run the predictions
205
  prediction = prod_function(model_pretrained, prod_dl, webcam_dl)
206
+ predictions = torch.cat(prediction, 0).to(device)
207
  match_idx = torch.argmin(predictions)
208
 
209
+ ## Display the results
210
  if predictions[match_idx] <= 0.3:
211
  st.write('Welcome: ',image_paths[match_idx].split('/')[-1].split('.')[0])
212
  else: