Pavan2k4 commited on
Commit
35d85a5
·
1 Parent(s): 5f0c972
Utils/area.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ def pixel_to_sqft(pixel_area, resolution_cm=30):
5
+
6
+ area_cm2 = pixel_area * (resolution_cm ** 2)
7
+ area_m2 = area_cm2 / 10000.0
8
+ area_ft2 = area_m2 * 10.7639
9
+ return area_ft2
10
+
11
+ def process_and_overlay_image(original_image, mask_prediction, output_image_path = None, resolution_cm=30):
12
+
13
+
14
+ # Load original image
15
+
16
+
17
+ # Convert mask prediction to binary mask
18
+ mask = mask_prediction.astype(np.uint8) * 255
19
+
20
+ # Find contours in the mask
21
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
22
+
23
+ # List to hold areas in square feet
24
+ areas_sqft = []
25
+
26
+ for contour in contours:
27
+ area_pixels = cv2.contourArea(contour)
28
+ area_sqft = pixel_to_sqft(area_pixels, resolution_cm)
29
+ areas_sqft.append(area_sqft)
30
+
31
+ # Draw contours on the original image
32
+ cv2.drawContours(original_image, [contour], -1, (0, 255, 0), int(0.5)) # Green color for contours
33
+
34
+ # Calculate and draw centroid
35
+ M = cv2.moments(contour)
36
+ if M["m00"] != 0:
37
+ cX = int(M["m10"] / M["m00"])
38
+ cY = int(M["m01"] / M["m00"])
39
+ else:
40
+ cX, cY = 0, 0
41
+
42
+ cv2.putText(original_image, f'{area_sqft:.0f}', (cX, cY),
43
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
44
+
45
+ # Save and display the image with contours
46
+ #cv2.imwrite(output_image_path, cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
47
+
48
+ # Display the image using matplotlib
49
+ return original_image
50
+ #return (cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
51
+
Utils/convert_raster.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from osgeo import gdal
2
+ import numpy as np
3
+ import os
4
+ import tempfile
5
+ from PIL import Image
6
+
7
+ def convert_gtiff_to_8bit(src):
8
+ dst = src
9
+ img = gdal.Open(src)
10
+ driver = gdal.GetDriverByName('GTiff')
11
+
12
+ output_image = driver.Create(dst, img.RasterXSize, img.RasterYSize, img.RasterCount, gdal.GDT_Byte, ['PHOTOMETRIC = RGB'])
13
+ output_image.SetGeoTransform(img.GetGeoTransform())
14
+ output_image.SetProjection(img.GetProjection())
15
+
16
+ max_bands = img.RasterCount
17
+ for i in range(max_bands):
18
+ i = i+1
19
+ band = img.GetRasterBand(i)
20
+ band_array = band.ReadAsArray()
21
+
22
+ min, max = band.ComputeRasterMinMax(1)
23
+ band_array = np.interp(band_array, (min,max), (0,255)).astype(np.uint8)
24
+
25
+ out = output_image.GetRasterBand(i)
26
+ out.WriteArray(band_array)
27
+ out.FlushCache()
28
+ return output_image
29
+ del output_image
30
+
Utils/split_merge.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import numpy as np
4
+
5
+ patches_folder = '/home/pavan/Desktop/Kus/Stream/Patches'
6
+ pred_patches = '/home/pavan/Desktop/Kus/Stream/Patch_pred'
7
+ os.makedirs(patches_folder, exist_ok=True)
8
+ os.makedirs(pred_patches,exist_ok=True)
9
+
10
+
11
+ def split(image, destination = patches_folder, patch_size = 256):
12
+ img = cv2.imread(image)
13
+ h,w,_ = img.shape
14
+ for y in range(0, h, patch_size):
15
+ for x in range(0, w, patch_size):
16
+ patch = img[y:y+patch_size, x:x+patch_size]
17
+
18
+
19
+ patch_filename = f"patch_{y}_{x}.png"
20
+ patch_path = os.path.join(destination, patch_filename)
21
+ cv2.imwrite(patch_path, patch)
22
+
23
+ def merge(patch_folder , dest_image = 'out.png', image_shape = None):
24
+ merged = np.zeros(image_shape[:-1] + (3,), dtype=np.uint8)
25
+ for filename in os.listdir(patch_folder):
26
+ if filename.endswith(".png"):
27
+ patch_path = os.path.join(patch_folder, filename)
28
+ patch = cv2.imread(patch_path)
29
+ patch_height, patch_width, _ = patch.shape
30
+
31
+ # Extract patch coordinates from filename
32
+ parts = filename.split("_")
33
+ x, y = None, None
34
+ for part in parts:
35
+ if part.endswith(".png"):
36
+ x = int(part.split(".")[0])
37
+ elif part.isdigit():
38
+ y = int(part)
39
+ if x is None or y is None:
40
+ raise ValueError(f"Invalid filename: {filename}")
41
+
42
+ # Check if patch fits within image boundaries
43
+ if x + patch_width > image_shape[1] or y + patch_height > image_shape[0]:
44
+ # Adjust patch position to fit within image boundaries
45
+ if x + patch_width > image_shape[1]:
46
+ x = image_shape[1] - patch_width
47
+ if y + patch_height > image_shape[0]:
48
+ y = image_shape[0] - patch_height
49
+
50
+ # Merge patch into the main image
51
+ merged[y:y+patch_height, x:x+patch_width, :] = patch
52
+
53
+ cv2.imwrite(dest_image, merged)
image_log.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ S.No,Date,Time,Image ID,Image Filename,Mask Filename
2
+ 1,2024-08-27,23:35:06,1724781904,image_1724781904.png,mask_1724781906.png
3
+ 2,2024-08-27,23:35:14,1724781912,image_1724781912.png,mask_1724781914.png
4
+ 3,2024-08-28,21:36:50,1724861208,image_1724861208.png,mask_1724861210.png
5
+ 4,2024-08-28,21:43:36,1724861611,image_1724861611.png,mask_1724861616.png
6
+ 5,2024-08-28,21:47:06,1724861823,image_1724861823.png,mask_1724861826.png
7
+ 6,2024-08-29,11:23:04,1724910780,image_1724910780.png,mask_1724910784.png
8
+ 7,2024-09-04,21:13:10,1725464583,image_1725464583.tif,mask_1725464590.png
9
+ 8,2024-09-04,21:14:25,1725464658,image_1725464658.tif,mask_1725464665.png
10
+ 9,2024-09-04,21:17:15,1725464832,image_1725464832.tif,mask_1725464835.png
11
+ 10,2024-09-04,22:21:21,1725468679,image_1725468679.tif,mask_1725468681.png
12
+ 11,2024-09-04,22:45:45,1725470142,image_1725470142.tif,mask_1725470145.png
13
+ 12,2024-09-04,22:46:39,1725470196,image_1725470196.tif,mask_1725470199.png
14
+ 13,2024-09-04,22:47:30,1725470247,image_1725470247.png,mask_1725470250.png
15
+ 14,2024-09-04,22:47:55,1725470272,image_1725470272.jpg,mask_1725470275.png
16
+ 15,2024-09-04,22:48:12,1725470289,image_1725470289.png,mask_1725470292.png
17
+ 16,2024-09-04,22:49:15,1725470353,image_1725470353.png,mask_1725470355.png
18
+ 17,2024-09-04,22:50:32,1725470430,image_1725470430.png,mask_1725470432.png
19
+ 18,2024-09-05,00:00:05,1725474603,image_1725474603.tif,mask_1725474605.png
20
+ 19,2024-09-05,22:51:32,1725556889,image_1725556889.tif,mask_1725556892.png
21
+ 20,2024-09-05,22:56:54,1725557179,image_1725557179.png,mask_1725557214.png
latest.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b3e02b56c25ad19386b0dbad700e0db322b871a5d281d7f14b7825b963ee163
3
+ size 100116298
model/CBAM/__pycache__/cbam.cpython-310.pyc ADDED
Binary file (832 Bytes). View file
 
model/CBAM/__pycache__/channel_att.cpython-310.pyc ADDED
Binary file (1.13 kB). View file
 
model/CBAM/__pycache__/reunet_cbam.cpython-310.pyc ADDED
Binary file (4.97 kB). View file
 
model/CBAM/__pycache__/spatial_att.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
model/CBAM/bn.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ class batch_norm(nn.Module):
4
+ def __init__(self, inp):
5
+ super().__init__()
6
+ self.batch = nn.BatchNorm2d(inp)
7
+ self.relu = nn.ReLU()
8
+ def forward(self, x):
9
+ b = self.batch(x)
10
+ op = self.relu(b)
11
+ return op
model/CBAM/cbam.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from channel_att import Channel_attention
4
+ from spatial_att import Spatial_attention
5
+ class CBAM(nn.Module):
6
+ def __init__(self, ch):
7
+ super().__init__()
8
+ self.channel = Channel_attention(ch)
9
+ self.spatial = Spatial_attention()
10
+
11
+ def forward(self, x):
12
+ x1 = self.channel(x)
13
+ x2 = self.spatial(x1)
14
+ return x2
model/CBAM/channel_att.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ class Channel_attention(nn.Module):
4
+ def __init__(self,ch, ratio = 8):
5
+ super().__init__()
6
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
7
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
8
+
9
+ self.mlp = nn.Sequential(
10
+ nn.Linear(ch, ch//ratio, bias = False),
11
+ nn.ReLU(inplace = True),
12
+ nn.Linear( ch//ratio,ch, bias = False)
13
+ )
14
+ self.sigmoid = nn.Sigmoid()
15
+
16
+ def forward(self, x):
17
+ x1 = self.avg_pool(x).squeeze(-1).squeeze(-1)
18
+ x1 = self.mlp(x1)
19
+ # x2
20
+ x2 = self.max_pool(x).squeeze(-1).squeeze(-1)
21
+ x2 = self.mlp(x2)
22
+ #concat
23
+ f = x1+x2
24
+ f_s = self.sigmoid(f).unsqueeze(-1).unsqueeze(-1)
25
+ f_final = x * f_s
26
+ return f_final
model/CBAM/decoder.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from residual import residual
4
+ class decoder(nn.Module):
5
+ def __init__(self, inp, out):
6
+ super().__init__()
7
+ self.upsample = nn.Upsample(scale_factor=2, mode = 'bilinear', align_corners = True)
8
+ self.block = residual(inp+out, out)
9
+ def forward(self, x, skip):
10
+ x = self.upsample(x)
11
+ x = torch.cat([x, skip], axis = 1)
12
+ x = self.block(x)
13
+ return x
model/CBAM/residual.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from bn import batch_norm
4
+ from cbam import CBAM
5
+ class residual(nn.Module):
6
+ def __init__(self, inp, out, stride=1):
7
+ super().__init__()
8
+ self.bn1 = batch_norm(inp)
9
+ self.conv1 = nn.Conv2d(inp, out, kernel_size=3, padding=1, stride=stride)
10
+ self.bn2 = batch_norm(out)
11
+ self.conv2 = nn.Conv2d(out, out, kernel_size=3, padding=1, stride=1)
12
+ # skip connection
13
+ self.concat = nn.Conv2d(inp, out, kernel_size=1, padding=0, stride=stride)
14
+ # Add CBAM
15
+ self.cbam = CBAM(out)
16
+
17
+ def forward(self, input):
18
+ x = self.bn1(input)
19
+ x = self.conv1(x)
20
+ x = self.bn2(x)
21
+ x = self.conv2(x)
22
+ x = self.cbam(x) # Apply CBAM
23
+ skip = self.concat(input)
24
+ skip = x + skip
25
+ return skip
model/CBAM/reunet_cbam.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Spatial_attention(nn.Module):
5
+ def __init__(self, kernel = 7):
6
+ super().__init__()
7
+ self.conv = nn.Conv2d(2, 1, padding = 3, kernel_size=kernel, bias = False)
8
+ self.sigmoid = nn.Sigmoid()
9
+ def forward(self, x):
10
+ x1 = torch.mean(x, dim=1, keepdim = True)
11
+ x2, _ = torch.max(x, dim = 1, keepdim = True)
12
+ f = torch.concat([x1, x2], dim = 1)
13
+ f_c = self.conv(f)
14
+ f_s = self.sigmoid(f_c)
15
+ f_final = x * f_s
16
+
17
+ return f_final
18
+
19
+ class Channel_attention(nn.Module):
20
+ def __init__(self,ch, ratio = 8):
21
+ super().__init__()
22
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
23
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
24
+
25
+ self.mlp = nn.Sequential(
26
+ nn.Linear(ch, ch//ratio, bias = False),
27
+ nn.ReLU(inplace = True),
28
+ nn.Linear( ch//ratio,ch, bias = False)
29
+ )
30
+ self.sigmoid = nn.Sigmoid()
31
+
32
+ def forward(self, x):
33
+ x1 = self.avg_pool(x).squeeze(-1).squeeze(-1)
34
+ x1 = self.mlp(x1)
35
+ # x2
36
+ x2 = self.max_pool(x).squeeze(-1).squeeze(-1)
37
+ x2 = self.mlp(x2)
38
+ #concat
39
+ f = x1+x2
40
+ f_s = self.sigmoid(f).unsqueeze(-1).unsqueeze(-1)
41
+ f_final = x * f_s
42
+ return f_final
43
+ class CBAM(nn.Module):
44
+ def __init__(self, ch):
45
+ super().__init__()
46
+ self.channel = Channel_attention(ch)
47
+ self.spatial = Spatial_attention()
48
+
49
+ def forward(self, x):
50
+ x1 = self.channel(x)
51
+ x2 = self.spatial(x1)
52
+ return x2
53
+
54
+ class residual(nn.Module):
55
+ def __init__(self, inp, out, stride=1):
56
+ super().__init__()
57
+ self.bn1 = batch_norm(inp)
58
+ self.conv1 = nn.Conv2d(inp, out, kernel_size=3, padding=1, stride=stride)
59
+ self.bn2 = batch_norm(out)
60
+ self.conv2 = nn.Conv2d(out, out, kernel_size=3, padding=1, stride=1)
61
+ # skip connection
62
+ self.concat = nn.Conv2d(inp, out, kernel_size=1, padding=0, stride=stride)
63
+ # Add CBAM
64
+ self.cbam = CBAM(out)
65
+
66
+ def forward(self, input):
67
+ x = self.bn1(input)
68
+ x = self.conv1(x)
69
+ x = self.bn2(x)
70
+ x = self.conv2(x)
71
+ x = self.cbam(x) # Apply CBAM
72
+ skip = self.concat(input)
73
+ skip = x + skip
74
+ return skip
75
+
76
+ class batch_norm(nn.Module):
77
+ def __init__(self, inp):
78
+ super().__init__()
79
+ self.batch = nn.BatchNorm2d(inp)
80
+ self.relu = nn.ReLU()
81
+ def forward(self, x):
82
+ b = self.batch(x)
83
+ op = self.relu(b)
84
+ return op
85
+
86
+ class decoder(nn.Module):
87
+ def __init__(self, inp, out):
88
+ super().__init__()
89
+ self.upsample = nn.Upsample(scale_factor=2, mode = 'bilinear', align_corners = True)
90
+ self.block = residual(inp+out, out)
91
+ def forward(self, x, skip):
92
+ x = self.upsample(x)
93
+ x = torch.cat([x, skip], axis = 1)
94
+ x = self.block(x)
95
+ return x
96
+ class reunet_cbam(nn.Module):
97
+ def __init__(self):
98
+ super().__init__()
99
+ # encoder 1
100
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=1)
101
+ self.bn1 = batch_norm(64)
102
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1)
103
+ self.conv3 = nn.Conv2d(3, 64, kernel_size=1, padding=0, stride=1)
104
+ self.cbam1 = CBAM(64) # Add CBAM for encoder 1
105
+ # encoder2
106
+ self.enc2 = residual(64, 128, stride=2)
107
+ # encoder3
108
+ self.enc3 = residual(128, 256, stride=2)
109
+ # bridge
110
+ self.bridge = residual(256, 512, stride=2)
111
+ # decoder
112
+ self.d1 = decoder(512, 256)
113
+ self.d2 = decoder(256, 128)
114
+ self.d3 = decoder(128, 64)
115
+
116
+ # output
117
+ self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0)
118
+ self.sigmoid = nn.Sigmoid()
119
+
120
+ def forward(self, input):
121
+ '''enc1'''
122
+ x = self.conv1(input)
123
+ x = self.bn1(x)
124
+ x = self.conv2(x)
125
+ x = self.cbam1(x) # Apply CBAM
126
+ residual = self.conv3(input)
127
+ skip1 = x + residual
128
+ '''enc 2 and 3'''
129
+ skip2 = self.enc2(skip1)
130
+ skip3 = self.enc3(skip2)
131
+ '''bridge'''
132
+ b = self.bridge(skip3)
133
+ '''decoder'''
134
+ d1 = self.d1(b, skip3)
135
+ d2 = self.d2(d1, skip2)
136
+ d3 = self.d3(d2, skip1)
137
+ '''output'''
138
+ output = self.output(d3)
139
+ output = self.sigmoid(output)
140
+ return output
model/CBAM/spatial_att.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ class Spatial_attention(nn.Module):
4
+ def __init__(self, kernel = 7):
5
+ super().__init__()
6
+ self.conv = nn.Conv2d(2, 1, padding = 3, kernel_size=kernel, bias = False)
7
+ self.sigmoid = nn.Sigmoid()
8
+ def forward(self, x):
9
+ x1 = torch.mean(x, dim=1, keepdim = True)
10
+ x2, _ = torch.max(x, dim = 1, keepdim = True)
11
+ f = torch.concat([x1, x2], dim = 1)
12
+ f_c = self.conv(f)
13
+ f_s = self.sigmoid(f_c)
14
+ f_final = x * f_s
15
+
16
+ return f_final
17
+
18
+
model/__init__.py ADDED
File without changes
model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (141 Bytes). View file
 
model/__pycache__/bn.cpython-310.pyc ADDED
Binary file (797 Bytes). View file
 
model/__pycache__/decoder.cpython-310.pyc ADDED
Binary file (971 Bytes). View file
 
model/__pycache__/model.cpython-310.pyc ADDED
Binary file (1.49 kB). View file
 
model/__pycache__/residual.cpython-310.pyc ADDED
Binary file (1.05 kB). View file
 
model/__pycache__/transform.cpython-310.pyc ADDED
Binary file (430 Bytes). View file
 
model/__pycache__/unet.cpython-310.pyc ADDED
Binary file (2.08 kB). View file
 
model/bn.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.functional as F
4
+
5
+ class batch_norm(nn.Module):
6
+ def __init__(self, inp):
7
+ super().__init__()
8
+ self.batch = nn.BatchNorm2d(inp)
9
+ self.relu = nn.ReLU()
10
+ def forward(self, x):
11
+ b = self.batch(x)
12
+ op = self.relu(b)
13
+ return op
model/decoder.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.functional as F
4
+ from bn import batch_norm
5
+ from residual import residual
6
+ class decoder(nn.Module):
7
+ def __init__(self, inp, out):
8
+ super().__init__()
9
+ self.upsample = nn.Upsample(scale_factor=2, mode = 'bilinear', align_corners = True)
10
+ self.block = residual(inp+out, out)
11
+ def forward(self, x, skip):
12
+ x = self.upsample(x)
13
+ x = torch.cat([x, skip], axis = 1)
14
+ x = self.block(x)
15
+ return x
model/model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.functional as F
4
+
5
+ from bn import batch_norm
6
+ from residual import residual
7
+ from decoder import decoder
8
+
9
+ class reunet(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ #encoder 1
13
+ self.conv1 = nn.Conv2d(3, 64, kernel_size = 3, padding = 1, stride = 1)
14
+ self.bn1 = batch_norm(64)
15
+ self.conv2 = nn.Conv2d(64,64, kernel_size =3, padding = 1, stride = 1)
16
+ self.conv3 = nn.Conv2d(3, 64, kernel_size = 1, padding = 0, stride = 1)
17
+
18
+ #encoder2
19
+ self.enc2 = residual(64, 128, stride = 2)
20
+ #encoder3
21
+ self.enc3 = residual(128,256, stride = 2)
22
+
23
+ #bridge
24
+ self.bridge = residual(256,512, stride = 2)
25
+
26
+ #decoder
27
+ self.d1 = decoder(512, 256)
28
+ self.d2 = decoder(256, 128)
29
+ self.d3 = decoder(128,64)
30
+
31
+ #output
32
+ self.output = nn.Conv2d(64,1,kernel_size = 1, padding = 0)
33
+ self.sigmoid = nn.Sigmoid()
34
+
35
+ def forward(self, input):
36
+ '''enc1'''
37
+ x = self.conv1(input)
38
+ x = self.bn1(x)
39
+ x = self.conv2(x)
40
+ residual = self.conv3(input)
41
+ skip1 = x+residual
42
+
43
+ '''enc 2 and 3'''
44
+ skip2 = self.enc2(skip1)
45
+ skip3 = self.enc3(skip2)
46
+
47
+ '''bridge'''
48
+ b = self.bridge(skip3)
49
+ '''decoder'''
50
+ d1 = self.d1(b, skip3)
51
+ d2 = self.d2(d1, skip2)
52
+ d3 = self.d3(d2, skip1)
53
+ '''output'''
54
+
55
+
56
+ output = self.output(d3)
57
+ output = self.sigmoid(output)
58
+ return output
59
+
model/residual.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.functional as F
4
+ from bn import batch_norm
5
+ class residual(nn.Module):
6
+ def __init__(self, inp, out, stride = 1):
7
+ super().__init__()
8
+ self.bn1 = batch_norm(inp)
9
+ self.conv1 = nn.Conv2d(inp, out, kernel_size=3, padding = 1, stride = stride)
10
+ self.bn2 = batch_norm(out)
11
+ self.conv2 = nn.Conv2d(out, out, kernel_size = 3, padding = 1, stride = 1)
12
+ # skip cpnnection
13
+ self.concat = nn.Conv2d(inp, out, kernel_size = 1, padding = 0, stride = stride)
14
+ def forward(self, input):
15
+ x = self.bn1(input)
16
+ x = self.conv1(x)
17
+ x = self.bn2(x)
18
+ x = self.conv2(x)
19
+ skip = self.concat(input)
20
+ skip = x+skip
21
+
22
+ return skip
model/t.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from bn import batch_norm
2
+ print(batch_norm)
model/transform.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ transforms = T.Compose([
4
+ T.ToTensor(),T.Resize((256,256), antialias=False), T.Lambda(lambda x: x.to(torch.float32))])
model/unet.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ class DoubleConv(nn.Module):
4
+ def __init__(self, in_channels, out_channels):
5
+ super(DoubleConv, self).__init__()
6
+ self.conv = nn.Sequential(
7
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
8
+ nn.BatchNorm2d(out_channels),
9
+ nn.ReLU(inplace=True),
10
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
11
+ nn.BatchNorm2d(out_channels),
12
+ nn.ReLU(inplace=True),
13
+ )
14
+
15
+ def forward(self, x):
16
+ return self.conv(x)
17
+
18
+ class UNET(nn.Module):
19
+ def __init__(
20
+ self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
21
+ ):
22
+ super(UNET, self).__init__()
23
+ self.ups = nn.ModuleList()
24
+ self.downs = nn.ModuleList()
25
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
26
+
27
+ for feature in features:
28
+ self.downs.append(DoubleConv(in_channels, feature))
29
+ in_channels = feature
30
+
31
+ for feature in reversed(features):
32
+ self.ups.append(
33
+ nn.ConvTranspose2d(
34
+ feature*2, feature, kernel_size=2, stride=2,
35
+ )
36
+ )
37
+ self.ups.append(DoubleConv(feature*2, feature))
38
+
39
+ self.bottleneck = DoubleConv(features[-1], features[-1]*2)
40
+ self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
41
+
42
+ def forward(self, x):
43
+ skip_connections = []
44
+ for down in self.downs:
45
+ x = down(x)
46
+ skip_connections.append(x)
47
+ x = self.pool(x)
48
+
49
+ x = self.bottleneck(x)
50
+ skip_connections = skip_connections[::-1]
51
+
52
+ for idx in range(0, len(self.ups), 2):
53
+ x = self.ups[idx](x)
54
+ skip_connection = skip_connections[idx//2]
55
+
56
+ if x.shape != skip_connection.shape:
57
+ x = TF.resize(x, size=skip_connection.shape[2:])
58
+
59
+ concat_skip = torch.cat((skip_connection, x), dim=1)
60
+ x = self.ups[idx+1](concat_skip)
61
+
62
+ return self.final_conv(x)
63
+
requirements.txt ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.31.0
3
+ access==1.1.9
4
+ aenum==3.1.15
5
+ affine==2.4.0
6
+ aiohttp==3.9.1
7
+ aiosignal==1.3.1
8
+ alabaster==0.7.16
9
+ albumentations==1.3.1
10
+ altair==5.2.0
11
+ annotated-types==0.6.0
12
+ antlr4-python3-runtime==4.9.3
13
+ anyio==4.0.0
14
+ appdirs==1.4.4
15
+ apturl==0.5.2
16
+ arabic-reshaper==3.0.0
17
+ argon2-cffi==23.1.0
18
+ argon2-cffi-bindings==21.2.0
19
+ arrow==1.3.0
20
+ asn1crypto==1.5.1
21
+ astor==0.8.1
22
+ asttokens==2.4.1
23
+ astunparse==1.6.3
24
+ async-lru==2.0.4
25
+ async-timeout==4.0.3
26
+ attrs==23.1.0
27
+ audioread==3.0.1
28
+ awscli==1.22.34
29
+ Babel==2.13.1
30
+ backoff==2.2.1
31
+ bcrypt==3.2.0
32
+ beautifulsoup4==4.12.2
33
+ beniget==0.4.1
34
+ bleach==6.1.0
35
+ blessings==1.7
36
+ blinker==1.4
37
+ boto3==1.34.54
38
+ botocore==1.34.54
39
+ bqplot==0.12.42
40
+ branca==0.7.0
41
+ Brlapi==0.8.3
42
+ Brotli==1.0.9
43
+ build==1.1.1
44
+ cachetools==5.3.2
45
+ cairocffi==1.7.1
46
+ CairoSVG==2.7.1
47
+ Cartopy==0.23.0
48
+ certifi==2023.7.22
49
+ cffi==1.16.0
50
+ chardet==4.0.0
51
+ charset-normalizer==3.3.1
52
+ click==8.1.7
53
+ click-plugins==1.1.1
54
+ clifford==1.4.0
55
+ cligj==0.7.2
56
+ colorama==0.4.4
57
+ coloredlogs==15.0.1
58
+ colour==0.1.5
59
+ comm==0.1.4
60
+ command-not-found==0.3
61
+ contourpy==1.2.0
62
+ coverage==5.3.1
63
+ cryptography==42.0.5
64
+ cssselect2==0.7.0
65
+ csv342==1.0.1
66
+ cupshelpers==1.0
67
+ cycler==0.10.0
68
+ Cython==3.0.7
69
+ data-gradients==0.3.2
70
+ dataclasses-json==0.6.3
71
+ datasets==2.18.0
72
+ dateutils==0.6.12
73
+ dbus-python==1.2.18
74
+ debugpy==1.8.0
75
+ decorator==5.1.1
76
+ defer==1.0.6
77
+ defusedxml==0.7.1
78
+ Deprecated==1.2.14
79
+ deprecation==2.1.0
80
+ devkit==0.0.12
81
+ diffusers==0.28.2
82
+ dill==0.3.8
83
+ distlib==0.3.4
84
+ distro==1.7.0
85
+ distro-info==1.1+ubuntu0.2
86
+ dm-tree==0.1.8
87
+ docopt==0.6.2
88
+ docutils==0.17.1
89
+ dotadevkit==1.3.0
90
+ duckdb==0.10.3
91
+ duplicity==0.8.21
92
+ ee==0.2
93
+ einops==0.3.2
94
+ ephem==4.1.5
95
+ esda==2.5.1
96
+ espeak-phonemizer==1.3.1
97
+ et-xmlfile==1.1.0
98
+ exceptiongroup==1.1.3
99
+ executing==2.0.1
100
+ fasteners==0.14.1
101
+ fastjsonschema==2.18.1
102
+ filelock==3.6.0
103
+ fiona==1.9.5
104
+ Flask==1.1.2
105
+ flatbuffers==23.5.26
106
+ folium==0.15.1
107
+ fonttools==4.46.0
108
+ fqdn==1.5.1
109
+ frozenlist==1.4.1
110
+ fs==2.4.12
111
+ fsspec==2023.12.1
112
+ future==0.18.2
113
+ gast==0.5.4
114
+ GDAL==3.8.4
115
+ gdown==4.7.1
116
+ geoio==1.3.0
117
+ geojson==3.1.0
118
+ geomet==1.1.0
119
+ geopandas==0.14.2
120
+ georasters==0.5.29
121
+ ghp-import==2.1.0
122
+ giddy==2.3.5
123
+ gitdb==4.0.11
124
+ GitPython==3.1.40
125
+ google-api-core==2.16.2
126
+ google-api-python-client==2.116.0
127
+ google-auth==2.27.0
128
+ google-auth-httplib2==0.2.0
129
+ google-auth-oauthlib==1.2.0
130
+ google-cloud-vision==3.7.2
131
+ google-pasta==0.2.0
132
+ googleapis-common-protos==1.62.0
133
+ graphviz==0.20.3
134
+ greenlet==3.0.3
135
+ grpcio==1.62.1
136
+ grpcio-status==1.62.1
137
+ gTTS==2.5.1
138
+ gw_dsl_parser==0.1.48a6
139
+ gyp==0.1
140
+ h11==0.14.0
141
+ h5py==3.10.0
142
+ html2text==2024.2.26
143
+ html5lib==1.1
144
+ httpcore==1.0.2
145
+ httplib2==0.20.2
146
+ httpx==0.26.0
147
+ huggingface-hub==0.23.3
148
+ humanfriendly==10.0
149
+ hydra-core==1.3.2
150
+ idna==2.10
151
+ imagecodecs==2024.1.1
152
+ imagecodes==0.0.1
153
+ imagededup==0.3.2
154
+ imageio==2.33.1
155
+ imagesize==1.4.1
156
+ imgaug==0.4.0
157
+ importlib-metadata==4.6.4
158
+ importlib_resources==6.4.0
159
+ imutils==0.5.4
160
+ inequality==1.0.1
161
+ invisible-watermark==0.2.0
162
+ ipyevents==2.0.2
163
+ ipyfilechooser==0.6.0
164
+ ipykernel==6.26.0
165
+ ipyleaflet==0.18.1
166
+ ipython==8.17.2
167
+ ipython-genutils==0.2.0
168
+ ipytree==0.2.2
169
+ ipywidgets==8.1.1
170
+ isoduration==20.11.0
171
+ itsdangerous==2.1.2
172
+ jedi==0.19.1
173
+ jeepney==0.7.1
174
+ Jinja2==3.1.2
175
+ jmespath==1.0.1
176
+ joblib==1.3.2
177
+ json-tricks==3.16.1
178
+ json5==0.9.14
179
+ jsonpatch==1.33
180
+ jsonpointer==2.4
181
+ jsonschema==4.19.2
182
+ jsonschema-specifications==2023.7.1
183
+ jupyter==1.0.0
184
+ jupyter-console==6.6.3
185
+ jupyter-contrib-core==0.4.2
186
+ jupyter-contrib-nbextensions==0.7.0
187
+ jupyter-events==0.9.0
188
+ jupyter-highlight-selected-word==0.2.0
189
+ jupyter-lsp==2.2.0
190
+ jupyter-nbextensions-configurator==0.6.3
191
+ jupyter-tabnine==1.2.3
192
+ jupyter_client==8.6.0
193
+ jupyter_core==5.5.0
194
+ jupyter_server==2.11.1
195
+ jupyter_server_terminals==0.4.4
196
+ jupyterlab==4.0.7
197
+ jupyterlab-pygments==0.2.2
198
+ jupyterlab-widgets==3.0.9
199
+ jupyterlab_server==2.25.0
200
+ kanaries_track==0.0.5
201
+ keras==2.15.0
202
+ keyring==23.5.0
203
+ kiwisolver==1.4.5
204
+ labelImg==1.8.6
205
+ langchain==0.0.354
206
+ langchain-community==0.0.8
207
+ langchain-core==0.1.6
208
+ langsmith==0.0.77
209
+ language-selector==0.1
210
+ launchpadlib==1.10.16
211
+ lazr.restfulclient==0.14.4
212
+ lazr.uri==1.0.6
213
+ lazy_loader==0.3
214
+ leafmap==0.30.0
215
+ Levenshtein==0.25.1
216
+ libclang==16.0.6
217
+ libpysal==4.10
218
+ librosa==0.10.1
219
+ lightning==2.2.5
220
+ lightning-utilities==0.11.2
221
+ llvmlite==0.42.0
222
+ lockfile==0.12.2
223
+ louis==3.20.0
224
+ lxml==4.9.3
225
+ lz4==3.1.3+dfsg
226
+ macaroonbakery==1.3.1
227
+ Mako==1.1.3
228
+ mapclassify==2.6.1
229
+ Markdown==3.5.2
230
+ markdown-it-py==3.0.0
231
+ MarkupSafe==2.1.3
232
+ marshmallow==3.20.1
233
+ matplotlib==3.8.3
234
+ matplotlib-inline==0.1.6
235
+ mdurl==0.1.2
236
+ mergedeep==1.3.4
237
+ metview==1.15.0
238
+ mgwr==2.2.1
239
+ mistune==3.0.2
240
+ mkdocs==1.5.3
241
+ ml-dtypes==0.2.0
242
+ momepy==0.7.0
243
+ monotonic==1.6
244
+ more-itertools==8.10.0
245
+ mpl-toolkits.clifford==0.0.3
246
+ mplcursors==0.5.3
247
+ mpmath==1.3.0
248
+ mrcnn==0.2
249
+ msgpack==1.0.7
250
+ mtcnn==0.1.1
251
+ multidict==6.0.4
252
+ multiprocess==0.70.16
253
+ mypy-extensions==1.0.0
254
+ namex==0.0.7
255
+ nbclassic==1.0.0
256
+ nbclient==0.8.0
257
+ nbconvert==7.10.0
258
+ nbformat==5.9.2
259
+ nest-asyncio==1.5.8
260
+ netifaces==0.11.0
261
+ netron==7.8.6
262
+ networkx==3.2.1
263
+ nibabel==5.2.1
264
+ nilearn==0.10.3
265
+ nltk==3.8.1
266
+ notebook==7.0.6
267
+ notebook_shim==0.2.3
268
+ numba==0.59.0
269
+ numpy==1.23.0
270
+ nvidia-cublas-cu12==12.1.3.1
271
+ nvidia-cuda-cupti-cu12==12.1.105
272
+ nvidia-cuda-nvrtc-cu12==12.1.105
273
+ nvidia-cuda-runtime-cu12==12.1.105
274
+ nvidia-cudnn-cu12==8.9.2.26
275
+ nvidia-cufft-cu12==11.0.2.54
276
+ nvidia-curand-cu12==10.3.2.106
277
+ nvidia-cusolver-cu12==11.4.5.107
278
+ nvidia-cusparse-cu12==12.1.0.106
279
+ nvidia-nccl-cu12==2.18.1
280
+ nvidia-nvjitlink-cu12==12.3.101
281
+ nvidia-nvtx-cu12==12.1.105
282
+ oauthlib==3.2.0
283
+ olefile==0.46
284
+ omegaconf==2.3.0
285
+ onnx==1.13.0
286
+ onnxruntime==1.13.1
287
+ onnxsim==0.4.35
288
+ openai==1.6.1
289
+ opencv-contrib-python==4.8.1.78
290
+ opencv-python==4.9.0.80
291
+ opencv-python-headless==4.10.0.84
292
+ openpyxl==3.1.2
293
+ opt-einsum==3.3.0
294
+ oscrypto==1.3.0
295
+ osmnx==1.8.1
296
+ overrides==7.4.0
297
+ OWSLib==0.28.1
298
+ packaging==23.2
299
+ pandas==2.1.4
300
+ pandocfilters==1.5.0
301
+ paramiko==2.9.3
302
+ parsimonious==0.10.0
303
+ parso==0.8.3
304
+ patchify==0.2.3
305
+ pathspec==0.12.1
306
+ patsy==0.5.6
307
+ pbr==5.8.0
308
+ pexpect==4.8.0
309
+ pillow==10.2.0
310
+ pip-tools==7.4.0
311
+ platformdirs==2.6.2
312
+ plotly==5.18.0
313
+ ply==3.11
314
+ pointpats==2.4.0
315
+ pooch==1.8.0
316
+ prometheus-client==0.17.1
317
+ prompt-toolkit==3.0.39
318
+ proto-plus==1.23.0
319
+ protobuf==4.25.3
320
+ psutil==5.9.6
321
+ psycopg2==2.9.2
322
+ ptyprocess==0.7.0
323
+ PuLP==2.8.0
324
+ pure-eval==0.2.2
325
+ py-cpuinfo==9.0.0
326
+ pyarrow==14.0.2
327
+ pyarrow-hotfix==0.6
328
+ pyasn1==0.5.1
329
+ pyasn1-modules==0.3.0
330
+ pycairo==1.20.1
331
+ pycocotools==2.0.6
332
+ pycparser==2.21
333
+ PyCRS==1.0.2
334
+ pycups==2.0.1
335
+ pydantic==2.5.3
336
+ pydantic_core==2.14.6
337
+ pydeck==0.8.1b0
338
+ pyDeprecate==0.3.2
339
+ PyGithub==2.2.0
340
+ Pygments==2.16.1
341
+ PyGObject==3.42.1
342
+ pygrib==2.1.5
343
+ pygwalker==0.4.8.6
344
+ pyHanko==0.21.0
345
+ pyhanko-certvalidator==0.26.3
346
+ PyJWT==2.8.0
347
+ pymacaroons==0.13.0
348
+ PyNaCl==1.5.0
349
+ pyparsing==2.4.7
350
+ pypdf==4.1.0
351
+ pypng==0.20220715.0
352
+ pyproj==3.6.1
353
+ pyproject_hooks==1.0.0
354
+ PyQt5==5.15.6
355
+ PyQt5-sip==12.9.1
356
+ pyqt5ac==1.2.1
357
+ pyRFC3339==1.1
358
+ pyrsistent==0.18.1
359
+ pysal==24.1
360
+ pyseeyou==1.0.2
361
+ pyshp==2.3.1
362
+ PySocks==1.7.1
363
+ pystac==1.9.0
364
+ pystac-client==0.7.5
365
+ pytesseract==0.3.10
366
+ python-apt==2.4.0+ubuntu3
367
+ python-bidi==0.4.2
368
+ python-box==7.1.1
369
+ python-dateutil==2.8.2
370
+ python-debian==0.1.43+ubuntu1.1
371
+ python-dotenv==1.0.0
372
+ python-json-logger==2.0.7
373
+ python-Levenshtein==0.25.1
374
+ python-magic==0.4.27
375
+ python-slugify==8.0.4
376
+ pythran==0.10.0
377
+ pytorch-ignite==0.5.1
378
+ pytorch-lightning==2.2.5
379
+ pyttsx3==2.90
380
+ pytz==2022.1
381
+ PyWavelets==1.5.0
382
+ pyxdg==0.27
383
+ PyYAML==6.0.1
384
+ pyyaml_env_tag==0.1
385
+ pyzmq==25.1.1
386
+ qgis-plugin-ci==2.8.8
387
+ qrcode==7.4.2
388
+ QScintilla==2.11.6
389
+ qtconsole==5.4.4
390
+ QtPy==2.4.1
391
+ quantecon==0.7.2
392
+ qudida==0.0.4
393
+ rapidfuzz==3.9.3
394
+ rasterio==1.3.9
395
+ rasterstats==0.19.0
396
+ referencing==0.30.2
397
+ regex==2024.5.15
398
+ reportlab==3.6.8
399
+ requests==2.31.0
400
+ requests-oauthlib==1.3.1
401
+ requests-toolbelt==1.0.0
402
+ rfc3339-validator==0.1.4
403
+ rfc3986-validator==0.1.1
404
+ rich==13.7.0
405
+ roboflow==1.1.14
406
+ roman==3.3
407
+ rpds-py==0.10.6
408
+ rsa==4.9
409
+ Rtree==1.2.0
410
+ s3transfer==0.10.0
411
+ safetensors==0.4.3
412
+ scikit-fuzzy==0.4.2
413
+ scikit-image==0.22.0
414
+ scikit-learn==1.3.2
415
+ scipy==1.11.4
416
+ scooby==0.9.2
417
+ seaborn==0.13.0
418
+ SecretStorage==3.3.1
419
+ segment-analytics-python==2.2.3
420
+ segment-anything==1.0
421
+ segregation==2.5
422
+ Send2Trash==1.8.2
423
+ sentinelhub==3.10.2
424
+ sentinelloader @ git+https://github.com/flaviostutz/sentinelloader@b107badeedf4ccc7c9eed74c5663d849348d887a
425
+ sentinelsat==1.2.1
426
+ shapely==2.0.2
427
+ simplejson==3.19.2
428
+ six==1.16.0
429
+ smmap==5.0.1
430
+ sniffio==1.3.0
431
+ snowballstemmer==2.2.0
432
+ snuggs==1.4.7
433
+ sounddevice==0.4.6
434
+ soundfile==0.12.1
435
+ soupsieve==2.5
436
+ soxr==0.3.7
437
+ spaghetti==1.7.5.post1
438
+ sparse==0.15.1
439
+ spglm==1.1.0
440
+ Sphinx==4.0.3
441
+ sphinx-rtd-theme==1.3.0
442
+ sphinxcontrib-applehelp==1.0.8
443
+ sphinxcontrib-devhelp==1.0.6
444
+ sphinxcontrib-htmlhelp==2.0.5
445
+ sphinxcontrib-jquery==4.1
446
+ sphinxcontrib-jsmath==1.0.1
447
+ sphinxcontrib-qthelp==1.0.7
448
+ sphinxcontrib-serializinghtml==1.1.10
449
+ spint==1.0.7
450
+ splot==1.1.5.post1
451
+ spopt==0.6.0
452
+ spreg==1.4.2
453
+ spvcm==0.3.0
454
+ SQLAlchemy==2.0.25
455
+ sqlglot==24.0.1
456
+ ssh-import-id==5.11
457
+ stack-data==0.6.3
458
+ statsmodels==0.14.1
459
+ streamlit==1.29.0
460
+ stringcase==1.2.0
461
+ super-gradients==3.6.0
462
+ supervision==0.17.1
463
+ svglib==1.5.1
464
+ sympy==1.12
465
+ systemd-python==234
466
+ tenacity==8.2.3
467
+ tensorboard==2.15.2
468
+ tensorboard-data-server==0.7.2
469
+ tensorflow==2.15.0.post1
470
+ tensorflow-estimator==2.15.0
471
+ tensorflow-io-gcs-filesystem==0.36.0
472
+ termcolor==1.1.0
473
+ terminado==0.17.1
474
+ tesseract==0.1.3
475
+ text-unidecode==1.3
476
+ thop==0.1.1.post2209072238
477
+ threadpoolctl==3.2.0
478
+ tifffile==2023.12.9
479
+ tiffile==2018.10.18
480
+ tinycss2==1.2.1
481
+ tinytools==1.1.1
482
+ tkintertable==1.3.3
483
+ tobler==0.11.2
484
+ tokenizers==0.19.1
485
+ toml==0.10.2
486
+ tomli==2.0.1
487
+ tomli_w==1.0.0
488
+ toolz==0.12.0
489
+ torch==2.1.1
490
+ torcheval==0.0.7
491
+ torchinfo==1.8.0
492
+ torchmetrics==0.8.0
493
+ torchvision==0.16.1
494
+ torchviz==0.0.2
495
+ tornado==6.4
496
+ tqdm==4.66.1
497
+ traitlets==5.13.0
498
+ traittypes==0.2.1
499
+ transformers==4.41.2
500
+ transifex-python==3.5.0
501
+ treelib==1.6.1
502
+ trimesh==4.1.7
503
+ triton==2.1.0
504
+ types-python-dateutil==2.8.19.14
505
+ typing-inspect==0.9.0
506
+ typing_extensions==4.8.0
507
+ tzdata==2023.3
508
+ tzlocal==5.2
509
+ tzwhere==3.0.3
510
+ ubuntu-drivers-common==0.0.0
511
+ ubuntu-pro-client==8001
512
+ ufoLib2==0.13.1
513
+ ufw==0.36.1
514
+ ultralytics==8.0.196
515
+ unattended-upgrades==0.1
516
+ unicodedata2==14.0.0
517
+ uri-template==1.3.0
518
+ uritemplate==4.1.1
519
+ uritools==4.0.2
520
+ urllib3==2.0.7
521
+ usb-creator==0.3.7
522
+ utm==0.7.0
523
+ validators==0.22.0
524
+ virtualenv==20.13.0+ds
525
+ wadllib==1.3.6
526
+ wasmtime==21.0.0
527
+ watchdog==3.0.0
528
+ wavio==0.0.8
529
+ wcwidth==0.2.8
530
+ webcolors==1.13
531
+ webencodings==0.5.1
532
+ websocket-client==1.6.4
533
+ Werkzeug==3.0.1
534
+ whitebox==2.3.1
535
+ whiteboxgui==2.3.0
536
+ widgetsnbextension==4.0.9
537
+ wrapt==1.14.1
538
+ xarray==2024.1.1
539
+ xdg==5
540
+ xhtml2pdf==0.2.11
541
+ xkit==0.0.0
542
+ xmltodict==0.13.0
543
+ xxhash==3.4.1
544
+ xyzservices==2023.10.1
545
+ yarl==1.9.4
546
+ zipp==1.0.0
streamlit_app.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import sys
3
+
4
+ import torch
5
+ from model.CBAM.reunet_cbam import reunet_cbam
6
+ import cv2
7
+ from PIL import Image
8
+ from model.transform import transforms
9
+ import numpy as np
10
+ from model.unet import UNET
11
+ from area import pixel_to_sqft, process_and_overlay_image
12
+ import matplotlib.pyplot as plt
13
+ import time
14
+ import os
15
+ import csv
16
+ from datetime import datetime
17
+ from split_merge import split, merge
18
+ from convert_raster import convert_gtiff_to_8bit
19
+ import shutil
20
+
21
+ patches_folder = 'data/Patches'
22
+ pred_patches = 'data/Patch_pred'
23
+ os.makedirs(patches_folder, exist_ok=True)
24
+ os.makedirs(pred_patches, exist_ok=True)
25
+
26
+ # Define the upload directories
27
+ UPLOAD_DIR = "data/uploaded_images"
28
+ MASK_DIR = "data/generated_masks"
29
+ CSV_LOG_PATH = "image_log.csv"
30
+
31
+ # Create the directories if they don't exist
32
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
33
+ os.makedirs(MASK_DIR, exist_ok=True)
34
+
35
+ model = reunet_cbam()
36
+ model.load_state_dict(torch.load('latest.pth', map_location='cpu')['model_state_dict'])
37
+ model.eval()
38
+
39
+ def predict(image):
40
+ with torch.no_grad():
41
+ output = model(image.unsqueeze(0))
42
+ return output.squeeze().cpu().numpy()
43
+
44
+ def log_image_details(image_id, image_filename, mask_filename):
45
+ file_exists = os.path.exists(CSV_LOG_PATH)
46
+
47
+ current_time = datetime.now()
48
+ date = current_time.strftime('%Y-%m-%d')
49
+ time = current_time.strftime('%H:%M:%S')
50
+
51
+ with open(CSV_LOG_PATH, mode='a', newline='') as file:
52
+ writer = csv.writer(file)
53
+ if not file_exists:
54
+ writer.writerow(['S.No', 'Date', 'Time', 'Image ID', 'Image Filename', 'Mask Filename'])
55
+
56
+ # Get the next S.No
57
+ if file_exists:
58
+ with open(CSV_LOG_PATH, mode='r') as f:
59
+ reader = csv.reader(f)
60
+ sno = sum(1 for row in reader)
61
+ else:
62
+ sno = 1
63
+
64
+ writer.writerow([sno, date, time, image_id, image_filename, mask_filename])
65
+
66
+ def overlay_mask(image, mask, alpha=0.5, rgb=[255, 0, 0]):
67
+ # Ensure image is 3-channel
68
+ if len(image.shape) == 2:
69
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
70
+
71
+ # Ensure mask is binary and same shape as image
72
+ mask = mask.astype(bool)
73
+ if mask.shape[:2] != image.shape[:2]:
74
+ raise ValueError("Mask and image must have the same dimensions")
75
+
76
+ # Create color overlay
77
+ color_mask = np.zeros_like(image)
78
+ color_mask[mask] = rgb
79
+
80
+ # Blend the image and color mask
81
+ output = cv2.addWeighted(image, 1, color_mask, alpha, 0)
82
+
83
+ return output
84
+
85
+ def reset_state():
86
+ st.session_state.file_uploaded = False
87
+ st.session_state.filename = None
88
+ st.session_state.mask_filename = None
89
+ st.session_state.tr_img = None
90
+ if 'page' in st.session_state:
91
+ del st.session_state.page
92
+
93
+ def upload_page():
94
+ if 'file_uploaded' not in st.session_state:
95
+ st.session_state.file_uploaded = False
96
+
97
+ if 'filename' not in st.session_state:
98
+ st.session_state.filename = None
99
+
100
+ if 'mask_filename' not in st.session_state:
101
+ st.session_state.mask_filename = None
102
+
103
+ image = st.file_uploader('Choose a satellite image', type=['jpg', 'png', 'jpeg', 'tiff', 'tif'])
104
+
105
+ if image is not None:
106
+ reset_state() # Reset the state when a new image is uploaded
107
+ bytes_data = image.getvalue()
108
+
109
+ timestamp = int(time.time())
110
+ original_filename = image.name
111
+ file_extension = os.path.splitext(original_filename)[1].lower()
112
+
113
+ if file_extension in ['.tiff', '.tif']:
114
+ filename = f"image_{timestamp}.tif"
115
+ else:
116
+ filename = f"image_{timestamp}.png"
117
+
118
+ filepath = os.path.join(UPLOAD_DIR, filename)
119
+
120
+ with open(filepath, "wb") as f:
121
+ f.write(bytes_data)
122
+
123
+ # Check if the uploaded file is a GeoTIFF
124
+ if file_extension in ['.tiff', '.tif']:
125
+ st.info('Processing GeoTIFF image...')
126
+ convert_gtiff_to_8bit(filepath)
127
+ st.success('GeoTIFF converted to 8-bit image')
128
+
129
+ img = Image.open(filepath)
130
+ st.image(img, caption='Uploaded Image', use_column_width=True)
131
+ st.success(f'Image saved as {filename}')
132
+
133
+ # Store the full path of the uploaded image
134
+ st.session_state.filename = filename
135
+
136
+ # Convert image to numpy array
137
+ img_array = np.array(img)
138
+
139
+ # Check if image shape is more than 650x650
140
+ if img_array.shape[0] > 650 or img_array.shape[1] > 650:
141
+ # Split image into patches
142
+ split(filepath, patch_size=256)
143
+
144
+ # Display buffer while analyzing
145
+ with st.spinner('Analyzing...'):
146
+ # Predict on each patch
147
+ for patch_filename in os.listdir(patches_folder):
148
+ if patch_filename.endswith(".png"):
149
+ patch_path = os.path.join(patches_folder, patch_filename)
150
+ patch_img = Image.open(patch_path)
151
+ patch_tr_img = transforms(patch_img)
152
+ prediction = predict(patch_tr_img)
153
+ mask = (prediction > 0.5).astype(np.uint8) * 255
154
+ mask_filename = f"mask_{patch_filename}"
155
+ mask_filepath = os.path.join(pred_patches, mask_filename)
156
+ Image.fromarray(mask).save(mask_filepath)
157
+
158
+ # Merge predicted patches
159
+ merged_mask_filename = f"generated_masks/mask_{timestamp}.png"
160
+ merge(pred_patches, merged_mask_filename, img_array.shape)
161
+
162
+ # Save merged mask
163
+ st.session_state.mask_filename = merged_mask_filename
164
+
165
+ # Clean up temporary patch files
166
+ st.info('Cleaning up temporary files...')
167
+ shutil.rmtree(patches_folder)
168
+ shutil.rmtree(pred_patches)
169
+ os.makedirs(patches_folder) # Recreate empty folders
170
+ os.makedirs(pred_patches)
171
+ st.success('Temporary files cleaned up')
172
+ else:
173
+ # Predict on whole image
174
+ st.session_state.tr_img = transforms(img)
175
+ prediction = predict(st.session_state.tr_img)
176
+ mask = (prediction > 0.5).astype(np.uint8) * 255
177
+ mask_filename = f"mask_{timestamp}.png"
178
+ mask_filepath = os.path.join(MASK_DIR, mask_filename)
179
+ Image.fromarray(mask).save(mask_filepath)
180
+ st.session_state.mask_filename = mask_filepath
181
+
182
+ st.session_state.file_uploaded = True
183
+
184
+ if st.session_state.file_uploaded and st.button('View result'):
185
+ if st.session_state.filename is None:
186
+ st.error("Please upload an image before viewing the result.")
187
+ else:
188
+ st.success('Image analyzed')
189
+ st.session_state.page = 'result'
190
+ st.rerun()
191
+
192
+ def result_page():
193
+ st.title('Analysis Result')
194
+
195
+ if 'filename' not in st.session_state or 'mask_filename' not in st.session_state:
196
+ st.error("No image or mask file found. Please upload and process an image first.")
197
+ if st.button('Back to Upload'):
198
+ reset_state()
199
+ st.rerun()
200
+ return
201
+
202
+ col1, col2 = st.columns(2)
203
+
204
+ # Display original image
205
+ original_img_path = os.path.join(UPLOAD_DIR, st.session_state.filename)
206
+ if os.path.exists(original_img_path):
207
+ original_img = Image.open(original_img_path)
208
+ col1.image(original_img, caption='Original Image', use_column_width=True)
209
+ else:
210
+ col1.error(f"Original image file not found: {original_img_path}")
211
+
212
+ # Display predicted mask
213
+ mask_path = st.session_state.mask_filename
214
+ if os.path.exists(mask_path):
215
+ mask = Image.open(mask_path)
216
+ col2.image(mask, caption='Predicted Mask', use_column_width=True)
217
+ else:
218
+ col2.error(f"Predicted mask file not found: {mask_path}")
219
+
220
+ st.subheader("Overlay with Area of Buildings (sqft)")
221
+
222
+ # Display overlayed image
223
+ if os.path.exists(original_img_path) and os.path.exists(mask_path):
224
+ original_np = cv2.imread(original_img_path)
225
+ mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
226
+
227
+ # Ensure mask is binary
228
+ _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
229
+
230
+ # Resize mask to match original image size if necessary
231
+ if original_np.shape[:2] != mask_np.shape[:2]:
232
+ mask_np = cv2.resize(mask_np, (original_np.shape[1], original_np.shape[0]))
233
+
234
+ # Process and overlay image
235
+ overlay_img = process_and_overlay_image(original_np, mask_np, 'output.png')
236
+
237
+ st.image(overlay_img, caption='Overlay Image', use_column_width=True)
238
+ else:
239
+ st.error("Image or mask file not found for overlay.")
240
+
241
+ if st.button('Back to Upload'):
242
+ reset_state()
243
+ st.rerun()
244
+
245
+ def main():
246
+ st.title('Building area estimation')
247
+
248
+ if 'page' not in st.session_state:
249
+ st.session_state.page = 'upload'
250
+
251
+ if st.session_state.page == 'upload':
252
+ upload_page()
253
+ elif st.session_state.page == 'result':
254
+ result_page()
255
+
256
+ if __name__ == '__main__':
257
+ main()