from abc import ABC, abstractmethod from enum import Enum from typing import Tuple, Optional import tensorflow as tf from tensorflow.keras.layers import * from tensorflow.keras.models import * class BaseUNet(ABC): """ Base Interface for UNet """ def __init__(self, model: Model): self.model: Model = model def get_model(self): return self.model @staticmethod @abstractmethod def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple): pass class UNet(Enum): """ Enum class defining different architecture types available """ DEFAULT = 0 DEFAULT_IMAGENET_EMBEDDING = 1 RESNET = 3 RESIDUAL_ATTENTION_UNET_SEPARABLE_CONV = 4 def build_model(self, input_size: Tuple[int, int, int], filters: Optional[Tuple] = None, kernels: Optional[Tuple] = None) -> BaseUNet: # set default filters if filters is None: filters = (16, 32, 64, 128, 256) # set default kernels if kernels is None: kernels = list(3 for _ in range(len(filters))) # check kernels and filters if len(filters) != len(kernels): raise Exception('Kernels and filter count has to match.') if self == UNet.DEFAULT_IMAGENET_EMBEDDING: print('Using default UNet model with imagenet embedding') return UNetDefault.build_model(input_size, filters, kernels, use_embedding=True) elif self == UNet.RESNET: print('Using UNet Resnet model') return UNet_resnet.build_model(input_size, filters, kernels) elif self == UNet.RESIDUAL_ATTENTION_UNET_SEPARABLE_CONV: print('Using UNet Resnet model with attention mechanism and separable convolutions') return UNet_ResNet_Attention_SeparableConv.build_model(input_size, filters, kernels) print('Using default UNet model') return UNetDefault.build_model(input_size, filters, kernels, use_embedding=False) class Attention(Layer): def __init__(self, **kwargs): super(Attention, self).__init__(**kwargs) def build(self, input_shape): # Create a trainable weight variable for this layer. self.kernel = self.add_weight(name='kernel', shape=(input_shape[-1], 1), initializer='glorot_normal', trainable=True) self.bias = self.add_weight(name='bias', shape=(1,), initializer='zeros', trainable=True) super(Attention, self).build(input_shape) # Be sure to call this at the end def call(self, x): attention = tf.nn.softmax(tf.matmul(x, self.kernel) + self.bias, axis=-1) return tf.multiply(x, attention) def compute_output_shape(self, input_shape): return input_shape class UNet_ResNet_Attention_SeparableConv(BaseUNet): """ UNet architecture with resnet blocks, attention mechanism and separable convolutions """ @staticmethod def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple): p0 = Input(shape=input_size) conv_outputs = [] first_layer = SeparableConv2D(filters[0], kernels[0], padding='same')(p0) int_layer = first_layer for i, f in enumerate(filters): int_layer, skip = UNet_ResNet_Attention_SeparableConv.down_block(int_layer, f, kernels[i]) conv_outputs.append(skip) int_layer = UNet_ResNet_Attention_SeparableConv.bottleneck(int_layer, filters[-1], kernels[-1]) conv_outputs = list(reversed(conv_outputs)) reversed_filter = list(reversed(filters)) reversed_kernels = list(reversed(kernels)) for i, f in enumerate(reversed_filter): if i + 1 < len(reversed_filter): num_filters_next = reversed_filter[i + 1] num_kernels_next = reversed_kernels[i + 1] else: num_filters_next = f num_kernels_next = reversed_kernels[i] int_layer = UNet_ResNet_Attention_SeparableConv.up_block(int_layer, conv_outputs[i], f, num_filters_next, num_kernels_next) int_layer = Attention()(int_layer) # concat. with the first layer int_layer = Concatenate()([first_layer, int_layer]) int_layer = SeparableConv2D(filters[0], kernels[0], padding="same", activation="relu")(int_layer) outputs = SeparableConv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer) model = Model(p0, outputs) return UNet_ResNet_Attention_SeparableConv(model) @staticmethod def down_block(x, num_filters: int = 64, kernel: int = 3): # down-sample inputs x = SeparableConv2D(num_filters, kernel, padding='same', strides=2, dilation_rate = 2)(x) # inner block out = SeparableConv2D(num_filters, kernel, padding='same')(x) # out = BatchNormalization()(out) out = Activation('relu')(out) out = SeparableConv2D(num_filters, kernel, padding='same')(out) # merge with the skip connection out = Add()([out, x]) # out = BatchNormalization()(out) return Activation('relu')(out), x @staticmethod def up_block(x, skip, num_filters: int = 64, num_filters_next: int = 64, kernel: int = 3): # add U-Net skip connection - before up-sampling concat = Concatenate()([x, skip]) # inner block out = SeparableConv2D(num_filters, kernel, padding='same', dilation_rate = 2)(concat) # out = BatchNormalization()(out) out = Activation('relu')(out) out = SeparableConv2D(num_filters, kernel, padding='same')(out) # merge with the skip connection out = Add()([out, x]) # out = BatchNormalization()(out) out = Activation('relu')(out) # up-sample out = UpSampling2D((2, 2))(out) out = SeparableConv2D(num_filters_next, kernel, padding='same')(out) # out = BatchNormalization()(out) return Activation('relu')(out) @staticmethod def bottleneck(x, num_filters: int = 64, kernel: int = 3): # inner block out = SeparableConv2D(num_filters, kernel, padding='same', dilation_rate = 2)(x) # out = BatchNormalization()(out) out = Activation('relu')(out) out = SeparableConv2D(num_filters, kernel, padding='same')(out) out = Add()([out, x]) # out = BatchNormalization()(out) return Activation('relu')(out) # Class for UNet with Resnet blocks class UNet_resnet(BaseUNet): """ UNet architecture with resnet blocks """ @staticmethod def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple): p0 = Input(shape=input_size) conv_outputs = [] first_layer = Conv2D(filters[0], kernels[0], padding='same')(p0) int_layer = first_layer for i, f in enumerate(filters): int_layer, skip = UNet_resnet.down_block(int_layer, f, kernels[i]) conv_outputs.append(skip) int_layer = UNet_resnet.bottleneck(int_layer, filters[-1], kernels[-1]) conv_outputs = list(reversed(conv_outputs)) reversed_filter = list(reversed(filters)) reversed_kernels = list(reversed(kernels)) for i, f in enumerate(reversed_filter): if i + 1 < len(reversed_filter): num_filters_next = reversed_filter[i + 1] num_kernels_next = reversed_kernels[i + 1] else: num_filters_next = f num_kernels_next = reversed_kernels[i] int_layer = UNet_resnet.up_block(int_layer, conv_outputs[i], f, num_filters_next, num_kernels_next) # concat. with the first layer int_layer = Concatenate()([first_layer, int_layer]) int_layer = Conv2D(filters[0], kernels[0], padding="same", activation="relu")(int_layer) outputs = Conv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer) model = Model(p0, outputs) return UNet_resnet(model) @staticmethod def down_block(x, num_filters: int = 64, kernel: int = 3): # down-sample inputs x = Conv2D(num_filters, kernel, padding='same', strides=2)(x) # inner block out = Conv2D(num_filters, kernel, padding='same')(x) # out = BatchNormalization()(out) out = Activation('relu')(out) out = Conv2D(num_filters, kernel, padding='same')(out) # merge with the skip connection out = Add()([out, x]) # out = BatchNormalization()(out) return Activation('relu')(out), x @staticmethod def up_block(x, skip, num_filters: int = 64, num_filters_next: int = 64, kernel: int = 3): # add U-Net skip connection - before up-sampling concat = Concatenate()([x, skip]) # inner block out = Conv2D(num_filters, kernel, padding='same')(concat) # out = BatchNormalization()(out) out = Activation('relu')(out) out = Conv2D(num_filters, kernel, padding='same')(out) # merge with the skip connection out = Add()([out, x]) # out = BatchNormalization()(out) out = Activation('relu')(out) # add U-Net skip connection - before up-sampling concat = Concatenate()([out, skip]) # up-sample # out = UpSampling2D((2, 2))(concat) out = Conv2DTranspose(num_filters_next, kernel, padding='same', strides=2)(concat) out = Conv2D(num_filters_next, kernel, padding='same')(out) # out = BatchNormalization()(out) return Activation('relu')(out) @staticmethod def bottleneck(x, filters, kernel: int = 3): x = Conv2D(filters, kernel, padding='same', name='bottleneck')(x) # x = BatchNormalization()(x) return Activation('relu')(x) class UNetDefault(BaseUNet): """ UNet architecture from following github notebook for image segmentation: https://github.com/nikhilroxtomar/UNet-Segmentation-in-Keras-TensorFlow/blob/master/unet-segmentation.ipynb https://github.com/nikhilroxtomar/Polyp-Segmentation-using-UNET-in-TensorFlow-2.0 """ @staticmethod def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple, use_embedding: bool = True): p0 = Input(input_size) if use_embedding: mobilenet_model = tf.keras.applications.MobileNetV2( input_shape=input_size, include_top=False, weights='imagenet' ) mobilenet_model.trainable = False mn1 = mobilenet_model(p0) mn1 = Reshape((16, 16, 320))(mn1) conv_outputs = [] int_layer = p0 for f in filters: conv_output, int_layer = UNetDefault.down_block(int_layer, f) conv_outputs.append(conv_output) int_layer = UNetDefault.bottleneck(int_layer, filters[-1]) if use_embedding: int_layer = Concatenate()([int_layer, mn1]) conv_outputs = list(reversed(conv_outputs)) for i, f in enumerate(reversed(filters)): int_layer = UNetDefault.up_block(int_layer, conv_outputs[i], f) int_layer = Conv2D(filters[0] // 2, 3, padding="same", activation="relu")(int_layer) outputs = Conv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer) model = Model(p0, outputs) return UNetDefault(model) @staticmethod def down_block(x, filters, kernel_size=(3, 3), padding="same", strides=1): c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x) # c = BatchNormalization()(c) p = MaxPool2D((2, 2), (2, 2))(c) return c, p @staticmethod def up_block(x, skip, filters, kernel_size=(3, 3), padding="same", strides=1): us = UpSampling2D((2, 2))(x) c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(us) # c = BatchNormalization()(c) concat = Concatenate()([c, skip]) c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat) # c = BatchNormalization()(c) return c @staticmethod def bottleneck(x, filters, kernel_size=(3, 3), padding="same", strides=1): c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x) # c = BatchNormalization()(c) return c if __name__ == "__main__": filters = (64, 128, 128, 256, 256, 512) kernels = (7, 7, 7, 3, 3, 3) input_image_size = (256, 256, 3) # model = UNet_resnet() # model = model.build_model(input_size=input_image_size,filters=filters,kernels=kernels) # print(model.summary()) # __init__() missing 1 required positional argument: 'model' model = UNetDefault.build_model(input_size=input_image_size, filters=filters, kernels=kernels) print(model.summary())