Building_area / app.py
Pavan2k4's picture
Update app.py
7c28b17 verified
raw
history blame
11.5 kB
import streamlit as st
import sys
import os
import shutil
import time
from datetime import datetime
import csv
import cv2
import numpy as np
from PIL import Image
import torch
from huggingface_hub import HfApi
# Adjust import paths as needed
sys.path.append('Utils')
sys.path.append('model')
from model.CBAM.reunet_cbam import reunet_cbam
from model.transform import transforms
from model.unet import UNET
from Utils.area import pixel_to_sqft, process_and_overlay_image
from split_merge import split, merge
from Utils.convert import read_pansharpened_rgb
# Initialize Hugging Face API
hf_api = HfApi()
# Get the token from secrets
HF_TOKEN = st.secrets.get("HF_TOKEN")
if not HF_TOKEN:
st.error("HF_TOKEN not found in secrets. Please set it in your Space's Configuration > Secrets.")
st.stop()
# Your Space ID (this should match exactly with your Hugging Face Space URL)
REPO_ID = "Pavan2k4/Building_area"
REPO_TYPE = "space"
# Define base directory for Hugging Face Spaces
BASE_DIR = "/home/user"
# Define subdirectories
UPLOAD_DIR = os.path.join(BASE_DIR, "uploaded_images")
MASK_DIR = os.path.join(BASE_DIR, "generated_masks")
PATCHES_DIR = os.path.join(BASE_DIR, "patches")
PRED_PATCHES_DIR = os.path.join(BASE_DIR, "pred_patches")
CSV_LOG_PATH = os.path.join(BASE_DIR, "image_log.csv")
# Create directories
for directory in [UPLOAD_DIR, MASK_DIR, PATCHES_DIR, PRED_PATCHES_DIR]:
os.makedirs(directory, exist_ok=True)
# Load model
@st.cache_resource
def load_model():
model = reunet_cbam()
model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict'])
model.eval()
return model
model = load_model()
def predict(image):
with torch.no_grad():
output = model(image.unsqueeze(0))
return output.squeeze().cpu().numpy()
def save_to_hf_repo(local_path, repo_path):
try:
hf_api.upload_file(
path_or_fileobj=local_path,
path_in_repo=repo_path,
repo_id=REPO_ID,
repo_type=REPO_TYPE,
token=HF_TOKEN
)
st.success(f"File uploaded successfully to {repo_path}")
except Exception as e:
st.error(f"Error uploading file: {str(e)}")
st.error("Detailed error information:")
st.exception(e)
def log_image_details(image_id, image_filename, mask_filename):
file_exists = os.path.exists(CSV_LOG_PATH)
current_time = datetime.now()
date = current_time.strftime('%Y-%m-%d')
time = current_time.strftime('%H:%M:%S')
with open(CSV_LOG_PATH, mode='a', newline='') as file:
writer = csv.writer(file)
if not file_exists:
writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename'])
# Get the next S.No
if file_exists:
with open(CSV_LOG_PATH, mode='r') as f:
reader = csv.reader(f)
sno = sum(1 for row in reader)
else:
sno = 1
writer.writerow([sno, date, time, image_id, image_filename, mask_filename])
# Save CSV to Hugging Face repo
save_to_hf_repo(CSV_LOG_PATH, 'image_log.csv')
def upload_page():
if 'file_uploaded' not in st.session_state:
st.session_state.file_uploaded = False
if 'filename' not in st.session_state:
st.session_state.filename = None
if 'mask_filename' not in st.session_state:
st.session_state.mask_filename = None
image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif'])
if image is not None and not st.session_state.file_uploaded:
try:
bytes_data = image.getvalue()
timestamp = int(time.time())
original_filename = image.name
file_extension = os.path.splitext(original_filename)[1].lower()
if file_extension in ['.tiff', '.tif']:
filename = f"image_{timestamp}.tif"
converted_filename = f"image_{timestamp}_converted.png"
else:
filename = f"image_{timestamp}.png"
converted_filename = filename
filepath = os.path.join(UPLOAD_DIR, filename)
converted_filepath = os.path.join(UPLOAD_DIR, converted_filename)
with open(filepath, "wb") as f:
f.write(bytes_data)
st.success(f"Image saved to {filepath}")
# Save image to Hugging Face repo
save_to_hf_repo(filepath, f'uploaded_images/{filename}')
# Check if the uploaded file is a GeoTIFF
if file_extension in ['.tiff', '.tif']:
st.info('Processing GeoTIFF image...')
rgb_image = read_pansharpened_rgb(filepath)
cv2.imwrite(converted_filepath, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR))
st.success(f'GeoTIFF converted to 8-bit image and saved as {converted_filename}')
img = Image.open(converted_filepath)
else:
img = Image.open(filepath)
st.image(img, caption='Uploaded Image', use_column_width=True)
st.success(f'Image processed and saved as {converted_filename}')
# Store the full path of the converted image
st.session_state.filename = converted_filename
# Convert image to numpy array
img_array = np.array(img)
# Check if image shape is more than 650x650
if img_array.shape[0] > 650 or img_array.shape[1] > 650:
# Split image into patches
split(converted_filepath, patch_size=512)
# Display buffer while analyzing
with st.spinner('Analyzing...'):
# Predict on each patch
for patch_filename in os.listdir(PATCHES_DIR):
if patch_filename.endswith(".png"):
patch_path = os.path.join(PATCHES_DIR, patch_filename)
patch_img = Image.open(patch_path)
patch_tr_img = transforms(patch_img)
prediction = predict(patch_tr_img)
mask = (prediction > 0.5).astype(np.uint8) * 255
mask_filename = f"mask_{patch_filename}"
mask_filepath = os.path.join(PRED_PATCHES_DIR, mask_filename)
Image.fromarray(mask).save(mask_filepath)
# Merge predicted patches
merged_mask_filename = f"mask_{timestamp}.png"
merged_mask_path = os.path.join(MASK_DIR, merged_mask_filename)
merge(PRED_PATCHES_DIR, merged_mask_path, img_array.shape)
# Save merged mask
st.session_state.mask_filename = merged_mask_filename
# Clean up temporary patch files
st.info('Cleaning up temporary files...')
shutil.rmtree(PATCHES_DIR)
shutil.rmtree(PRED_PATCHES_DIR)
os.makedirs(PATCHES_DIR) # Recreate empty folders
os.makedirs(PRED_PATCHES_DIR)
st.success('Temporary files cleaned up')
else:
# Predict on whole image
st.session_state.tr_img = transforms(img)
prediction = predict(st.session_state.tr_img)
mask = (prediction > 0.5).astype(np.uint8) * 255
mask_filename = f"mask_{timestamp}.png"
mask_filepath = os.path.join(MASK_DIR, mask_filename)
Image.fromarray(mask).save(mask_filepath)
st.session_state.mask_filename = mask_filename
# Save mask to Hugging Face repo
mask_filepath = os.path.join(MASK_DIR, st.session_state.mask_filename)
save_to_hf_repo(mask_filepath, f'generated_masks/{st.session_state.mask_filename}')
# Log image details
log_image_details(timestamp, converted_filename, st.session_state.mask_filename)
st.session_state.file_uploaded = True
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.error("Please check the logs for more details.")
print(f"Error in upload_page: {str(e)}") # This will appear in the Streamlit logs
if st.session_state.file_uploaded and st.button('View result'):
if st.session_state.filename is None:
st.error("Please upload an image before viewing the result.")
else:
st.success('Image analyzed')
st.session_state.page = 'result'
st.rerun()
def result_page():
st.title('Analysis Result')
if 'filename' not in st.session_state or 'mask_filename' not in st.session_state:
st.error("No image or mask file found. Please upload and process an image first.")
if st.button('Back to Upload'):
st.session_state.page = 'upload'
st.session_state.file_uploaded = False
st.session_state.filename = None
st.session_state.mask_filename = None
st.rerun()
return
col1, col2 = st.columns(2)
# Display original image
original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename)
if os.path.exists(original_img_path):
original_img = Image.open(original_img_path)
col1.image(original_img, caption='Original Image', use_column_width=True)
else:
col1.error(f"Original image file not found: {original_img_path}")
# Display predicted mask
mask_path = os.path.join(MASK_DIR, st.session_state.mask_filename)
if os.path.exists(mask_path):
mask = Image.open(mask_path)
col2.image(mask, caption='Predicted Mask', use_column_width=True)
else:
col2.error(f"Predicted mask file not found: {mask_path}")
st.subheader("Overlay with Area of Buildings (sqft)")
# Display overlayed image
if os.path.exists(original_img_path) and os.path.exists(mask_path):
original_np = cv2.imread(original_img_path)
mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
# Ensure mask is binary
_, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
# Resize mask to match original image size if necessary
if original_np.shape[:2] != mask_np.shape[:2]:
mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0]))
# Process and overlay image
overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png')
st.image(overlay_img, caption='Overlay Image', use_column_width=True)
else:
st.error("Image or mask file not found for overlay.")
if st.button('Back to Upload'):
shutil.rmtree(PATCHES_DIR)
shutil.rmtree(PRED_PATCHES_DIR)
st.session_state.page = 'upload'
st.session_state.file_uploaded = False
st.session_state.filename = None
st.session_state.mask_filename = None
st.rerun()
def main():
st.title('Building area estimation')
if 'page' not in st.session_state:
st.session_state.page = 'upload'
if st.session_state.page == 'upload':
upload_page()
elif st.session_state.page == 'result':
result_page()
if __name__ == '__main__':
main()