Pavan2k4 commited on
Commit
99c5651
·
verified ·
1 Parent(s): 88871b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -51
app.py CHANGED
@@ -1,40 +1,45 @@
1
  import streamlit as st
2
  import sys
 
 
 
 
 
 
 
 
 
 
 
3
  sys.path.append('Utils')
4
  sys.path.append('model')
5
- import torch
6
  from model.CBAM.reunet_cbam import reunet_cbam
7
- import cv2
8
- from PIL import Image
9
  from model.transform import transforms
10
- import numpy as np
11
  from model.unet import UNET
12
  from Utils.area import pixel_to_sqft, process_and_overlay_image
13
- import matplotlib.pyplot as plt
14
- import time
15
- import os
16
- import csv
17
- from datetime import datetime
18
  from split_merge import split, merge
19
- from Utils.convert import read_pansharpened_rgb
20
- import shutil
21
-
22
 
23
- # Define directories
24
- UPLOAD_DIR = "uploaded_images/"
25
- MASK_DIR = "generated_masks/"
26
- patches_folder = "Patches/"
27
- pred_patches = "Patch_pred/"
28
- CSV_LOG_PATH = "image_log.csv"
29
 
30
  # Create directories
31
- for directory in [UPLOAD_DIR, MASK_DIR, patches_folder, pred_patches]:
32
  os.makedirs(directory, exist_ok=True)
33
 
34
  # Load model
35
- model = reunet_cbam()
36
- model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict'])
37
- model.eval()
 
 
 
 
 
38
 
39
  def predict(image):
40
  with torch.no_grad():
@@ -63,27 +68,6 @@ def log_image_details(image_id, image_filename, mask_filename):
63
 
64
  writer.writerow([sno, date, time, image_id, image_filename, mask_filename])
65
 
66
- def overlay_mask(image, mask, alpha=0.5, rgb=[255, 0, 0]):
67
- # Ensure image is 3-channel
68
- if len(image.shape) == 2:
69
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
70
-
71
- # Ensure mask is binary and same shape as image
72
- mask = mask.astype(bool)
73
- if mask.shape[:2] != image.shape[:2]:
74
- raise ValueError("Mask and image must have the same dimensions")
75
-
76
- # Create color overlay
77
- color_mask = np.zeros_like(image)
78
- color_mask[mask] = rgb
79
-
80
- # Blend the image and color mask
81
- output = cv2.addWeighted(image, 1, color_mask, alpha, 0)
82
-
83
- return output
84
-
85
- import shutil # Add this import at the top of your file
86
-
87
  def upload_page():
88
  if 'file_uploaded' not in st.session_state:
89
  st.session_state.file_uploaded = False
@@ -155,8 +139,9 @@ def upload_page():
155
  Image.fromarray(mask).save(mask_filepath)
156
 
157
  # Merge predicted patches
158
- merged_mask_filename = f"generated_masks/mask_{timestamp}.png"
159
- merge(pred_patches, merged_mask_filename, img_array.shape)
 
160
 
161
  # Save merged mask
162
  st.session_state.mask_filename = merged_mask_filename
@@ -176,7 +161,7 @@ def upload_page():
176
  mask_filename = f"mask_{timestamp}.png"
177
  mask_filepath = os.path.join(MASK_DIR, mask_filename)
178
  Image.fromarray(mask).save(mask_filepath)
179
- st.session_state.mask_filename = mask_filepath
180
 
181
  st.session_state.file_uploaded = True
182
 
@@ -187,6 +172,7 @@ def upload_page():
187
  st.success('Image analyzed')
188
  st.session_state.page = 'result'
189
  st.rerun()
 
190
  def result_page():
191
  st.title('Analysis Result')
192
 
@@ -211,7 +197,7 @@ def result_page():
211
  col1.error(f"Original image file not found: {original_img_path}")
212
 
213
  # Display predicted mask
214
- mask_path = st.session_state.mask_filename
215
  if os.path.exists(mask_path):
216
  mask = Image.open(mask_path)
217
  col2.image(mask, caption='Predicted Mask', use_column_width=True)
@@ -235,9 +221,6 @@ def result_page():
235
  # Process and overlay image
236
  overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png')
237
 
238
- # Convert BGR to RGB for displaying with st.image
239
- # overlay_rgb = cv2.cvtColor(overlay_img, cv2.COLOR_BGR2RGB)
240
-
241
  st.image(overlay_img, caption='Overlay Image', use_column_width=True)
242
  else:
243
  st.error("Image or mask file not found for overlay.")
@@ -263,4 +246,4 @@ def main():
263
  result_page()
264
 
265
  if __name__ == '__main__':
266
- main()
 
1
  import streamlit as st
2
  import sys
3
+ import os
4
+ import shutil
5
+ import time
6
+ from datetime import datetime
7
+ import csv
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+ import torch
12
+
13
+ # Adjust import paths as needed
14
  sys.path.append('Utils')
15
  sys.path.append('model')
 
16
  from model.CBAM.reunet_cbam import reunet_cbam
 
 
17
  from model.transform import transforms
 
18
  from model.unet import UNET
19
  from Utils.area import pixel_to_sqft, process_and_overlay_image
 
 
 
 
 
20
  from split_merge import split, merge
21
+ from Utils.convert import read_pansharpened_rgb
 
 
22
 
23
+ # Define directories for Hugging Face Spaces
24
+ UPLOAD_DIR = "/tmp/uploaded_images/"
25
+ MASK_DIR = "/tmp/generated_masks/"
26
+ patches_folder = "/tmp/Patches/"
27
+ pred_patches = "/tmp/Patch_pred/"
28
+ CSV_LOG_PATH = "outputs/image_log.csv"
29
 
30
  # Create directories
31
+ for directory in [UPLOAD_DIR, MASK_DIR, patches_folder, pred_patches, "outputs"]:
32
  os.makedirs(directory, exist_ok=True)
33
 
34
  # Load model
35
+ @st.cache_resource
36
+ def load_model():
37
+ model = reunet_cbam()
38
+ model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict'])
39
+ model.eval()
40
+ return model
41
+
42
+ model = load_model()
43
 
44
  def predict(image):
45
  with torch.no_grad():
 
68
 
69
  writer.writerow([sno, date, time, image_id, image_filename, mask_filename])
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def upload_page():
72
  if 'file_uploaded' not in st.session_state:
73
  st.session_state.file_uploaded = False
 
139
  Image.fromarray(mask).save(mask_filepath)
140
 
141
  # Merge predicted patches
142
+ merged_mask_filename = f"mask_{timestamp}.png"
143
+ merged_mask_path = os.path.join(MASK_DIR, merged_mask_filename)
144
+ merge(pred_patches, merged_mask_path, img_array.shape)
145
 
146
  # Save merged mask
147
  st.session_state.mask_filename = merged_mask_filename
 
161
  mask_filename = f"mask_{timestamp}.png"
162
  mask_filepath = os.path.join(MASK_DIR, mask_filename)
163
  Image.fromarray(mask).save(mask_filepath)
164
+ st.session_state.mask_filename = mask_filename
165
 
166
  st.session_state.file_uploaded = True
167
 
 
172
  st.success('Image analyzed')
173
  st.session_state.page = 'result'
174
  st.rerun()
175
+
176
  def result_page():
177
  st.title('Analysis Result')
178
 
 
197
  col1.error(f"Original image file not found: {original_img_path}")
198
 
199
  # Display predicted mask
200
+ mask_path = os.path.join(MASK_DIR, st.session_state.mask_filename)
201
  if os.path.exists(mask_path):
202
  mask = Image.open(mask_path)
203
  col2.image(mask, caption='Predicted Mask', use_column_width=True)
 
221
  # Process and overlay image
222
  overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png')
223
 
 
 
 
224
  st.image(overlay_img, caption='Overlay Image', use_column_width=True)
225
  else:
226
  st.error("Image or mask file not found for overlay.")
 
246
  result_page()
247
 
248
  if __name__ == '__main__':
249
+ main()