File size: 9,164 Bytes
35d85a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import streamlit as st
import sys

import torch
from model.CBAM.reunet_cbam import reunet_cbam
import cv2
from PIL import Image 
from model.transform import transforms
import numpy as np
from model.unet import UNET
from area import pixel_to_sqft, process_and_overlay_image
import matplotlib.pyplot as plt
import time
import os
import csv 
from datetime import datetime
from split_merge import split, merge 
from convert_raster import convert_gtiff_to_8bit
import shutil

patches_folder = 'data/Patches'
pred_patches = 'data/Patch_pred'
os.makedirs(patches_folder, exist_ok=True)
os.makedirs(pred_patches, exist_ok=True)

# Define the upload directories
UPLOAD_DIR = "data/uploaded_images"
MASK_DIR = "data/generated_masks"
CSV_LOG_PATH = "image_log.csv"

# Create the directories if they don't exist
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(MASK_DIR, exist_ok=True)

model = reunet_cbam()
model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict'])
model.eval()

def predict(image):
    with torch.no_grad():
        output = model(image.unsqueeze(0))
    return output.squeeze().cpu().numpy()

def log_image_details(image_id, image_filename, mask_filename):
    file_exists = os.path.exists(CSV_LOG_PATH)
    
    current_time = datetime.now()
    date = current_time.strftime('%Y-%m-%d')
    time = current_time.strftime('%H:%M:%S')
    
    with open(CSV_LOG_PATH, mode='a', newline='') as file:
        writer = csv.writer(file)
        if not file_exists:
            writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename'])
        
        # Get the next S.No
        if file_exists:
            with open(CSV_LOG_PATH, mode='r') as f:
                reader = csv.reader(f)
                sno = sum(1 for row in reader)
        else:
            sno = 1
        
        writer.writerow([sno, date, time, image_id, image_filename, mask_filename])

def overlay_mask(image, mask, alpha=0.5, rgb=[255, 0, 0]):
    # Ensure image is 3-channel
    if len(image.shape) == 2:
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    
    # Ensure mask is binary and same shape as image
    mask = mask.astype(bool)
    if mask.shape[:2] != image.shape[:2]:
        raise ValueError("Mask and image must have the same dimensions")
    
    # Create color overlay
    color_mask = np.zeros_like(image)
    color_mask[mask] = rgb
    
    # Blend the image and color mask
    output = cv2.addWeighted(image, 1, color_mask, alpha, 0)
    
    return output

def reset_state():
    st.session_state.file_uploaded = False
    st.session_state.filename = None
    st.session_state.mask_filename = None
    st.session_state.tr_img = None
    if 'page' in st.session_state:
        del st.session_state.page

def upload_page():
    if 'file_uploaded' not in st.session_state:
        st.session_state.file_uploaded = False
    
    if 'filename' not in st.session_state:
        st.session_state.filename = None
    
    if 'mask_filename' not in st.session_state:
        st.session_state.mask_filename = None

    image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif'])

    if image is not None:
        reset_state()  # Reset the state when a new image is uploaded
        bytes_data = image.getvalue()
        
        timestamp = int(time.time())
        original_filename = image.name
        file_extension = os.path.splitext(original_filename)[1].lower()
        
        if file_extension in ['.tiff', '.tif']:
            filename = f"image_{timestamp}.tif"
        else:
            filename = f"image_{timestamp}.png"
        
        filepath = os.path.join(UPLOAD_DIR, filename)
        
        with open(filepath, "wb") as f:
            f.write(bytes_data)
        
        # Check if the uploaded file is a GeoTIFF
        if file_extension in ['.tiff', '.tif']:
            st.info('Processing GeoTIFF image...')
            convert_gtiff_to_8bit(filepath)
            st.success('GeoTIFF converted to 8-bit image')
        
        img = Image.open(filepath)
        st.image(img, caption='Uploaded Image', use_column_width=True)
        st.success(f'Image saved as {filename}')

        # Store the full path of the uploaded image
        st.session_state.filename = filename

        # Convert image to numpy array
        img_array = np.array(img)

        # Check if image shape is more than 650x650
        if img_array.shape[0] > 650 or img_array.shape[1] > 650:
            # Split image into patches
            split(filepath, patch_size=256)

            # Display buffer while analyzing
            with st.spinner('Analyzing...'):
                # Predict on each patch
                for patch_filename in os.listdir(patches_folder):
                    if patch_filename.endswith(".png"):
                        patch_path = os.path.join(patches_folder, patch_filename)
                        patch_img = Image.open(patch_path)
                        patch_tr_img = transforms(patch_img)
                        prediction = predict(patch_tr_img)
                        mask = (prediction > 0.5).astype(np.uint8) * 255
                        mask_filename = f"mask_{patch_filename}"
                        mask_filepath = os.path.join(pred_patches, mask_filename)
                        Image.fromarray(mask).save(mask_filepath)

                # Merge predicted patches
                merged_mask_filename = f"generated_masks/mask_{timestamp}.png"
                merge(pred_patches, merged_mask_filename, img_array.shape)

                # Save merged mask
                st.session_state.mask_filename = merged_mask_filename

                # Clean up temporary patch files
                st.info('Cleaning up temporary files...')
                shutil.rmtree(patches_folder)
                shutil.rmtree(pred_patches)
                os.makedirs(patches_folder)  # Recreate empty folders
                os.makedirs(pred_patches)
                st.success('Temporary files cleaned up')
        else:
            # Predict on whole image
            st.session_state.tr_img = transforms(img)
            prediction = predict(st.session_state.tr_img)
            mask = (prediction > 0.5).astype(np.uint8) * 255
            mask_filename = f"mask_{timestamp}.png"
            mask_filepath = os.path.join(MASK_DIR, mask_filename)
            Image.fromarray(mask).save(mask_filepath)
            st.session_state.mask_filename = mask_filepath

        st.session_state.file_uploaded = True

    if st.session_state.file_uploaded and st.button('View result'):
        if st.session_state.filename is None:
            st.error("Please upload an image before viewing the result.")
        else:
            st.success('Image analyzed')
            st.session_state.page = 'result'
            st.rerun()

def result_page():
    st.title('Analysis Result')
    
    if 'filename' not in st.session_state or 'mask_filename' not in st.session_state:
        st.error("No image or mask file found. Please upload and process an image first.")
        if st.button('Back to Upload'):
            reset_state()
            st.rerun()
        return

    col1, col2 = st.columns(2)

    # Display original image
    original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename)
    if os.path.exists(original_img_path):
        original_img = Image.open(original_img_path)
        col1.image(original_img, caption='Original Image', use_column_width=True)
    else:
        col1.error(f"Original image file not found: {original_img_path}")

    # Display predicted mask
    mask_path = st.session_state.mask_filename
    if os.path.exists(mask_path):
        mask = Image.open(mask_path)
        col2.image(mask, caption='Predicted Mask', use_column_width=True)
    else:
        col2.error(f"Predicted mask file not found: {mask_path}")

    st.subheader("Overlay with Area of Buildings (sqft)")

    # Display overlayed image
    if os.path.exists(original_img_path) and os.path.exists(mask_path):
        original_np = cv2.imread(original_img_path)
        mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        # Ensure mask is binary
        _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)

        # Resize mask to match original image size if necessary
        if original_np.shape[:2] != mask_np.shape[:2]:
            mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0]))

        # Process and overlay image
        overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png')
        
        st.image(overlay_img, caption='Overlay Image', use_column_width=True)
    else:
        st.error("Image or mask file not found for overlay.")

    if st.button('Back to Upload'):
        reset_state()
        st.rerun()

def main():
    st.title('Building area estimation')
    
    if 'page' not in st.session_state:
        st.session_state.page = 'upload'

    if st.session_state.page == 'upload':
        upload_page()
    elif st.session_state.page == 'result':
        result_page()

if __name__ == '__main__':
    main()