Pavan2k4 commited on
Commit
359a4fd
·
verified ·
1 Parent(s): b047243

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -118
app.py CHANGED
@@ -19,20 +19,18 @@ from Utils.split_merge import split, merge
19
  from Utils.convert_raster import convert_gtiff_to_8bit
20
  import shutil
21
 
22
- patches_folder = 'data/Patches'
23
- pred_patches = 'data/Patch_pred'
24
- os.makedirs(patches_folder, exist_ok=True)
25
- os.makedirs(pred_patches, exist_ok=True)
26
-
27
- # Define the upload directories
28
  UPLOAD_DIR = "data/uploaded_images"
29
  MASK_DIR = "data/generated_masks"
 
 
30
  CSV_LOG_PATH = "image_log.csv"
31
 
32
- # Create the directories if they don't exist
33
- os.makedirs(UPLOAD_DIR, exist_ok=True)
34
- os.makedirs(MASK_DIR, exist_ok=True)
35
 
 
36
  model = reunet_cbam()
37
  model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict'])
38
  model.eval()
@@ -44,7 +42,6 @@ def predict(image):
44
 
45
  def log_image_details(image_id, image_filename, mask_filename):
46
  file_exists = os.path.exists(CSV_LOG_PATH)
47
-
48
  current_time = datetime.now()
49
  date = current_time.strftime('%Y-%m-%d')
50
  time = current_time.strftime('%H:%M:%S')
@@ -54,35 +51,9 @@ def log_image_details(image_id, image_filename, mask_filename):
54
  if not file_exists:
55
  writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename'])
56
 
57
- # Get the next S.No
58
- if file_exists:
59
- with open(CSV_LOG_PATH, mode='r') as f:
60
- reader = csv.reader(f)
61
- sno = sum(1 for row in reader)
62
- else:
63
- sno = 1
64
-
65
  writer.writerow([sno, date, time, image_id, image_filename, mask_filename])
66
 
67
- def overlay_mask(image, mask, alpha=0.5, rgb=[255, 0, 0]):
68
- # Ensure image is 3-channel
69
- if len(image.shape) == 2:
70
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
71
-
72
- # Ensure mask is binary and same shape as image
73
- mask = mask.astype(bool)
74
- if mask.shape[:2] != image.shape[:2]:
75
- raise ValueError("Mask and image must have the same dimensions")
76
-
77
- # Create color overlay
78
- color_mask = np.zeros_like(image)
79
- color_mask[mask] = rgb
80
-
81
- # Blend the image and color mask
82
- output = cv2.addWeighted(image, 1, color_mask, alpha, 0)
83
-
84
- return output
85
-
86
  def reset_state():
87
  st.session_state.file_uploaded = False
88
  st.session_state.filename = None
@@ -91,95 +62,78 @@ def reset_state():
91
  if 'page' in st.session_state:
92
  del st.session_state.page
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def upload_page():
95
  if 'file_uploaded' not in st.session_state:
96
  st.session_state.file_uploaded = False
97
 
98
- if 'filename' not in st.session_state:
99
- st.session_state.filename = None
100
-
101
- if 'mask_filename' not in st.session_state:
102
- st.session_state.mask_filename = None
103
-
104
  image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif'])
105
 
106
  if image is not None:
107
- reset_state() # Reset the state when a new image is uploaded
108
- bytes_data = image.getvalue()
109
-
110
  timestamp = int(time.time())
111
- original_filename = image.name
112
- file_extension = os.path.splitext(original_filename)[1].lower()
113
-
114
- if file_extension in ['.tiff', '.tif']:
115
- filename = f"image_{timestamp}.tif"
116
- else:
117
- filename = f"image_{timestamp}.png"
118
-
119
- filepath = os.path.join(UPLOAD_DIR, filename)
120
-
121
- with open(filepath, "wb") as f:
122
- f.write(bytes_data)
123
-
124
- # Check if the uploaded file is a GeoTIFF
125
- if file_extension in ['.tiff', '.tif']:
126
- st.info('Processing GeoTIFF image...')
127
- convert_gtiff_to_8bit(filepath)
128
- st.success('GeoTIFF converted to 8-bit image')
129
 
130
  img = Image.open(filepath)
131
  st.image(img, caption='Uploaded Image', use_column_width=True)
132
  st.success(f'Image saved as {filename}')
133
 
134
- # Store the full path of the uploaded image
135
  st.session_state.filename = filename
136
-
137
- # Convert image to numpy array
138
  img_array = np.array(img)
 
 
 
139
 
140
- # Check if image shape is more than 650x650
141
- if img_array.shape[0] > 650 or img_array.shape[1] > 650:
142
- # Split image into patches
143
- split(filepath, patch_size=256)
144
-
145
- # Display buffer while analyzing
146
- with st.spinner('Analyzing...'):
147
- # Predict on each patch
148
- for patch_filename in os.listdir(patches_folder):
149
- if patch_filename.endswith(".png"):
150
- patch_path = os.path.join(patches_folder, patch_filename)
151
- patch_img = Image.open(patch_path)
152
- patch_tr_img = transforms(patch_img)
153
- prediction = predict(patch_tr_img)
154
- mask = (prediction > 0.5).astype(np.uint8) * 255
155
- mask_filename = f"mask_{patch_filename}"
156
- mask_filepath = os.path.join(pred_patches, mask_filename)
157
- Image.fromarray(mask).save(mask_filepath)
158
-
159
- # Merge predicted patches
160
- merged_mask_filename = f"data/generated_masks/mask_{timestamp}.png"
161
- merge(pred_patches, merged_mask_filename, img_array.shape)
162
-
163
- # Save merged mask
164
- st.session_state.mask_filename = merged_mask_filename
165
-
166
- # Clean up temporary patch files
167
- st.info('Cleaning up temporary files...')
168
- shutil.rmtree(patches_folder)
169
- shutil.rmtree(pred_patches)
170
- os.makedirs(patches_folder) # Recreate empty folders
171
- os.makedirs(pred_patches)
172
- st.success('Temporary files cleaned up')
173
- else:
174
- # Predict on whole image
175
- st.session_state.tr_img = transforms(img)
176
- prediction = predict(st.session_state.tr_img)
177
- mask = (prediction > 0.5).astype(np.uint8) * 255
178
- mask_filename = f"mask_{timestamp}.png"
179
- mask_filepath = os.path.join(MASK_DIR, mask_filename)
180
- Image.fromarray(mask).save(mask_filepath)
181
- st.session_state.mask_filename = mask_filepath
182
-
183
  st.session_state.file_uploaded = True
184
 
185
  if st.session_state.file_uploaded and st.button('View result'):
@@ -202,16 +156,15 @@ def result_page():
202
 
203
  col1, col2 = st.columns(2)
204
 
205
- # Display original image
206
  original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename)
 
 
207
  if os.path.exists(original_img_path):
208
  original_img = Image.open(original_img_path)
209
  col1.image(original_img, caption='Original Image', use_column_width=True)
210
  else:
211
  col1.error(f"Original image file not found: {original_img_path}")
212
 
213
- # Display predicted mask
214
- mask_path = st.session_state.mask_filename
215
  if os.path.exists(mask_path):
216
  mask = Image.open(mask_path)
217
  col2.image(mask, caption='Predicted Mask', use_column_width=True)
@@ -220,19 +173,15 @@ def result_page():
220
 
221
  st.subheader("Overlay with Area of Buildings (sqft)")
222
 
223
- # Display overlayed image
224
  if os.path.exists(original_img_path) and os.path.exists(mask_path):
225
  original_np = cv2.imread(original_img_path)
226
  mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
227
 
228
- # Ensure mask is binary
229
  _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
230
 
231
- # Resize mask to match original image size if necessary
232
  if original_np.shape[:2] != mask_np.shape[:2]:
233
  mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0]))
234
 
235
- # Process and overlay image
236
  overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png')
237
 
238
  st.image(overlay_img, caption='Overlay Image', use_column_width=True)
@@ -255,4 +204,4 @@ def main():
255
  result_page()
256
 
257
  if __name__ == '__main__':
258
- main()
 
19
  from Utils.convert_raster import convert_gtiff_to_8bit
20
  import shutil
21
 
22
+ # Define directories
 
 
 
 
 
23
  UPLOAD_DIR = "data/uploaded_images"
24
  MASK_DIR = "data/generated_masks"
25
+ PATCHES_DIR = 'data/Patches'
26
+ PRED_PATCHES_DIR = 'data/Patch_pred'
27
  CSV_LOG_PATH = "image_log.csv"
28
 
29
+ # Create directories
30
+ for directory in [UPLOAD_DIR, MASK_DIR, PATCHES_DIR, PRED_PATCHES_DIR]:
31
+ os.makedirs(directory, exist_ok=True)
32
 
33
+ # Load model
34
  model = reunet_cbam()
35
  model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict'])
36
  model.eval()
 
42
 
43
  def log_image_details(image_id, image_filename, mask_filename):
44
  file_exists = os.path.exists(CSV_LOG_PATH)
 
45
  current_time = datetime.now()
46
  date = current_time.strftime('%Y-%m-%d')
47
  time = current_time.strftime('%H:%M:%S')
 
51
  if not file_exists:
52
  writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename'])
53
 
54
+ sno = sum(1 for row in open(CSV_LOG_PATH)) if file_exists else 1
 
 
 
 
 
 
 
55
  writer.writerow([sno, date, time, image_id, image_filename, mask_filename])
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def reset_state():
58
  st.session_state.file_uploaded = False
59
  st.session_state.filename = None
 
62
  if 'page' in st.session_state:
63
  del st.session_state.page
64
 
65
+ def process_image(image, timestamp):
66
+ filename = f"image_{timestamp}{os.path.splitext(image.name)[1]}"
67
+ filepath = os.path.join(UPLOAD_DIR, filename)
68
+
69
+ with open(filepath, "wb") as f:
70
+ f.write(image.getvalue())
71
+
72
+ if filename.lower().endswith(('.tiff', '.tif')):
73
+ st.info('Processing GeoTIFF image...')
74
+ convert_gtiff_to_8bit(filepath)
75
+ st.success('GeoTIFF converted to 8-bit image')
76
+
77
+ return filename, filepath
78
+
79
+ def predict_image(img_array, filename, timestamp):
80
+ if img_array.shape[0] > 650 or img_array.shape[1] > 650:
81
+ split(os.path.join(UPLOAD_DIR, filename), patch_size=256)
82
+
83
+ with st.spinner('Analyzing...'):
84
+ for patch_filename in os.listdir(PATCHES_DIR):
85
+ if patch_filename.endswith(".png"):
86
+ patch_path = os.path.join(PATCHES_DIR, patch_filename)
87
+ patch_img = Image.open(patch_path)
88
+ patch_tr_img = transforms(patch_img)
89
+ prediction = predict(patch_tr_img)
90
+ mask = (prediction > 0.5).astype(np.uint8) * 255
91
+ mask_filename = f"mask_{patch_filename}"
92
+ mask_filepath = os.path.join(PRED_PATCHES_DIR, mask_filename)
93
+ Image.fromarray(mask).save(mask_filepath)
94
+
95
+ merged_mask_filename = f"mask_{timestamp}.png"
96
+ merged_mask_filepath = os.path.join(MASK_DIR, merged_mask_filename)
97
+ merge(PRED_PATCHES_DIR, merged_mask_filepath, img_array.shape)
98
+
99
+ st.info('Cleaning up temporary files...')
100
+ for dir in [PATCHES_DIR, PRED_PATCHES_DIR]:
101
+ shutil.rmtree(dir)
102
+ os.makedirs(dir)
103
+ st.success('Temporary files cleaned up')
104
+ else:
105
+ tr_img = transforms(Image.open(os.path.join(UPLOAD_DIR, filename)))
106
+ prediction = predict(tr_img)
107
+ mask = (prediction > 0.5).astype(np.uint8) * 255
108
+ merged_mask_filename = f"mask_{timestamp}.png"
109
+ merged_mask_filepath = os.path.join(MASK_DIR, merged_mask_filename)
110
+ Image.fromarray(mask).save(merged_mask_filepath)
111
+
112
+ return merged_mask_filepath
113
+
114
  def upload_page():
115
  if 'file_uploaded' not in st.session_state:
116
  st.session_state.file_uploaded = False
117
 
 
 
 
 
 
 
118
  image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif'])
119
 
120
  if image is not None:
121
+ reset_state()
 
 
122
  timestamp = int(time.time())
123
+ filename, filepath = process_image(image, timestamp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  img = Image.open(filepath)
126
  st.image(img, caption='Uploaded Image', use_column_width=True)
127
  st.success(f'Image saved as {filename}')
128
 
 
129
  st.session_state.filename = filename
 
 
130
  img_array = np.array(img)
131
+
132
+ mask_filepath = predict_image(img_array, filename, timestamp)
133
+ st.session_state.mask_filename = mask_filepath
134
 
135
+ log_image_details(timestamp, filename, os.path.basename(mask_filepath))
136
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  st.session_state.file_uploaded = True
138
 
139
  if st.session_state.file_uploaded and st.button('View result'):
 
156
 
157
  col1, col2 = st.columns(2)
158
 
 
159
  original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename)
160
+ mask_path = st.session_state.mask_filename
161
+
162
  if os.path.exists(original_img_path):
163
  original_img = Image.open(original_img_path)
164
  col1.image(original_img, caption='Original Image', use_column_width=True)
165
  else:
166
  col1.error(f"Original image file not found: {original_img_path}")
167
 
 
 
168
  if os.path.exists(mask_path):
169
  mask = Image.open(mask_path)
170
  col2.image(mask, caption='Predicted Mask', use_column_width=True)
 
173
 
174
  st.subheader("Overlay with Area of Buildings (sqft)")
175
 
 
176
  if os.path.exists(original_img_path) and os.path.exists(mask_path):
177
  original_np = cv2.imread(original_img_path)
178
  mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
179
 
 
180
  _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
181
 
 
182
  if original_np.shape[:2] != mask_np.shape[:2]:
183
  mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0]))
184
 
 
185
  overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png')
186
 
187
  st.image(overlay_img, caption='Overlay Image', use_column_width=True)
 
204
  result_page()
205
 
206
  if __name__ == '__main__':
207
+ main()