Spaces:
Sleeping
Sleeping
File size: 9,164 Bytes
35d85a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
import streamlit as st
import sys
import torch
from model.CBAM.reunet_cbam import reunet_cbam
import cv2
from PIL import Image
from model.transform import transforms
import numpy as np
from model.unet import UNET
from area import pixel_to_sqft, process_and_overlay_image
import matplotlib.pyplot as plt
import time
import os
import csv
from datetime import datetime
from split_merge import split, merge
from convert_raster import convert_gtiff_to_8bit
import shutil
patches_folder = 'data/Patches'
pred_patches = 'data/Patch_pred'
os.makedirs(patches_folder, exist_ok=True)
os.makedirs(pred_patches, exist_ok=True)
# Define the upload directories
UPLOAD_DIR = "data/uploaded_images"
MASK_DIR = "data/generated_masks"
CSV_LOG_PATH = "image_log.csv"
# Create the directories if they don't exist
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(MASK_DIR, exist_ok=True)
model = reunet_cbam()
model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict'])
model.eval()
def predict(image):
with torch.no_grad():
output = model(image.unsqueeze(0))
return output.squeeze().cpu().numpy()
def log_image_details(image_id, image_filename, mask_filename):
file_exists = os.path.exists(CSV_LOG_PATH)
current_time = datetime.now()
date = current_time.strftime('%Y-%m-%d')
time = current_time.strftime('%H:%M:%S')
with open(CSV_LOG_PATH, mode='a', newline='') as file:
writer = csv.writer(file)
if not file_exists:
writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename'])
# Get the next S.No
if file_exists:
with open(CSV_LOG_PATH, mode='r') as f:
reader = csv.reader(f)
sno = sum(1 for row in reader)
else:
sno = 1
writer.writerow([sno, date, time, image_id, image_filename, mask_filename])
def overlay_mask(image, mask, alpha=0.5, rgb=[255, 0, 0]):
# Ensure image is 3-channel
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# Ensure mask is binary and same shape as image
mask = mask.astype(bool)
if mask.shape[:2] != image.shape[:2]:
raise ValueError("Mask and image must have the same dimensions")
# Create color overlay
color_mask = np.zeros_like(image)
color_mask[mask] = rgb
# Blend the image and color mask
output = cv2.addWeighted(image, 1, color_mask, alpha, 0)
return output
def reset_state():
st.session_state.file_uploaded = False
st.session_state.filename = None
st.session_state.mask_filename = None
st.session_state.tr_img = None
if 'page' in st.session_state:
del st.session_state.page
def upload_page():
if 'file_uploaded' not in st.session_state:
st.session_state.file_uploaded = False
if 'filename' not in st.session_state:
st.session_state.filename = None
if 'mask_filename' not in st.session_state:
st.session_state.mask_filename = None
image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif'])
if image is not None:
reset_state() # Reset the state when a new image is uploaded
bytes_data = image.getvalue()
timestamp = int(time.time())
original_filename = image.name
file_extension = os.path.splitext(original_filename)[1].lower()
if file_extension in ['.tiff', '.tif']:
filename = f"image_{timestamp}.tif"
else:
filename = f"image_{timestamp}.png"
filepath = os.path.join(UPLOAD_DIR, filename)
with open(filepath, "wb") as f:
f.write(bytes_data)
# Check if the uploaded file is a GeoTIFF
if file_extension in ['.tiff', '.tif']:
st.info('Processing GeoTIFF image...')
convert_gtiff_to_8bit(filepath)
st.success('GeoTIFF converted to 8-bit image')
img = Image.open(filepath)
st.image(img, caption='Uploaded Image', use_column_width=True)
st.success(f'Image saved as {filename}')
# Store the full path of the uploaded image
st.session_state.filename = filename
# Convert image to numpy array
img_array = np.array(img)
# Check if image shape is more than 650x650
if img_array.shape[0] > 650 or img_array.shape[1] > 650:
# Split image into patches
split(filepath, patch_size=256)
# Display buffer while analyzing
with st.spinner('Analyzing...'):
# Predict on each patch
for patch_filename in os.listdir(patches_folder):
if patch_filename.endswith(".png"):
patch_path = os.path.join(patches_folder, patch_filename)
patch_img = Image.open(patch_path)
patch_tr_img = transforms(patch_img)
prediction = predict(patch_tr_img)
mask = (prediction > 0.5).astype(np.uint8) * 255
mask_filename = f"mask_{patch_filename}"
mask_filepath = os.path.join(pred_patches, mask_filename)
Image.fromarray(mask).save(mask_filepath)
# Merge predicted patches
merged_mask_filename = f"generated_masks/mask_{timestamp}.png"
merge(pred_patches, merged_mask_filename, img_array.shape)
# Save merged mask
st.session_state.mask_filename = merged_mask_filename
# Clean up temporary patch files
st.info('Cleaning up temporary files...')
shutil.rmtree(patches_folder)
shutil.rmtree(pred_patches)
os.makedirs(patches_folder) # Recreate empty folders
os.makedirs(pred_patches)
st.success('Temporary files cleaned up')
else:
# Predict on whole image
st.session_state.tr_img = transforms(img)
prediction = predict(st.session_state.tr_img)
mask = (prediction > 0.5).astype(np.uint8) * 255
mask_filename = f"mask_{timestamp}.png"
mask_filepath = os.path.join(MASK_DIR, mask_filename)
Image.fromarray(mask).save(mask_filepath)
st.session_state.mask_filename = mask_filepath
st.session_state.file_uploaded = True
if st.session_state.file_uploaded and st.button('View result'):
if st.session_state.filename is None:
st.error("Please upload an image before viewing the result.")
else:
st.success('Image analyzed')
st.session_state.page = 'result'
st.rerun()
def result_page():
st.title('Analysis Result')
if 'filename' not in st.session_state or 'mask_filename' not in st.session_state:
st.error("No image or mask file found. Please upload and process an image first.")
if st.button('Back to Upload'):
reset_state()
st.rerun()
return
col1, col2 = st.columns(2)
# Display original image
original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename)
if os.path.exists(original_img_path):
original_img = Image.open(original_img_path)
col1.image(original_img, caption='Original Image', use_column_width=True)
else:
col1.error(f"Original image file not found: {original_img_path}")
# Display predicted mask
mask_path = st.session_state.mask_filename
if os.path.exists(mask_path):
mask = Image.open(mask_path)
col2.image(mask, caption='Predicted Mask', use_column_width=True)
else:
col2.error(f"Predicted mask file not found: {mask_path}")
st.subheader("Overlay with Area of Buildings (sqft)")
# Display overlayed image
if os.path.exists(original_img_path) and os.path.exists(mask_path):
original_np = cv2.imread(original_img_path)
mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
# Ensure mask is binary
_, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
# Resize mask to match original image size if necessary
if original_np.shape[:2] != mask_np.shape[:2]:
mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0]))
# Process and overlay image
overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png')
st.image(overlay_img, caption='Overlay Image', use_column_width=True)
else:
st.error("Image or mask file not found for overlay.")
if st.button('Back to Upload'):
reset_state()
st.rerun()
def main():
st.title('Building area estimation')
if 'page' not in st.session_state:
st.session_state.page = 'upload'
if st.session_state.page == 'upload':
upload_page()
elif st.session_state.page == 'result':
result_page()
if __name__ == '__main__':
main()
|