import streamlit as st import sys sys.path.append('Utils') sys.path.append('model') import torch from model.CBAM.reunet_cbam import reunet_cbam import cv2 from PIL import Image from model.transform import transforms import numpy as np from model.unet import UNET from Utils.area import pixel_to_sqft, process_and_overlay_image import matplotlib.pyplot as plt import time import os import csv from datetime import datetime from Utils.split_merge import split, merge from Utils.convert import read_pansharpened_rgb import shutil # Define directories UPLOAD_DIR = "uploaded_images/" MASK_DIR = "generated_masks/" patches_folder = 'data/Patches/' pred_patches = 'data/Patch_pred/' CSV_LOG_PATH = "image_log.csv" # Create directories for directory in [UPLOAD_DIR, MASK_DIR, patches_folder, pred_patches]: os.makedirs(directory, exist_ok=True) # Load model model = reunet_cbam() model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict']) model.eval() def predict(image): with torch.no_grad(): output = model(image.unsqueeze(0)) return output.squeeze().cpu().numpy() def log_image_details(image_id, image_filename, mask_filename): file_exists = os.path.exists(CSV_LOG_PATH) current_time = datetime.now() date = current_time.strftime('%Y-%m-%d') time = current_time.strftime('%H:%M:%S') with open(CSV_LOG_PATH, mode='a', newline='') as file: writer = csv.writer(file) if not file_exists: writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename']) # Get the next S.No if file_exists: with open(CSV_LOG_PATH, mode='r') as f: reader = csv.reader(f) sno = sum(1 for row in reader) else: sno = 1 writer.writerow([sno, date, time, image_id, image_filename, mask_filename]) def overlay_mask(image, mask, alpha=0.5, rgb=[255, 0, 0]): # Ensure image is 3-channel if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # Ensure mask is binary and same shape as image mask = mask.astype(bool) if mask.shape[:2] != image.shape[:2]: raise ValueError("Mask and image must have the same dimensions") # Create color overlay color_mask = np.zeros_like(image) color_mask[mask] = rgb # Blend the image and color mask output = cv2.addWeighted(image, 1, color_mask, alpha, 0) return output import shutil # Add this import at the top of your file def upload_page(): if 'file_uploaded' not in st.session_state: st.session_state.file_uploaded = False if 'filename' not in st.session_state: st.session_state.filename = None if 'mask_filename' not in st.session_state: st.session_state.mask_filename = None image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif']) if image is not None and not st.session_state.file_uploaded: bytes_data = image.getvalue() timestamp = int(time.time()) original_filename = image.name file_extension = os.path.splitext(original_filename)[1].lower() if file_extension in ['.tiff', '.tif']: filename = f"image_{timestamp}.tif" converted_filename = f"image_{timestamp}_converted.png" else: filename = f"image_{timestamp}.png" converted_filename = filename filepath = os.path.join(UPLOAD_DIR, filename) converted_filepath = os.path.join(UPLOAD_DIR, converted_filename) with open(filepath, "wb") as f: f.write(bytes_data) # Check if the uploaded file is a GeoTIFF if file_extension in ['.tiff', '.tif']: st.info('Processing GeoTIFF image...') rgb_image = read_pansharpened_rgb(filepath) cv2.imwrite(converted_filepath, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)) st.success(f'GeoTIFF converted to 8-bit image and saved as {converted_filename}') img = Image.open(converted_filepath) else: img = Image.open(filepath) st.image(img, caption='Uploaded Image', use_column_width=True) st.success(f'Image saved as {converted_filename}') # Store the full path of the converted image st.session_state.filename = converted_filename # Convert image to numpy array img_array = np.array(img) # Check if image shape is more than 650x650 if img_array.shape[0] > 650 or img_array.shape[1] > 650: # Split image into patches split(converted_filepath, patch_size=512) # Display buffer while analyzing with st.spinner('Analyzing...'): # Predict on each patch for patch_filename in os.listdir(patches_folder): if patch_filename.endswith(".png"): patch_path = os.path.join(patches_folder, patch_filename) patch_img = Image.open(patch_path) patch_tr_img = transforms(patch_img) prediction = predict(patch_tr_img) mask = (prediction > 0.5).astype(np.uint8) * 255 mask_filename = f"mask_{patch_filename}" mask_filepath = os.path.join(pred_patches, mask_filename) Image.fromarray(mask).save(mask_filepath) # Merge predicted patches merged_mask_filename = f"generated_masks/mask_{timestamp}.png" merge(pred_patches, merged_mask_filename, img_array.shape) # Save merged mask st.session_state.mask_filename = merged_mask_filename # Clean up temporary patch files st.info('Cleaning up temporary files...') shutil.rmtree(patches_folder) shutil.rmtree(pred_patches) os.makedirs(patches_folder) # Recreate empty folders os.makedirs(pred_patches) st.success('Temporary files cleaned up') else: # Predict on whole image st.session_state.tr_img = transforms(img) prediction = predict(st.session_state.tr_img) mask = (prediction > 0.5).astype(np.uint8) * 255 mask_filename = f"mask_{timestamp}.png" mask_filepath = os.path.join(MASK_DIR, mask_filename) Image.fromarray(mask).save(mask_filepath) st.session_state.mask_filename = mask_filepath st.session_state.file_uploaded = True if st.session_state.file_uploaded and st.button('View result'): if st.session_state.filename is None: st.error("Please upload an image before viewing the result.") else: st.success('Image analyzed') st.session_state.page = 'result' st.rerun() def result_page(): st.title('Analysis Result') if 'filename' not in st.session_state or 'mask_filename' not in st.session_state: st.error("No image or mask file found. Please upload and process an image first.") if st.button('Back to Upload'): st.session_state.page = 'upload' st.session_state.file_uploaded = False st.session_state.filename = None st.session_state.mask_filename = None st.rerun() return col1, col2 = st.columns(2) # Display original image original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename) if os.path.exists(original_img_path): original_img = Image.open(original_img_path) col1.image(original_img, caption='Original Image', use_column_width=True) else: col1.error(f"Original image file not found: {original_img_path}") # Display predicted mask mask_path = st.session_state.mask_filename if os.path.exists(mask_path): mask = Image.open(mask_path) col2.image(mask, caption='Predicted Mask', use_column_width=True) else: col2.error(f"Predicted mask file not found: {mask_path}") st.subheader("Overlay with Area of Buildings (sqft)") # Display overlayed image if os.path.exists(original_img_path) and os.path.exists(mask_path): original_np = cv2.imread(original_img_path) mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) # Ensure mask is binary _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY) # Resize mask to match original image size if necessary if original_np.shape[:2] != mask_np.shape[:2]: mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0])) # Process and overlay image overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png') # Convert BGR to RGB for displaying with st.image # overlay_rgb = cv2.cvtColor(overlay_img, cv2.COLOR_BGR2RGB) st.image(overlay_img, caption='Overlay Image', use_column_width=True) else: st.error("Image or mask file not found for overlay.") if st.button('Back to Upload'): shutil.rmtree(patches_folder) shutil.rmtree(pred_patches) st.session_state.page = 'upload' st.session_state.file_uploaded = False st.session_state.filename = None st.session_state.mask_filename = None st.rerun() def main(): st.title('Building area estimation') if 'page' not in st.session_state: st.session_state.page = 'upload' if st.session_state.page == 'upload': upload_page() elif st.session_state.page == 'result': result_page() if __name__ == '__main__': main()