Pavan2k4 commited on
Commit
8058707
·
verified ·
1 Parent(s): ee47532

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -70
app.py CHANGED
@@ -42,6 +42,7 @@ def predict(image):
42
 
43
  def log_image_details(image_id, image_filename, mask_filename):
44
  file_exists = os.path.exists(CSV_LOG_PATH)
 
45
  current_time = datetime.now()
46
  date = current_time.strftime('%Y-%m-%d')
47
  time = current_time.strftime('%H:%M:%S')
@@ -51,89 +52,131 @@ def log_image_details(image_id, image_filename, mask_filename):
51
  if not file_exists:
52
  writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename'])
53
 
54
- sno = sum(1 for row in open(CSV_LOG_PATH)) if file_exists else 1
 
 
 
 
 
 
 
55
  writer.writerow([sno, date, time, image_id, image_filename, mask_filename])
56
 
57
- def reset_state():
58
- st.session_state.file_uploaded = False
59
- st.session_state.filename = None
60
- st.session_state.mask_filename = None
61
- st.session_state.tr_img = None
62
- if 'page' in st.session_state:
63
- del st.session_state.page
64
-
65
- def process_image(image, timestamp):
66
- filename = f"image_{timestamp}{os.path.splitext(image.name)[1]}"
67
- filepath = os.path.join(UPLOAD_DIR, filename)
68
 
69
- with open(filepath, "wb") as f:
70
- f.write(image.getvalue())
 
 
71
 
72
- if filename.lower().endswith(('.tiff', '.tif')):
73
- st.info('Processing GeoTIFF image...')
74
- convert_gtiff_to_8bit(filepath)
75
- st.success('GeoTIFF converted to 8-bit image')
76
 
77
- return filename, filepath
78
-
79
- def predict_image(img_array, filename, timestamp):
80
- if img_array.shape[0] > 650 or img_array.shape[1] > 650:
81
- split(os.path.join(UPLOAD_DIR, filename), patch_size=256)
82
-
83
- with st.spinner('Analyzing...'):
84
- for patch_filename in os.listdir(PATCHES_DIR):
85
- if patch_filename.endswith(".png"):
86
- patch_path = os.path.join(PATCHES_DIR, patch_filename)
87
- patch_img = Image.open(patch_path)
88
- patch_tr_img = transforms(patch_img)
89
- prediction = predict(patch_tr_img)
90
- mask = (prediction > 0.5).astype(np.uint8) * 255
91
- mask_filename = f"mask_{patch_filename}"
92
- mask_filepath = os.path.join(PRED_PATCHES_DIR, mask_filename)
93
- Image.fromarray(mask).save(mask_filepath)
94
-
95
- merged_mask_filename = f"mask_{timestamp}.png"
96
- merged_mask_filepath = os.path.join(MASK_DIR, merged_mask_filename)
97
- merge(PRED_PATCHES_DIR, merged_mask_filepath, img_array.shape)
98
-
99
- st.info('Cleaning up temporary files...')
100
- for dir in [PATCHES_DIR, PRED_PATCHES_DIR]:
101
- shutil.rmtree(dir)
102
- os.makedirs(dir)
103
- st.success('Temporary files cleaned up')
104
- else:
105
- tr_img = transforms(Image.open(os.path.join(UPLOAD_DIR, filename)))
106
- prediction = predict(tr_img)
107
- mask = (prediction > 0.5).astype(np.uint8) * 255
108
- merged_mask_filename = f"mask_{timestamp}.png"
109
- merged_mask_filepath = os.path.join(MASK_DIR, merged_mask_filename)
110
- Image.fromarray(mask).save(merged_mask_filepath)
111
 
112
- return merged_mask_filepath
 
 
113
 
114
  def upload_page():
115
  if 'file_uploaded' not in st.session_state:
116
  st.session_state.file_uploaded = False
117
 
 
 
 
 
 
 
118
  image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif'])
119
 
120
- if image is not None:
121
- reset_state()
 
122
  timestamp = int(time.time())
123
- filename, filepath = process_image(image, timestamp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- img = Image.open(filepath)
126
  st.image(img, caption='Uploaded Image', use_column_width=True)
127
- st.success(f'Image saved as {filename}')
128
 
129
- st.session_state.filename = filename
 
 
 
130
  img_array = np.array(img)
131
-
132
- mask_filepath = predict_image(img_array, filename, timestamp)
133
- st.session_state.mask_filename = mask_filepath
134
 
135
- log_image_details(timestamp, filename, os.path.basename(mask_filepath))
136
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  st.session_state.file_uploaded = True
138
 
139
  if st.session_state.file_uploaded and st.button('View result'):
@@ -143,28 +186,31 @@ def upload_page():
143
  st.success('Image analyzed')
144
  st.session_state.page = 'result'
145
  st.rerun()
146
-
147
  def result_page():
148
  st.title('Analysis Result')
149
 
150
  if 'filename' not in st.session_state or 'mask_filename' not in st.session_state:
151
  st.error("No image or mask file found. Please upload and process an image first.")
152
  if st.button('Back to Upload'):
153
- reset_state()
 
 
 
154
  st.rerun()
155
  return
156
 
157
  col1, col2 = st.columns(2)
158
 
 
159
  original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename)
160
- mask_path = st.session_state.mask_filename
161
-
162
  if os.path.exists(original_img_path):
163
  original_img = Image.open(original_img_path)
164
  col1.image(original_img, caption='Original Image', use_column_width=True)
165
  else:
166
  col1.error(f"Original image file not found: {original_img_path}")
167
 
 
 
168
  if os.path.exists(mask_path):
169
  mask = Image.open(mask_path)
170
  col2.image(mask, caption='Predicted Mask', use_column_width=True)
@@ -173,23 +219,35 @@ def result_page():
173
 
174
  st.subheader("Overlay with Area of Buildings (sqft)")
175
 
 
176
  if os.path.exists(original_img_path) and os.path.exists(mask_path):
177
  original_np = cv2.imread(original_img_path)
178
  mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
179
 
 
180
  _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
181
 
 
182
  if original_np.shape[:2] != mask_np.shape[:2]:
183
  mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0]))
184
 
 
185
  overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png')
186
 
 
 
 
187
  st.image(overlay_img, caption='Overlay Image', use_column_width=True)
188
  else:
189
  st.error("Image or mask file not found for overlay.")
190
 
191
  if st.button('Back to Upload'):
192
- reset_state()
 
 
 
 
 
193
  st.rerun()
194
 
195
  def main():
@@ -204,4 +262,4 @@ def main():
204
  result_page()
205
 
206
  if __name__ == '__main__':
207
- main()
 
42
 
43
  def log_image_details(image_id, image_filename, mask_filename):
44
  file_exists = os.path.exists(CSV_LOG_PATH)
45
+
46
  current_time = datetime.now()
47
  date = current_time.strftime('%Y-%m-%d')
48
  time = current_time.strftime('%H:%M:%S')
 
52
  if not file_exists:
53
  writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename'])
54
 
55
+ # Get the next S.No
56
+ if file_exists:
57
+ with open(CSV_LOG_PATH, mode='r') as f:
58
+ reader = csv.reader(f)
59
+ sno = sum(1 for row in reader)
60
+ else:
61
+ sno = 1
62
+
63
  writer.writerow([sno, date, time, image_id, image_filename, mask_filename])
64
 
65
+ def overlay_mask(image, mask, alpha=0.5, rgb=[255, 0, 0]):
66
+ # Ensure image is 3-channel
67
+ if len(image.shape) == 2:
68
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
 
 
 
 
 
 
 
69
 
70
+ # Ensure mask is binary and same shape as image
71
+ mask = mask.astype(bool)
72
+ if mask.shape[:2] != image.shape[:2]:
73
+ raise ValueError("Mask and image must have the same dimensions")
74
 
75
+ # Create color overlay
76
+ color_mask = np.zeros_like(image)
77
+ color_mask[mask] = rgb
 
78
 
79
+ # Blend the image and color mask
80
+ output = cv2.addWeighted(image, 1, color_mask, alpha, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ return output
83
+
84
+ import shutil # Add this import at the top of your file
85
 
86
  def upload_page():
87
  if 'file_uploaded' not in st.session_state:
88
  st.session_state.file_uploaded = False
89
 
90
+ if 'filename' not in st.session_state:
91
+ st.session_state.filename = None
92
+
93
+ if 'mask_filename' not in st.session_state:
94
+ st.session_state.mask_filename = None
95
+
96
  image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif'])
97
 
98
+ if image is not None and not st.session_state.file_uploaded:
99
+ bytes_data = image.getvalue()
100
+
101
  timestamp = int(time.time())
102
+ original_filename = image.name
103
+ file_extension = os.path.splitext(original_filename)[1].lower()
104
+
105
+ if file_extension in ['.tiff', '.tif']:
106
+ filename = f"image_{timestamp}.tif"
107
+ converted_filename = f"image_{timestamp}_converted.png"
108
+ else:
109
+ filename = f"image_{timestamp}.png"
110
+ converted_filename = filename
111
+
112
+ filepath = os.path.join(UPLOAD_DIR, filename)
113
+ converted_filepath = os.path.join(UPLOAD_DIR, converted_filename)
114
+
115
+ with open(filepath, "wb") as f:
116
+ f.write(bytes_data)
117
+
118
+ # Check if the uploaded file is a GeoTIFF
119
+ if file_extension in ['.tiff', '.tif']:
120
+ st.info('Processing GeoTIFF image...')
121
+ rgb_image = read_pansharpened_rgb(filepath)
122
+ cv2.imwrite(converted_filepath, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR))
123
+ st.success(f'GeoTIFF converted to 8-bit image and saved as {converted_filename}')
124
+ img = Image.open(converted_filepath)
125
+ else:
126
+ img = Image.open(filepath)
127
 
 
128
  st.image(img, caption='Uploaded Image', use_column_width=True)
129
+ st.success(f'Image saved as {converted_filename}')
130
 
131
+ # Store the full path of the converted image
132
+ st.session_state.filename = converted_filename
133
+
134
+ # Convert image to numpy array
135
  img_array = np.array(img)
 
 
 
136
 
137
+ # Check if image shape is more than 650x650
138
+ if img_array.shape[0] > 650 or img_array.shape[1] > 650:
139
+ # Split image into patches
140
+ split(converted_filepath, patch_size=512)
141
+
142
+ # Display buffer while analyzing
143
+ with st.spinner('Analyzing...'):
144
+ # Predict on each patch
145
+ for patch_filename in os.listdir(patches_folder):
146
+ if patch_filename.endswith(".png"):
147
+ patch_path = os.path.join(patches_folder, patch_filename)
148
+ patch_img = Image.open(patch_path)
149
+ patch_tr_img = transforms(patch_img)
150
+ prediction = predict(patch_tr_img)
151
+ mask = (prediction > 0.5).astype(np.uint8) * 255
152
+ mask_filename = f"mask_{patch_filename}"
153
+ mask_filepath = os.path.join(pred_patches, mask_filename)
154
+ Image.fromarray(mask).save(mask_filepath)
155
+
156
+ # Merge predicted patches
157
+ merged_mask_filename = f"generated_masks/mask_{timestamp}.png"
158
+ merge(pred_patches, merged_mask_filename, img_array.shape)
159
+
160
+ # Save merged mask
161
+ st.session_state.mask_filename = merged_mask_filename
162
+
163
+ # Clean up temporary patch files
164
+ st.info('Cleaning up temporary files...')
165
+ shutil.rmtree(patches_folder)
166
+ shutil.rmtree(pred_patches)
167
+ os.makedirs(patches_folder) # Recreate empty folders
168
+ os.makedirs(pred_patches)
169
+ st.success('Temporary files cleaned up')
170
+ else:
171
+ # Predict on whole image
172
+ st.session_state.tr_img = transforms(img)
173
+ prediction = predict(st.session_state.tr_img)
174
+ mask = (prediction > 0.5).astype(np.uint8) * 255
175
+ mask_filename = f"mask_{timestamp}.png"
176
+ mask_filepath = os.path.join(MASK_DIR, mask_filename)
177
+ Image.fromarray(mask).save(mask_filepath)
178
+ st.session_state.mask_filename = mask_filepath
179
+
180
  st.session_state.file_uploaded = True
181
 
182
  if st.session_state.file_uploaded and st.button('View result'):
 
186
  st.success('Image analyzed')
187
  st.session_state.page = 'result'
188
  st.rerun()
 
189
  def result_page():
190
  st.title('Analysis Result')
191
 
192
  if 'filename' not in st.session_state or 'mask_filename' not in st.session_state:
193
  st.error("No image or mask file found. Please upload and process an image first.")
194
  if st.button('Back to Upload'):
195
+ st.session_state.page = 'upload'
196
+ st.session_state.file_uploaded = False
197
+ st.session_state.filename = None
198
+ st.session_state.mask_filename = None
199
  st.rerun()
200
  return
201
 
202
  col1, col2 = st.columns(2)
203
 
204
+ # Display original image
205
  original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename)
 
 
206
  if os.path.exists(original_img_path):
207
  original_img = Image.open(original_img_path)
208
  col1.image(original_img, caption='Original Image', use_column_width=True)
209
  else:
210
  col1.error(f"Original image file not found: {original_img_path}")
211
 
212
+ # Display predicted mask
213
+ mask_path = st.session_state.mask_filename
214
  if os.path.exists(mask_path):
215
  mask = Image.open(mask_path)
216
  col2.image(mask, caption='Predicted Mask', use_column_width=True)
 
219
 
220
  st.subheader("Overlay with Area of Buildings (sqft)")
221
 
222
+ # Display overlayed image
223
  if os.path.exists(original_img_path) and os.path.exists(mask_path):
224
  original_np = cv2.imread(original_img_path)
225
  mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
226
 
227
+ # Ensure mask is binary
228
  _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
229
 
230
+ # Resize mask to match original image size if necessary
231
  if original_np.shape[:2] != mask_np.shape[:2]:
232
  mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0]))
233
 
234
+ # Process and overlay image
235
  overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png')
236
 
237
+ # Convert BGR to RGB for displaying with st.image
238
+ # overlay_rgb = cv2.cvtColor(overlay_img, cv2.COLOR_BGR2RGB)
239
+
240
  st.image(overlay_img, caption='Overlay Image', use_column_width=True)
241
  else:
242
  st.error("Image or mask file not found for overlay.")
243
 
244
  if st.button('Back to Upload'):
245
+ shutil.rmtree(patches_folder)
246
+ shutil.rmtree(pred_patches)
247
+ st.session_state.page = 'upload'
248
+ st.session_state.file_uploaded = False
249
+ st.session_state.filename = None
250
+ st.session_state.mask_filename = None
251
  st.rerun()
252
 
253
  def main():
 
262
  result_page()
263
 
264
  if __name__ == '__main__':
265
+ main()