File size: 7,464 Bytes
35d85a5
 
ebff05b
 
35d85a5
 
 
 
 
 
 
ebff05b
35d85a5
 
 
 
 
ebff05b
e6126b5
35d85a5
 
359a4fd
2cc088e
 
 
 
35d85a5
 
359a4fd
 
 
35d85a5
359a4fd
35d85a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359a4fd
35d85a5
 
 
 
 
 
 
 
 
 
359a4fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35d85a5
 
 
 
 
 
 
359a4fd
35d85a5
359a4fd
35d85a5
 
 
 
 
 
 
359a4fd
 
 
35d85a5
359a4fd
 
35d85a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359a4fd
 
35d85a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359a4fd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import streamlit as st
import sys
sys.path.append('Utils')
sys.path.append('model')
import torch
from model.CBAM.reunet_cbam import reunet_cbam
import cv2
from PIL import Image 
from model.transform import transforms
import numpy as np
from model.unet import UNET
from Utils.area import pixel_to_sqft, process_and_overlay_image
import matplotlib.pyplot as plt
import time
import os
import csv 
from datetime import datetime
from Utils.split_merge import split, merge 
from Utils.convert import convert_gtiff_to_8bit
import shutil

# Define directories
UPLOAD_DIR = "data/uploaded_images/"
MASK_DIR = "data/generated_masks/"
PATCHES_DIR = 'data/Patches/'
PRED_PATCHES_DIR = 'data/Patch_pred/'
CSV_LOG_PATH = "image_log.csv"

# Create directories
for directory in [UPLOAD_DIR, MASK_DIR, PATCHES_DIR, PRED_PATCHES_DIR]:
    os.makedirs(directory, exist_ok=True)

# Load model
model = reunet_cbam()
model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict'])
model.eval()

def predict(image):
    with torch.no_grad():
        output = model(image.unsqueeze(0))
    return output.squeeze().cpu().numpy()

def log_image_details(image_id, image_filename, mask_filename):
    file_exists = os.path.exists(CSV_LOG_PATH)
    current_time = datetime.now()
    date = current_time.strftime('%Y-%m-%d')
    time = current_time.strftime('%H:%M:%S')
    
    with open(CSV_LOG_PATH, mode='a', newline='') as file:
        writer = csv.writer(file)
        if not file_exists:
            writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename'])
        
        sno = sum(1 for row in open(CSV_LOG_PATH)) if file_exists else 1
        writer.writerow([sno, date, time, image_id, image_filename, mask_filename])

def reset_state():
    st.session_state.file_uploaded = False
    st.session_state.filename = None
    st.session_state.mask_filename = None
    st.session_state.tr_img = None
    if 'page' in st.session_state:
        del st.session_state.page

def process_image(image, timestamp):
    filename = f"image_{timestamp}{os.path.splitext(image.name)[1]}"
    filepath = os.path.join(UPLOAD_DIR, filename)
    
    with open(filepath, "wb") as f:
        f.write(image.getvalue())
    
    if filename.lower().endswith(('.tiff', '.tif')):
        st.info('Processing GeoTIFF image...')
        convert_gtiff_to_8bit(filepath)
        st.success('GeoTIFF converted to 8-bit image')
    
    return filename, filepath

def predict_image(img_array, filename, timestamp):
    if img_array.shape[0] > 650 or img_array.shape[1] > 650:
        split(os.path.join(UPLOAD_DIR, filename), patch_size=256)
        
        with st.spinner('Analyzing...'):
            for patch_filename in os.listdir(PATCHES_DIR):
                if patch_filename.endswith(".png"):
                    patch_path = os.path.join(PATCHES_DIR, patch_filename)
                    patch_img = Image.open(patch_path)
                    patch_tr_img = transforms(patch_img)
                    prediction = predict(patch_tr_img)
                    mask = (prediction > 0.5).astype(np.uint8) * 255
                    mask_filename = f"mask_{patch_filename}"
                    mask_filepath = os.path.join(PRED_PATCHES_DIR, mask_filename)
                    Image.fromarray(mask).save(mask_filepath)

            merged_mask_filename = f"mask_{timestamp}.png"
            merged_mask_filepath = os.path.join(MASK_DIR, merged_mask_filename)
            merge(PRED_PATCHES_DIR, merged_mask_filepath, img_array.shape)
            
            st.info('Cleaning up temporary files...')
            for dir in [PATCHES_DIR, PRED_PATCHES_DIR]:
                shutil.rmtree(dir)
                os.makedirs(dir)
            st.success('Temporary files cleaned up')
    else:
        tr_img = transforms(Image.open(os.path.join(UPLOAD_DIR, filename)))
        prediction = predict(tr_img)
        mask = (prediction > 0.5).astype(np.uint8) * 255
        merged_mask_filename = f"mask_{timestamp}.png"
        merged_mask_filepath = os.path.join(MASK_DIR, merged_mask_filename)
        Image.fromarray(mask).save(merged_mask_filepath)
    
    return merged_mask_filepath

def upload_page():
    if 'file_uploaded' not in st.session_state:
        st.session_state.file_uploaded = False
    
    image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif'])

    if image is not None:
        reset_state()
        timestamp = int(time.time())
        filename, filepath = process_image(image, timestamp)
        
        img = Image.open(filepath)
        st.image(img, caption='Uploaded Image', use_column_width=True)
        st.success(f'Image saved as {filename}')

        st.session_state.filename = filename
        img_array = np.array(img)
        
        mask_filepath = predict_image(img_array, filename, timestamp)
        st.session_state.mask_filename = mask_filepath

        log_image_details(timestamp, filename, os.path.basename(mask_filepath))
        
        st.session_state.file_uploaded = True

    if st.session_state.file_uploaded and st.button('View result'):
        if st.session_state.filename is None:
            st.error("Please upload an image before viewing the result.")
        else:
            st.success('Image analyzed')
            st.session_state.page = 'result'
            st.rerun()

def result_page():
    st.title('Analysis Result')
    
    if 'filename' not in st.session_state or 'mask_filename' not in st.session_state:
        st.error("No image or mask file found. Please upload and process an image first.")
        if st.button('Back to Upload'):
            reset_state()
            st.rerun()
        return

    col1, col2 = st.columns(2)

    original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename)
    mask_path = st.session_state.mask_filename

    if os.path.exists(original_img_path):
        original_img = Image.open(original_img_path)
        col1.image(original_img, caption='Original Image', use_column_width=True)
    else:
        col1.error(f"Original image file not found: {original_img_path}")

    if os.path.exists(mask_path):
        mask = Image.open(mask_path)
        col2.image(mask, caption='Predicted Mask', use_column_width=True)
    else:
        col2.error(f"Predicted mask file not found: {mask_path}")

    st.subheader("Overlay with Area of Buildings (sqft)")

    if os.path.exists(original_img_path) and os.path.exists(mask_path):
        original_np = cv2.imread(original_img_path)
        mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)

        if original_np.shape[:2] != mask_np.shape[:2]:
            mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0]))

        overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png')
        
        st.image(overlay_img, caption='Overlay Image', use_column_width=True)
    else:
        st.error("Image or mask file not found for overlay.")

    if st.button('Back to Upload'):
        reset_state()
        st.rerun()

def main():
    st.title('Building area estimation')
    
    if 'page' not in st.session_state:
        st.session_state.page = 'upload'

    if st.session_state.page == 'upload':
        upload_page()
    elif st.session_state.page == 'result':
        result_page()

if __name__ == '__main__':
    main()