diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..ea97f612fc0722c890a4d3d11fb48137946a14ef 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +examples/all.jpg filter=lfs diff=lfs merge=lfs -text +examples/DRIVE.tif filter=lfs diff=lfs merge=lfs -text +examples/hrf.png filter=lfs diff=lfs merge=lfs -text +examples/LES.png filter=lfs diff=lfs merge=lfs -text +examples/tmp_upload.png filter=lfs diff=lfs merge=lfs -text +examples/ukbb.png filter=lfs diff=lfs merge=lfs -text diff --git a/AV/Tools/AVclassifiation.py b/AV/Tools/AVclassifiation.py new file mode 100644 index 0000000000000000000000000000000000000000..979357ccbf918a0579539375d96d5640fb00f450 --- /dev/null +++ b/AV/Tools/AVclassifiation.py @@ -0,0 +1,202 @@ +import cv2 +import numpy as np +import os +# import natsort +import pandas as pd +from skimage.morphology import skeletonize, erosion, square,dilation +from AV.Tools.BinaryPostProcessing import binaryPostProcessing3 +from PIL import Image +from scipy.signal import convolve2d +from collections import OrderedDict +import time +######################################### + + + + +def Skeleton(a_or_v, a_and_v): + th = np.uint8(a_and_v) + # Distance transform for maximum diameter + vessels = th.copy() + dist = cv2.distanceTransform(a_or_v, cv2.DIST_L2, 3) + thinned = np.uint8(skeletonize((vessels / 255))) * 255 + return thinned, dist + + +def cal_crosspoint(vessel): + # Removing bifurcation points by using specially designed kernels + # Can be optimized further! (not the best implementation) + thinned1, dist = Skeleton(vessel, vessel) + thh = thinned1.copy() + thh = thh / 255 + kernel1 = np.array([[1, 1, 1], [1, 10, 1], [1, 1, 1]]) + + th = convolve2d(thh, kernel1, mode="same") + for u in range(th.shape[0]): + for j in range(th.shape[1]): + if th[u, j] >= 13.0: + cv2.circle(vessel, (j, u), 2 * int(dist[u, j]), (0, 0, 0), -1) + # thi = cv2.cvtColor(thi, cv2.COLOR_BGR2GRAY) + return vessel + + +def AVclassifiation(out_path, PredAll1, PredAll2, VesselPredAll, DataSet=0, image_basename=''): + """ + predAll1: predition results of artery + predAll2: predition results of vein + VesselPredAll: predition results of vessel + DataSet: the length of dataset + image_basename: the name of saved mask + """ + + ImgN = DataSet + + for ImgNumber in range(ImgN): + + height, width = PredAll1.shape[2:4] + + VesselProb = VesselPredAll[ImgNumber, 0, :, :] + + ArteryProb = PredAll1[ImgNumber, 0, :, :] + VeinProb = PredAll2[ImgNumber, 0, :, :] + + VesselSeg = (VesselProb >= 0.1) & ((ArteryProb >0.2) | (VeinProb > 0.2)) + # VesselSeg = (VesselProb >= 0.5) & ((ArteryProb >= 0.5) | (VeinProb >= 0.5)) + crossSeg = (VesselProb >= 0.1) & ((ArteryProb >= 0.6) & (VeinProb >= 0.6)) + VesselSeg = binaryPostProcessing3(VesselSeg, removeArea=100, fillArea=20) + + vesselPixels = np.where(VesselSeg > 0) + + ArteryProb2 = np.zeros((height, width)) + VeinProb2 = np.zeros((height, width)) + crossProb2 = np.zeros((height, width)) + image_color = np.zeros((3, height, width), dtype=np.uint8) + for i in range(len(vesselPixels[0])): + row = vesselPixels[0][i] + col = vesselPixels[1][i] + probA = ArteryProb[row, col] + probV = VeinProb[row, col] + #probA,probV = softmax([probA,probV]) + ArteryProb2[row, col] = probA + VeinProb2[row, col] = probV + + test_use_vessel = np.zeros((height, width), np.uint8) + ArteryPred2 = ((ArteryProb2 >= 0.2) & (ArteryProb2 >= VeinProb2)) + VeinPred2 = ((VeinProb2 >= 0.2) & (VeinProb2 >= ArteryProb2)) + + ArteryPred2 = binaryPostProcessing3(ArteryPred2, removeArea=100, fillArea=20) + VeinPred2 = binaryPostProcessing3(VeinPred2, removeArea=100, fillArea=20) + + image_color[0, :, :] = ArteryPred2 * 255 + image_color[2, :, :] = VeinPred2 * 255 + image_color = image_color.transpose((1, 2, 0)) + + #Image.fromarray(image_color).save(os.path.join(out_path, f'{image_basename[ImgNumber].split(".")[0]}_ori.png')) + + imgBin_vessel = ArteryPred2 + VeinPred2 + imgBin_vessel[imgBin_vessel[:, :] == 2] = 1 + test_use_vessel = imgBin_vessel.copy() * 255 + + vessel = cal_crosspoint(test_use_vessel) + + contours_vessel, hierarchy_c = cv2.findContours(vessel, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + + # inter continuity + for vessel_seg in range(len(contours_vessel)): + C_vessel = np.zeros(vessel.shape, np.uint8) + C_vessel = cv2.drawContours(C_vessel, contours_vessel, vessel_seg, (255, 255, 255), cv2.FILLED) + cli = np.mean(VeinProb2[C_vessel == 255]) / np.mean(ArteryProb2[C_vessel == 255]) + if cli < 1: + image_color[ + (C_vessel[:, :] == 255) & (test_use_vessel[:, :] == 255)] = [255, 0, 0] + else: + image_color[ + (C_vessel[:, :] == 255) & (test_use_vessel[:, :] == 255)] = [0, 0, 255] + loop=0 + while loop<2: + # out vein continuity + vein = image_color[:, :, 2] + contours_vein, hierarchy_b = cv2.findContours(vein, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + + vein_size = [] + for z in range(len(contours_vein)): + vein_size.append(contours_vein[z].size) + vein_size = np.sort(np.array(vein_size)) + # image_color_copy = np.uint8(image_color).copy() + for vein_seg in range(len(contours_vein)): + judge_number = min(np.mean(vein_size),500) + # cv2.putText(image_color_copy, str(vein_seg), (int(contours_vein[vein_seg][0][0][0]), int(contours_vein[vein_seg][0][0][1])), 3, 1, + # color=(255, 0, 0), thickness=2) + if contours_vein[vein_seg].size < judge_number: + C_vein = np.zeros(vessel.shape, np.uint8) + C_vein = cv2.drawContours(C_vein, contours_vein, vein_seg, (255, 255, 255), cv2.FILLED) + max_diameter = np.max(Skeleton(C_vein, C_vein)[1]) + + image_color_copy_vein = image_color[:, :, 2].copy() + image_color_copy_arter = image_color[:, :, 0].copy() + # a_ori = cv2.drawContours(a_ori, contours_b, k, (0, 0, 0), cv2.FILLED) + image_color_copy_vein = cv2.drawContours(image_color_copy_vein, contours_vein, vein_seg, + (0, 0, 0), + cv2.FILLED) + # image_color[(C_cross[:, :] == 255) & (image_color[:, :, 1] == 255)] = [255, 0, 0] + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ( + 4 * int(np.ceil(max_diameter)), 4 * int(np.ceil(max_diameter)))) + C_vein_dilate = cv2.dilate(C_vein, kernel, iterations=1) + # cv2.imwrite(path_out_3, C_vein_dilate) + C_vein_dilate_judge = np.zeros(vessel.shape, np.uint8) + C_vein_dilate_judge[ + (C_vein_dilate[:, :] == 255) & (image_color_copy_vein == 255)] = 1 + C_arter_dilate_judge = np.zeros(vessel.shape, np.uint8) + C_arter_dilate_judge[ + (C_vein_dilate[:, :] == 255) & (image_color_copy_arter == 255)] = 1 + if (len(np.unique(C_vein_dilate_judge)) == 1) & ( + len(np.unique(C_arter_dilate_judge)) != 1) & (np.mean(VeinProb2[C_vein == 255]) < 0.6): + image_color[ + (C_vein[:, :] == 255) & (image_color[:, :, 2] == 255)] = [255, 0, + 0] + + # out artery continuity + arter = image_color[:, :, 0] + contours_arter, hierarchy_a = cv2.findContours(arter, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + arter_size = [] + for z in range(len(contours_arter)): + arter_size.append(contours_arter[z].size) + arter_size = np.sort(np.array(arter_size)) + for arter_seg in range(len(contours_arter)): + judge_number = min(np.mean(arter_size),500) + + if contours_arter[arter_seg].size < judge_number: + + C_arter = np.zeros(vessel.shape, np.uint8) + C_arter = cv2.drawContours(C_arter, contours_arter, arter_seg, (255, 255, 255), cv2.FILLED) + max_diameter = np.max(Skeleton(C_arter, test_use_vessel)[1]) + + image_color_copy_vein = image_color[:, :, 2].copy() + image_color_copy_arter = image_color[:, :, 0].copy() + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ( + 4 * int(np.ceil(max_diameter)), 4 * int(np.ceil(max_diameter)))) + image_color_copy_arter = cv2.drawContours(image_color_copy_arter, contours_arter, arter_seg, + (0, 0, 0), + cv2.FILLED) + C_arter_dilate = cv2.dilate(C_arter, kernel, iterations=1) + # image_color[(C_cross[:, :] == 255) & (image_color[:, :, 1] == 255)] = [255, 0, 0] + C_arter_dilate_judge = np.zeros(arter.shape, np.uint8) + C_arter_dilate_judge[ + (C_arter_dilate[:, :] == 255) & (image_color_copy_arter[:, :] == 255)] = 1 + C_vein_dilate_judge = np.zeros(arter.shape, np.uint8) + C_vein_dilate_judge[ + (C_arter_dilate[:, :] == 255) & (image_color_copy_vein[:, :] == 255)] = 1 + + if (len(np.unique(C_arter_dilate_judge)) == 1) & ( + len(np.unique(C_vein_dilate_judge)) != 1) & (np.mean(ArteryProb2[C_arter == 255]) < 0.6): + image_color[ + (C_arter[:, :] == 255) & (image_color[:, :, 0] == 255)] = [0, + 0, + 255] + loop=loop+1 + + # image_basename = os.path.basename(image_basename) + # Image.fromarray(image_color).save(os.path.join(out_path, f'{image_basename.split(".")[0]}.png')) + # Image.fromarray(np.uint8(VesselProb*255)).save(os.path.join(out_path, f'{image_basename.split(".")[0]}_vessel.png')) + return image_color + diff --git a/AV/Tools/AVclassifiationMetrics.py b/AV/Tools/AVclassifiationMetrics.py new file mode 100644 index 0000000000000000000000000000000000000000..e94d6c8f6508af385234cd7f80ed41722de4fa71 --- /dev/null +++ b/AV/Tools/AVclassifiationMetrics.py @@ -0,0 +1,465 @@ +import cv2 +import numpy as np +import matplotlib.pyplot as plt +import os +# import natsort +import pandas as pd +from skimage import morphology +from sklearn import metrics +from Tools.BinaryPostProcessing import binaryPostProcessing3 +from PIL import Image +from scipy.signal import convolve2d +import time + +######################################### +def softmax(x): + e_x = np.exp(x - np.max(x)) + return e_x / e_x.sum() + + +def Skeleton(a_or_v, a_and_v): + th = np.uint8(a_and_v) + # Distance transform for maximum diameter + vessels = th.copy() + dist = cv2.distanceTransform(a_or_v, cv2.DIST_L2, 3) + thinned = np.uint8(morphology.skeletonize((vessels / 255))) * 255 + return thinned, dist + + +def cal_crosspoint(vessel): + # Removing bifurcation points by using specially designed kernels + # Can be optimized further! (not the best implementation) + thinned1, dist = Skeleton(vessel, vessel) + thh = thinned1.copy() + thh = thh / 255 + kernel1 = np.array([[1, 1, 1], [1, 10, 1], [1, 1, 1]]) + + th = convolve2d(thh, kernel1, mode="same") + for u in range(th.shape[0]): + for j in range(th.shape[1]): + if th[u, j] >= 13.0: + cv2.circle(vessel, (j, u), 3 * int(dist[u, j]), (0, 0, 0), -1) + # thi = cv2.cvtColor(thi, cv2.COLOR_BGR2GRAY) + return vessel + + + +def AVclassifiation_pos_ve(out_path, PredAll1, PredAll2, VesselPredAll, DataSet=0, image_basename=''): + """ + predAll1: predition results of artery + predAll2: predition results of vein + VesselPredAll: predition results of vessel + DataSet: the length of dataset + image_basename: the name of saved mask + """ + + ImgN = DataSet + + for ImgNumber in range(ImgN): + + height, width = PredAll1.shape[2:4] + + VesselProb = VesselPredAll[ImgNumber, 0, :, :] + + ArteryProb = PredAll1[ImgNumber, 0, :, :] + VeinProb = PredAll2[ImgNumber, 0, :, :] + + VesselSeg = (VesselProb >= 0.1) & ((ArteryProb >= 0.2) | (VeinProb >= 0.2)) + # VesselSeg = (VesselProb >= 0.5) & ((ArteryProb >= 0.5) | (VeinProb >= 0.5)) + crossSeg = (VesselProb >= 0.1) & ((ArteryProb >= 0.6) & (VeinProb >= 0.6)) + VesselSeg = binaryPostProcessing3(VesselSeg, removeArea=100, fillArea=20) + + vesselPixels = np.where(VesselSeg > 0) + + ArteryProb2 = np.zeros((height, width)) + VeinProb2 = np.zeros((height, width)) + crossProb2 = np.zeros((height, width)) + image_color = np.zeros((3, height, width), dtype=np.uint8) + for i in range(len(vesselPixels[0])): + row = vesselPixels[0][i] + col = vesselPixels[1][i] + probA = ArteryProb[row, col] + probV = VeinProb[row, col] + ArteryProb2[row, col] = probA + VeinProb2[row, col] = probV + + test_use_vessel = np.zeros((height, width), np.uint8) + ArteryPred2 = ((ArteryProb2 >= 0.2) & (ArteryProb2 > VeinProb2)) + VeinPred2 = ((VeinProb2 >= 0.2) & (VeinProb2 > ArteryProb2)) + + ArteryPred2 = binaryPostProcessing3(ArteryPred2, removeArea=100, fillArea=20) + VeinPred2 = binaryPostProcessing3(VeinPred2, removeArea=100, fillArea=20) + + image_color[0, :, :] = ArteryPred2 * 255 + image_color[2, :, :] = VeinPred2 * 255 + image_color = image_color.transpose((1, 2, 0)) + + imgBin_vessel = ArteryPred2 + VeinPred2 + imgBin_vessel[imgBin_vessel[:, :] == 2] = 1 + test_use_vessel = imgBin_vessel.copy() * 255 + + vessel = cal_crosspoint(test_use_vessel) + + contours_vessel, hierarchy_c = cv2.findContours(vessel, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + + # inter continuity + for vessel_seg in range(len(contours_vessel)): + C_vessel = np.zeros(vessel.shape, np.uint8) + C_vessel = cv2.drawContours(C_vessel, contours_vessel, vessel_seg, (255, 255, 255), cv2.FILLED) + cli = np.mean(VeinProb2[C_vessel == 255]) / np.mean(ArteryProb2[C_vessel == 255]) + if cli < 1: + image_color[ + (C_vessel[:, :] == 255) & (test_use_vessel[:, :] == 255)] = [255, 0, 0] + else: + image_color[ + (C_vessel[:, :] == 255) & (test_use_vessel[:, :] == 255)] = [0, 0, 255] + + + Image.fromarray(image_color).save(os.path.join(out_path, f'{image_basename[ImgNumber].split(".")[0]}.png')) + + +def AVclassifiation(out_path, PredAll1, PredAll2, VesselPredAll, DataSet=0, image_basename=''): + """ + predAll1: predition results of artery + predAll2: predition results of vein + VesselPredAll: predition results of vessel + DataSet: the length of dataset + image_basename: the name of saved mask + """ + + ImgN = DataSet + + for ImgNumber in range(ImgN): + + height, width = PredAll1.shape[2:4] + + VesselProb = VesselPredAll[ImgNumber, 0, :, :] + + ArteryProb = PredAll1[ImgNumber, 0, :, :] + VeinProb = PredAll2[ImgNumber, 0, :, :] + + VesselSeg = (VesselProb >= 0.1) & ((ArteryProb >= 0.2) | (VeinProb >= 0.2)) + # VesselSeg = (VesselProb >= 0.5) & ((ArteryProb >= 0.5) | (VeinProb >= 0.5)) + crossSeg = (VesselProb >= 0.1) & ((ArteryProb >= 0.6) & (VeinProb >= 0.6)) + VesselSeg = binaryPostProcessing3(VesselSeg, removeArea=100, fillArea=20) + + vesselPixels = np.where(VesselSeg > 0) + + ArteryProb2 = np.zeros((height, width)) + VeinProb2 = np.zeros((height, width)) + crossProb2 = np.zeros((height, width)) + image_color = np.zeros((3, height, width), dtype=np.uint8) + for i in range(len(vesselPixels[0])): + row = vesselPixels[0][i] + col = vesselPixels[1][i] + probA = ArteryProb[row, col] + probV = VeinProb[row, col] + ArteryProb2[row, col] = probA + VeinProb2[row, col] = probV + + test_use_vessel = np.zeros((height, width), np.uint8) + ArteryPred2 = ((ArteryProb2 >= 0.2) & (ArteryProb2 > VeinProb2)) + VeinPred2 = ((VeinProb2 >= 0.2) & (VeinProb2 > ArteryProb2)) + + ArteryPred2 = binaryPostProcessing3(ArteryPred2, removeArea=100, fillArea=20) + VeinPred2 = binaryPostProcessing3(VeinPred2, removeArea=100, fillArea=20) + + image_color[0, :, :] = ArteryPred2 * 255 + image_color[2, :, :] = VeinPred2 * 255 + image_color = image_color.transpose((1, 2, 0)) + + imgBin_vessel = ArteryPred2 + VeinPred2 + imgBin_vessel[imgBin_vessel[:, :] == 2] = 1 + test_use_vessel = imgBin_vessel.copy() * 255 + + vessel = cal_crosspoint(test_use_vessel) + + contours_vessel, hierarchy_c = cv2.findContours(vessel, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + + # inter continuity + for vessel_seg in range(len(contours_vessel)): + C_vessel = np.zeros(vessel.shape, np.uint8) + C_vessel = cv2.drawContours(C_vessel, contours_vessel, vessel_seg, (255, 255, 255), cv2.FILLED) + cli = np.mean(VeinProb2[C_vessel == 255]) / np.mean(ArteryProb2[C_vessel == 255]) + if cli < 1: + image_color[ + (C_vessel[:, :] == 255) & (test_use_vessel[:, :] == 255)] = [255, 0, 0] + else: + image_color[ + (C_vessel[:, :] == 255) & (test_use_vessel[:, :] == 255)] = [0, 0, 255] + + # out vein continuity + vein = image_color[:, :, 2] + contours_vein, hierarchy_b = cv2.findContours(vein, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + + vein_size = [] + for z in range(len(contours_vein)): + vein_size.append(contours_vein[z].size) + vein_size = np.sort(np.array(vein_size)) + # image_color_copy = np.uint8(image_color).copy() + for vein_seg in range(len(contours_vein)): + judge_number = min(np.mean(vein_size),500) + # cv2.putText(image_color_copy, str(vein_seg), (int(contours_vein[vein_seg][0][0][0]), int(contours_vein[vein_seg][0][0][1])), 3, 1, + # color=(255, 0, 0), thickness=2) + if contours_vein[vein_seg].size < judge_number: + C_vein = np.zeros(vessel.shape, np.uint8) + C_vein = cv2.drawContours(C_vein, contours_vein, vein_seg, (255, 255, 255), cv2.FILLED) + max_diameter = np.max(Skeleton(C_vein, C_vein)[1]) + + image_color_copy_vein = image_color[:, :, 2].copy() + image_color_copy_arter = image_color[:, :, 0].copy() + # a_ori = cv2.drawContours(a_ori, contours_b, k, (0, 0, 0), cv2.FILLED) + image_color_copy_vein = cv2.drawContours(image_color_copy_vein, contours_vein, vein_seg, + (0, 0, 0), + cv2.FILLED) + # image_color[(C_cross[:, :] == 255) & (image_color[:, :, 1] == 255)] = [255, 0, 0] + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ( + 4 * int(np.ceil(max_diameter)), 4 * int(np.ceil(max_diameter)))) + C_vein_dilate = cv2.dilate(C_vein, kernel, iterations=1) + # cv2.imwrite(path_out_3, C_vein_dilate) + C_vein_dilate_judge = np.zeros(vessel.shape, np.uint8) + C_vein_dilate_judge[ + (C_vein_dilate[:, :] == 255) & (image_color_copy_vein == 255)] = 1 + C_arter_dilate_judge = np.zeros(vessel.shape, np.uint8) + C_arter_dilate_judge[ + (C_vein_dilate[:, :] == 255) & (image_color_copy_arter == 255)] = 1 + if (len(np.unique(C_vein_dilate_judge)) == 1) & ( + len(np.unique(C_arter_dilate_judge)) != 1) & (np.mean(VeinProb2[C_vein == 255]) < 0.5): + image_color[ + (C_vein[:, :] == 255) & (image_color[:, :, 2] == 255)] = [255, 0, + 0] + + # out artery continuity + arter = image_color[:, :, 0] + contours_arter, hierarchy_a = cv2.findContours(arter, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + arter_size = [] + for z in range(len(contours_arter)): + arter_size.append(contours_arter[z].size) + arter_size = np.sort(np.array(arter_size)) + for arter_seg in range(len(contours_arter)): + judge_number = min(np.mean(arter_size),500) + + if contours_arter[arter_seg].size < judge_number: + + C_arter = np.zeros(vessel.shape, np.uint8) + C_arter = cv2.drawContours(C_arter, contours_arter, arter_seg, (255, 255, 255), cv2.FILLED) + max_diameter = np.max(Skeleton(C_arter, test_use_vessel)[1]) + + image_color_copy_vein = image_color[:, :, 2].copy() + image_color_copy_arter = image_color[:, :, 0].copy() + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, ( + 4 * int(np.ceil(max_diameter)), 4 * int(np.ceil(max_diameter)))) + image_color_copy_arter = cv2.drawContours(image_color_copy_arter, contours_arter, arter_seg, + (0, 0, 0), + cv2.FILLED) + C_arter_dilate = cv2.dilate(C_arter, kernel, iterations=1) + # image_color[(C_cross[:, :] == 255) & (image_color[:, :, 1] == 255)] = [255, 0, 0] + C_arter_dilate_judge = np.zeros(arter.shape, np.uint8) + C_arter_dilate_judge[ + (C_arter_dilate[:, :] == 255) & (image_color_copy_arter[:, :] == 255)] = 1 + C_vein_dilate_judge = np.zeros(arter.shape, np.uint8) + C_vein_dilate_judge[ + (C_arter_dilate[:, :] == 255) & (image_color_copy_vein[:, :] == 255)] = 1 + + if (len(np.unique(C_arter_dilate_judge)) == 1) & ( + len(np.unique(C_vein_dilate_judge)) != 1) & (np.mean(VeinProb2[C_vein == 255]) < 0.5): + image_color[ + (C_arter[:, :] == 255) & (image_color[:, :, 0] == 255)] = [0, + 0, + 255] + + Image.fromarray(image_color).save(os.path.join(out_path, f'{image_basename[ImgNumber].split(".")[0]}.png')) + + + +def AVclassifiationMetrics_skeletonPixles(PredAll1,PredAll2,VesselPredAll,LabelAll1,LabelAll2,LabelVesselAll,DataSet=0, onlyMeasureSkeleton=False, strict_mode=True): + + """ + predAll1: predition results of artery + predAll2: predition results of vein + VesselPredAll: predition results of vessel + LabelAll1: label of artery + LabelAll2: label of vein + LabelVesselAll: label of vessel + DataSet: the length of dataset + onlyMeasureSkeleton: measure skeleton + strict_mode: strict + """ + + + ImgN = DataSet + + senList = [] + specList = [] + accList = [] + f1List = [] + ioulist = [] + diceList = [] + + + + senList_sk = [] + specList_sk = [] + accList_sk = [] + f1List_sk = [] + ioulist_sk = [] + diceList_sk = [] + + bad_case_count = 0 + bad_case_index = [] + + for ImgNumber in range(ImgN): + + height, width = PredAll1.shape[2:4] + + + VesselProb = VesselPredAll[ImgNumber, 0,:,:] + VesselLabel = LabelVesselAll[ImgNumber, 0, :, :] + + + ArteryLabel = LabelAll1[ImgNumber, 0, :, :] + VeinLabel = LabelAll2[ImgNumber, 0, :, :] + + ArteryProb = PredAll1[ImgNumber, 0,:,:] + VeinProb = PredAll2[ImgNumber, 0,:,:] + + if strict_mode: + VesselSeg = VesselLabel + else: + VesselSeg = (VesselProb >= 0.1) & ((ArteryProb >= 0.2) | (VeinProb >= 0.2)) + VesselSeg= binaryPostProcessing3(VesselSeg, removeArea=100, fillArea=20) + + vesselPixels = np.where(VesselSeg>0) + + ArteryProb2 = np.zeros((height,width)) + VeinProb2 = np.zeros((height,width)) + + for i in range(len(vesselPixels[0])): + row = vesselPixels[0][i] + col = vesselPixels[1][i] + probA = ArteryProb[row, col] + probV = VeinProb[row, col] + ArteryProb2[row, col] = probA + VeinProb2[row, col] = probV + + + ArteryLabelImg2= ArteryLabel.copy() + VeinLabelImg2= VeinLabel.copy() + ArteryLabelImg2 [VesselSeg == 0] = 0 + VeinLabelImg2 [VesselSeg == 0] = 0 + ArteryVeinLabelImg = np.zeros((height, width,3), np.uint8) + ArteryVeinLabelImg[ArteryLabelImg2>0] = (255, 0, 0) + ArteryVeinLabelImg[VeinLabelImg2>0] = (0, 0, 255) + ArteryVeinLabelCommon = np.bitwise_and(ArteryLabelImg2>0, VeinLabelImg2>0) + + if strict_mode: + ArteryPred2 = ArteryProb2 > 0.5 + VeinPred2 = VeinProb2 >= 0.5 + else: + ArteryPred2 = (ArteryProb2 > 0.2) & (ArteryProb2>VeinProb2) + VeinPred2 = (VeinProb2 >= 0.2) & (ArteryProb20, ArteryLabelImg2>0) # 真实为动脉,预测为动脉 + TNimg = np.bitwise_and(VeinPred2>0, VeinLabelImg2>0) # 真实为静脉,预测为静脉 + FPimg = np.bitwise_and(ArteryPred2>0, VeinLabelImg2>0) # 真实为静脉,预测为动脉 + FPimg = np.bitwise_and(FPimg, np.bitwise_not(ArteryVeinLabelCommon)) # 真实为静脉,预测为动脉,且不属于动静脉共存区域 + + FNimg = np.bitwise_and(VeinPred2>0, ArteryLabelImg2>0) # 真实为动脉,预测为静脉 + FNimg = np.bitwise_and(FNimg, np.bitwise_not(ArteryVeinLabelCommon)) # 真实为动脉,预测为静脉,且不属于动静脉共存区域 + + + if not onlyMeasureSkeleton: + TPa = np.count_nonzero(TPimg) + TNa = np.count_nonzero(TNimg) + FPa = np.count_nonzero(FPimg) + FNa = np.count_nonzero(FNimg) + + sensitivity = TPa/(TPa+FNa) + specificity = TNa/(TNa + FPa) + acc = (TPa + TNa) /(TPa + TNa + FPa + FNa) + f1 = 2*TPa/(2*TPa + FPa + FNa) + dice = 2*TPa/(2*TPa + FPa + FNa) + iou = TPa/(TPa + FPa + FNa) + #print('Pixel-wise Metrics', acc, sensitivity, specificity) + + senList.append(sensitivity) + specList.append(specificity) + accList.append(acc) + f1List.append(f1) + diceList.append(dice) + ioulist.append(iou) + # print('Avg Per:', np.mean(accList), np.mean(senList), np.mean(specList)) + + ################################################################################################## + """Skeleton Performance Measurement""" + Skeleton = np.uint8(morphology.skeletonize(VesselSeg)) + #np.save('./tmpfile/tmp_skeleton'+str(ImgNumber)+'.npy',Skeleton) + + ArterySkeletonLabel = cv2.bitwise_and(ArteryLabelImg2, ArteryLabelImg2, mask=Skeleton) + VeinSkeletonLabel = cv2.bitwise_and(VeinLabelImg2, VeinLabelImg2, mask=Skeleton) + + ArterySkeletonPred = cv2.bitwise_and(ArteryPred2, ArteryPred2, mask=Skeleton) + VeinSkeletonPred = cv2.bitwise_and(VeinPred2, VeinPred2, mask=Skeleton) + + + skeletonPixles = np.where(Skeleton >0) + + TPa_sk = 0 + TNa_sk = 0 + FPa_sk = 0 + FNa_sk = 0 + for i in range(len(skeletonPixles[0])): + row = skeletonPixles[0][i] + col = skeletonPixles[1][i] + if ArterySkeletonLabel[row, col] == 1 and ArterySkeletonPred[row, col] == 1: + TPa_sk = TPa_sk +1 + + elif VeinSkeletonLabel[row, col] == 1 and VeinSkeletonPred[row, col] == 1: + TNa_sk = TNa_sk + 1 + + elif ArterySkeletonLabel[row, col] == 1 and VeinSkeletonPred[row, col] == 1\ + and ArteryVeinLabelCommon[row, col] == 0: + FNa_sk = FNa_sk + 1 + + elif VeinSkeletonLabel[row, col] == 1 and ArterySkeletonPred[row, col] == 1\ + and ArteryVeinLabelCommon[row, col] == 0: + FPa_sk = FPa_sk + 1 + + else: + pass + + if (TPa_sk+FNa_sk)==0 and (TNa_sk + FPa_sk)==0 and (TPa_sk + TNa_sk + FPa_sk + FNa_sk)==0: + bad_case_count += 1 + bad_case_index.append(ImgNumber) + sensitivity_sk = TPa_sk/(TPa_sk+FNa_sk) + specificity_sk = TNa_sk/(TNa_sk + FPa_sk) + acc_sk = (TPa_sk + TNa_sk) /(TPa_sk + TNa_sk + FPa_sk + FNa_sk) + f1_sk = 2*TPa_sk/(2*TPa_sk + FPa_sk + FNa_sk) + dice_sk = 2*TPa_sk/(2*TPa_sk + FPa_sk + FNa_sk) + iou_sk = TPa_sk/(TPa_sk + FPa_sk + FNa_sk) + + senList_sk.append(sensitivity_sk) + specList_sk.append(specificity_sk) + accList_sk.append(acc_sk) + f1List_sk.append(f1_sk) + diceList_sk.append(dice_sk) + ioulist_sk.append(iou_sk) + #print('Skeletonal Metrics', acc_sk, sensitivity_sk, specificity_sk) + if onlyMeasureSkeleton: + print('Avg Skeleton Performance:', np.mean(accList_sk), np.mean(senList_sk), np.mean(specList_sk)) + return np.mean(accList_sk), np.mean(specList_sk),np.mean(senList_sk), np.mean(f1List_sk), np.mean(diceList_sk), np.mean(ioulist_sk), bad_case_index + else: + print('Avg Pixel-wise Performance:', np.mean(accList), np.mean(senList), np.mean(specList)) + return np.mean(accList), np.mean(specList),np.mean(senList),np.mean(f1List),np.mean(diceList),np.mean(ioulist) + + + +if __name__ == '__main__': + + + pro_path = r'F:\dw\RIP-AV\AV\log\DRIVE\running_result\ProMap_testset.npy' + ps = np.load(pro_path) + AVclassifiation(r'./', ps[:, 0:1, :, :], ps[:, 1:2, :, :], ps[:, 2:, :, :], DataSet=ps.shape[0], image_basename=[str(i)+'.png' for i in range(20)]) \ No newline at end of file diff --git a/AV/Tools/BGR2RGB.py b/AV/Tools/BGR2RGB.py new file mode 100644 index 0000000000000000000000000000000000000000..f8fd090a2caf089ccca9b2a03654d9b465a3848c --- /dev/null +++ b/AV/Tools/BGR2RGB.py @@ -0,0 +1,25 @@ + + +import cv2 + +def BGR2RGB(Image): + """ + + :param Image: + :return: RGBImage + """ + + #the input image is BGR image from OpenCV + RGBImage = cv2.cvtColor(Image, cv2.COLOR_BGR2RGB) + return RGBImage + + +def RGB2BGR(Image): + """ + + :param Image: + :return: BGRImage + """ + #the input image is RGB image + BGRImage = cv2.cvtColor(Image, cv2.COLOR_RGB2BGR) + return BGRImage \ No newline at end of file diff --git a/AV/Tools/BinaryPostProcessing.py b/AV/Tools/BinaryPostProcessing.py new file mode 100644 index 0000000000000000000000000000000000000000..3a5f0e7352e44d03ad368026d879517bb783a3b2 --- /dev/null +++ b/AV/Tools/BinaryPostProcessing.py @@ -0,0 +1,108 @@ + + +import numpy as np +from skimage import morphology, measure +from AV.Tools.Remove_small_holes import remove_small_holes +import scipy.ndimage.morphology as scipyMorphology + + +def binaryPostProcessing(BinaryImage, removeArea): + """ + Post process the binary segmentation + :param BinaryImage: + :param removeArea: + :return: Img_BW + """ + + BinaryImage[BinaryImage > 0] = 1 + + ###9s + # Img_BW = pymorph.binary(BinaryImage) + # Img_BW = pymorph.areaopen(Img_BW, removeArea) + # Img_BW = pymorph.areaclose(Img_BW, 50) + # Img_BW = np.uint8(Img_BW) + + ###2.5 s + # Img_BW = np.uint8(BinaryImage) + # Img_BW = ITK_LabelImage(Img_BW, removeArea) + # Img_BW[Img_BW >0] = 1 + + Img_BW = BinaryImage.copy() + BinaryImage_Label = measure.label(Img_BW) + for i, region in enumerate(measure.regionprops(BinaryImage_Label)): + if region.area < removeArea: + Img_BW[BinaryImage_Label == i + 1] = 0 + else: + pass + + Img_BW = morphology.binary_closing(Img_BW, morphology.disk(3)) + Img_BW = remove_small_holes(Img_BW, 50) + Img_BW = np.uint8(Img_BW) + + return Img_BW + + +################Three parameters +def binaryPostProcessing3(BinaryImage, removeArea, fillArea): + """ + Post process the binary image + :param BinaryImage: + :param removeArea: + :param fillArea: + :return: Img_BW + """ + + BinaryImage[BinaryImage>0]=1 + + ####takes 0.9s, result is good + Img_BW = BinaryImage.copy() + BinaryImage_Label = measure.label(Img_BW) + for i, region in enumerate(measure.regionprops(BinaryImage_Label)): + if region.area < removeArea: + Img_BW[BinaryImage_Label == i + 1] = 0 + else: + pass + + # ####takes 0.01s, result is bad + # temptime = time.time() + # Img_BW = morphology.remove_small_objects(BinaryImage, removeArea) + # print "binaryPostProcessing3, ITK_LabelImage time:", time.time() - temptime + + + Img_BW = morphology.binary_closing(Img_BW, morphology.square(3)) + # Img_BW = remove_small_holes(Img_BW, fillArea) + + Img_BW_filled = scipyMorphology.binary_fill_holes(Img_BW) + Img_BW_dif = np.uint8(Img_BW_filled) - np.uint8(Img_BW) + Img_BW_difLabel = measure.label(Img_BW_dif) + FilledImg = np.zeros(Img_BW.shape) + for i, region in enumerate(measure.regionprops(Img_BW_difLabel)): + if region.area < fillArea: + FilledImg[Img_BW_difLabel == i + 1] = 1 + else: + pass + Img_BW[FilledImg > 0] = 1 + + Img_BW = np.uint8(Img_BW) + return Img_BW + + +def removeSmallBLobs(BinaryImage, removeArea): + """ + Post process the binary image + :param BinaryImage: + :param removeArea: + """ + + BinaryImage[BinaryImage>0]=1 + + ####takes 0.9s, result is good + Img_BW = BinaryImage.copy() + BinaryImage_Label = measure.label(Img_BW) + for i, region in enumerate(measure.regionprops(BinaryImage_Label)): + if region.area < removeArea: + Img_BW[BinaryImage_Label == i + 1] = 0 + else: + pass + return np.uint8(Img_BW) + diff --git a/AV/Tools/FakePad.py b/AV/Tools/FakePad.py new file mode 100644 index 0000000000000000000000000000000000000000..dc08748a7e1f3ac340cb25966ac747428473356a --- /dev/null +++ b/AV/Tools/FakePad.py @@ -0,0 +1,115 @@ + + +from __future__ import division + +import cv2 +import numpy as np +from skimage import morphology +np.seterr(divide='ignore', invalid='ignore') + +"""This is the profiled code, very fast, takes 0.25s""" +def fakePad(Image, Mask, iterations=50): + """ + add an extra padding around the front mask + :param Image: + :param Mask: + :param iterations: + :return: DilatedImg + """ + + if len(Image.shape) == 3: ##for RGB Image + """for RGB Images""" + + Mask0 = Mask.copy() + height, width = Mask0.shape[:2] + Mask0[0, :] = 0 # np.zeros(width) + Mask0[-1, :] = 0 # np.zeros(width) + Mask0[:, 0] = 0 # np.zeros(height) + Mask0[:, -1] = 0 # np.zeros(height) + + # Erodes the mask to avoid weird region near the border. + structureElement1 = morphology.disk(5) + Mask0 = cv2.morphologyEx(Mask0, cv2.MORPH_ERODE, structureElement1, iterations=1) + + # DilatedImg = Img_green_reverse * Mask + DilatedImg = cv2.bitwise_and(Image, Image, mask=Mask0) + OldMask = Mask0.copy() + + filter = np.ones((3, 3)) + filterRows, filterCols = np.where(filter > 0) + filterRows = filterRows - 1 + filterCols = filterCols - 1 + + structureElement2 = morphology.diamond(1) + for i in range(0, iterations): + NewMask = cv2.morphologyEx(OldMask, cv2.MORPH_DILATE, structureElement2, iterations=1) + pixelIndex = np.where(NewMask - OldMask) # [rows, cols] + imgValues = np.zeros((len(pixelIndex[0]), len(filterRows), 3)) + for k in range(len(filterRows)): + filterRowIndexes = pixelIndex[0] - filterRows[k] + filterColIndexes = pixelIndex[1] - filterCols[k] + + selectMask0 = np.bitwise_and(np.bitwise_and(filterRowIndexes < height, filterRowIndexes >= 0), + np.bitwise_and(filterColIndexes < width, filterColIndexes >= 0)) + selectMask1 = OldMask[filterRowIndexes[selectMask0], filterColIndexes[selectMask0]] > 0 + selectedPositions = [filterRowIndexes[selectMask0][selectMask1], + filterColIndexes[selectMask0][selectMask1]] + imgValues[np.arange(len(pixelIndex[0]))[selectMask0][selectMask1], k, :] = DilatedImg[ + selectedPositions[0], + selectedPositions[1], :] + + DilatedImg[pixelIndex[0], pixelIndex[1], :] = np.sum(imgValues, axis=1) // np.sum(imgValues > 0, axis=1) + + OldMask = NewMask + + return DilatedImg + + ######################################################################## + + else: #for green channel only + """for green channel only""" + + Mask0 = Mask.copy() + height, width = Mask0.shape + Mask0[0, :] = 0 # np.zeros(width) + Mask0[-1, :] = 0 # np.zeros(width) + Mask0[:, 0] = 0 # np.zeros(height) + Mask0[:, -1] = 0 # np.zeros(height) + + # Erodes the mask to avoid weird region near the border. + structureElement1 = morphology.disk(5) + Mask0 = cv2.morphologyEx(Mask0, cv2.MORPH_ERODE, structureElement1, iterations=1) + + # DilatedImg = Img_green_reverse * Mask + DilatedImg = cv2.bitwise_and(Image, Image, mask=Mask0) + + OldMask = Mask0.copy() + + filter = np.ones((3, 3)) + filterRows, filterCols = np.where(filter > 0) + filterRows = filterRows - 1 + filterCols = filterCols - 1 + + structureElement2 = morphology.diamond(1) + for i in range(0, iterations): + NewMask = cv2.morphologyEx(OldMask, cv2.MORPH_DILATE, structureElement2, iterations=1) + pixelIndex = np.where(NewMask - OldMask) # [rows, cols] + + imgValues = np.zeros((len(pixelIndex[0]), len(filterRows))) + for k in range(len(filterRows)): + filterRowIndexes = pixelIndex[0] - filterRows[k] + filterColIndexes = pixelIndex[1] - filterCols[k] + + selectMask0 = np.bitwise_and(np.bitwise_and(filterRowIndexes < height, filterRowIndexes >= 0), + np.bitwise_and(filterColIndexes < width, filterColIndexes >= 0)) + selectMask1 = OldMask[filterRowIndexes[selectMask0], filterColIndexes[selectMask0]] > 0 + selectedPositions = [filterRowIndexes[selectMask0][selectMask1], filterColIndexes[selectMask0][selectMask1]] + imgValues[np.arange(len(pixelIndex[0]))[selectMask0][selectMask1], k] = DilatedImg[selectedPositions[0], selectedPositions[1]] + + DilatedImg[pixelIndex[0], pixelIndex[1]] = np.sum(imgValues, axis=1) / np.sum(imgValues > 0, axis=1) + + OldMask = NewMask + + return DilatedImg + + diff --git a/AV/Tools/Float2Uint.py b/AV/Tools/Float2Uint.py new file mode 100644 index 0000000000000000000000000000000000000000..2f977f64915c2a86065a13772dea796d0ece53ad --- /dev/null +++ b/AV/Tools/Float2Uint.py @@ -0,0 +1,18 @@ + + +import numpy as np + +def float2Uint(Image_float): + """ + Transfer float image to np.uint8 type + :param Image_float: + :return:LnGray + """ + + MaxLn = np.max(Image_float) + MinLn = np.min(Image_float) + # LnGray = 255*(Image_float - MinLn)//(MaxLn - MinLn + 1e-6) + LnGray = 255 * ((Image_float - MinLn) / float((MaxLn - MinLn + 1e-6))) + LnGray = np.array(LnGray, dtype = np.uint8) + + return LnGray diff --git a/AV/Tools/Hemelings_eval.py b/AV/Tools/Hemelings_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..cbfbef4a77dc86a2d80ce9504aa0784d6c61298d --- /dev/null +++ b/AV/Tools/Hemelings_eval.py @@ -0,0 +1,119 @@ +import numpy as np +from skimage.morphology import skeletonize, erosion +from sklearn.metrics import f1_score, accuracy_score,classification_report,confusion_matrix +import cv2 +from Tools.BinaryPostProcessing import binaryPostProcessing3 + +def evaluation_code(prediction, groundtruth, mask=None,use_mask=False): + ''' + Function to evaluate the performance of AV predictions with a given ground truth + - prediction: should be an image array of [dim1, dim2, img_channels = 3] with arteries in red and veins in blue + - groundtruth: same as above + ''' + encoded_pred = np.zeros(prediction.shape[:2], dtype=int) + encoded_gt = np.zeros(groundtruth.shape[:2], dtype=int) + + # convert white pixels to green pixels (which are ignored) + white_ind = np.where(np.logical_and(groundtruth[:,:,0] == 255, groundtruth[:,:,1] == 255, groundtruth[:,:,2] == 255)) + if white_ind[0].size != 0: + groundtruth[white_ind] = [0,255,0] + + # translate the images to arrays suited for sklearn metrics + # --- original ------- + arteriole = np.where(np.logical_and(groundtruth[:,:,0] == 255, groundtruth[:,:,1] == 0)); encoded_gt[arteriole] = 1 + venule = np.where(np.logical_and(groundtruth[:,:,2] == 255, groundtruth[:,:,1] == 0)); encoded_gt[venule] = 2 + arteriole = np.where(prediction[:,:,0] == 255); encoded_pred[arteriole] = 1 + venule = np.where(prediction[:,:,2] == 255); encoded_pred[venule] = 2 + # -------------------- + + # disk and cup + if use_mask: + groundtruth = cv2.bitwise_and(groundtruth, groundtruth, mask=mask) + prediction = cv2.bitwise_and(prediction, prediction, mask=mask) + + encoded_pred = cv2.bitwise_and(encoded_pred, encoded_pred, mask=mask) + encoded_gt = cv2.bitwise_and(encoded_gt, encoded_gt, mask=mask) + + # retrieve the indices for the centerline pixels present in the prediction + center = np.where(np.logical_and( + np.logical_or((skeletonize(groundtruth[:,:,0] > 0)),(skeletonize(groundtruth[:,:,2] > 0))), + encoded_pred[:,:] > 0)) + + encoded_pred_center = encoded_pred[center] + encoded_gt_center = encoded_gt[center] + + # retrieve the indices for the centerline pixels present in the groundtruth + center_comp = np.where( + np.logical_or(skeletonize(groundtruth[:,:,0] > 0),skeletonize(groundtruth[:,:,2] > 0))) + + encoded_pred_center_comp = encoded_pred[center_comp] + encoded_gt_center_comp = encoded_gt[center_comp] + + # retrieve the indices for discovered centerline pixels - limited to vessels wider than two pixels (for DRIVE) + center_eroded = np.where(np.logical_and( + np.logical_or(skeletonize(erosion(groundtruth[:,:,0] > 0)),skeletonize(erosion(groundtruth[:,:,2] > 0))), + encoded_pred[:,:] > 0)) + + encoded_pred_center_eroded = encoded_pred[center_eroded] + encoded_gt_center_eroded = encoded_gt[center_eroded] + + # metrics over full image + cur1_acc = accuracy_score(encoded_gt.flatten(),encoded_pred.flatten()) + cur1_F1 = f1_score(encoded_gt.flatten(),encoded_pred.flatten(),average='weighted') + # cls_report = classification_report(encoded_gt.flatten(), encoded_pred.flatten(), target_names=['class_1', 'class_2', 'class_3']) + # print('Full image') + # print('Accuracy: {}\nF1: {}\n'.format(cur1_acc, cur1_F1)) + # print('Class report:') + # print(cls_report) + metrics1 = [cur1_acc, cur1_F1] + + # metrics over discovered centerline pixels + cur2_acc = accuracy_score(encoded_gt_center.flatten(),encoded_pred_center.flatten()) + cur2_F1 = f1_score(encoded_gt_center.flatten(),encoded_pred_center.flatten(),average='weighted') + # print('Discovered centerline pixels') + # print('Accuracy: {}\nF1: {}\n'.format(cur2_acc, cur2_F1)) + metrics2 = [cur2_acc, cur2_F1] + + # metrics over discovered centerline pixels - limited to vessels wider than two pixels + cur3_acc = accuracy_score(encoded_gt_center_eroded.flatten(),encoded_pred_center_eroded.flatten()) + cur3_F1 = f1_score(encoded_gt_center_eroded.flatten(),encoded_pred_center_eroded.flatten(),average='weighted') + # print('Discovered centerline pixels of vessels wider than two pixels') + # print('Accuracy: {}\nF1: {}\n'.format(cur3_acc, cur3_F1)) + metrics3 = [cur3_acc, cur3_F1] + + # metrics over all centerline pixels in ground truth + cur4_acc = accuracy_score(encoded_gt_center_comp.flatten(),encoded_pred_center_comp.flatten()) + cur4_F1 = f1_score(encoded_gt_center_comp.flatten(),encoded_pred_center_comp.flatten(),average='weighted') + # print('Centerline pixels') + # print('Accuracy: {}\nF1: {}\n'.format(cur4_acc, cur4_F1)) + # confusion matrix + out = confusion_matrix(encoded_gt_center_comp,encoded_pred_center_comp)#.ravel() + sens = 0 + sepc = 0 + + if out.shape[0] == 2: + tn, fp,fn,tp = out.ravel() + # print(tn, fp,fn,tp) + else: + tn = out[1,1] + fp = out[1,2] + fn = out[2,1] + tp = out[2,2] + + # sens = tpr + spec = tp/ (tp+fn) + # spec = TNR + sens = tn/(fp+tn) + + metrics4 = [cur4_acc, cur4_F1, sens, spec] + + + # finally, compute vessel detection rate + vessel_ind = np.where(encoded_gt>0) + vessel_gt = encoded_gt[vessel_ind] + vessel_pred = encoded_pred[vessel_ind] + + detection_rate = accuracy_score(vessel_gt.flatten(),vessel_pred.flatten()) + # print('Amount of vessels detected: ' + str(detection_rate)) + + return [metrics1,metrics2,metrics3,metrics4,detection_rate]#,encoded_pred,encoded_gt,center,center_comp,center_eroded#,encoded_pred_center_comp,encoded_gt_center_comp \ No newline at end of file diff --git a/AV/Tools/Im2Double.py b/AV/Tools/Im2Double.py new file mode 100644 index 0000000000000000000000000000000000000000..e4dad09384c7602fa66c59e975f9a578d4ad1d54 --- /dev/null +++ b/AV/Tools/Im2Double.py @@ -0,0 +1,15 @@ + + +import numpy as np + +def im2double(im): + """ + Transfer np.uint8 to float type + :param im: + :return: output image + """ + + min_val = np.min(im.ravel()) + max_val = np.max(im.ravel()) + out = (im.astype('float') - min_val) / (max_val - min_val) + return out \ No newline at end of file diff --git a/AV/Tools/ImageResize.py b/AV/Tools/ImageResize.py new file mode 100644 index 0000000000000000000000000000000000000000..11cd8ed5690e5ccb0eecd2f9ab312969728e5f66 --- /dev/null +++ b/AV/Tools/ImageResize.py @@ -0,0 +1,214 @@ +import os.path + +import cv2 +import numpy as np +from skimage import measure + +def imageResize(Image, downsizeRatio): + + ##This program resize the original image + ##Input: original image and downsizeRatio (user defined parameter: 0.75, 0.5 or 0.2) + ##Output: the resized image according to the given ratio + + if downsizeRatio < 1:#len(ImgFileList) + ImgResized = cv2.resize(Image, dsize=None, fx=downsizeRatio, fy=downsizeRatio) + else: + ImgResized = Image + + ImgResized = np.uint8(ImgResized) + return ImgResized + + +def creatMask(Image, threshold = 10): + ##This program try to creat the mask for the filed-of-view + ##Input original image (RGB or green channel), threshold (user set parameter, default 10) + ##Output: the filed-of-view mask + + if len(Image.shape) == 3: ##RGB image + gray = cv2.cvtColor(Image, cv2.COLOR_BGR2GRAY) + Mask0 = gray >= threshold + + else: #for green channel image + Mask0 = Image >= threshold + + + # ######get the largest blob, this takes 0.18s + cvVersion = int(cv2.__version__.split('.')[0]) + + Mask0 = np.uint8(Mask0) + + contours, hierarchy = cv2.findContours(Mask0, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + + areas = [cv2.contourArea(c) for c in contours] + max_index = np.argmax(areas) + Mask = np.zeros(Image.shape[:2], dtype=np.uint8) + cv2.drawContours(Mask, contours, max_index, 1, -1) + + ResultImg = Image.copy() + if len(Image.shape) == 3: + ResultImg[Mask ==0] = (255,255,255) + else: + ResultImg[Mask==0] = 255 + + return ResultImg, Mask + +def shift_rgb(img, *args): + + + + + result_img = np.empty_like(img) + shifts = args + max_value = 255 + # print(shifts) + for i, shift in enumerate(shifts): + lut = np.arange(0, max_value + 1).astype("float32") + lut += shift + + lut = np.clip(lut, 0, max_value).astype(img.dtype) + if len(img.shape)==2: + print(f'=========grey image=======') + result_img = cv2.LUT(img,lut) + else: + result_img[..., i] = cv2.LUT(img[...,i],lut) + + return result_img +def cropImage_bak(Image, Mask): + Image = Image.copy() + Mask = Mask.copy() + + leftLimit, rightLimit, upperLimit, lowerLimit = getLimit(Mask) + + + + if len(Image.shape) == 3: + ImgCropped = Image[upperLimit:lowerLimit, leftLimit:rightLimit, :] + MaskCropped = Mask[upperLimit:lowerLimit, leftLimit:rightLimit] + + ImgCropped[:20, :, :] = 0 + ImgCropped[-20:, :, :] = 0 + ImgCropped[:, :20, :] = 0 + ImgCropped[:, -20:, :] = 0 + MaskCropped[:20, :] = 0 + MaskCropped[-20:, :] = 0 + MaskCropped[:, :20] = 0 + MaskCropped[:, -20:] = 0 + else: #len(Image.shape) == 2: + ImgCropped = Image[upperLimit:lowerLimit, leftLimit:rightLimit] + MaskCropped = Mask[upperLimit:lowerLimit, leftLimit:rightLimit] + ImgCropped[:20, :] = 0 + ImgCropped[-20:, :] = 0 + ImgCropped[:, :20] = 0 + ImgCropped[:, -20:] = 0 + MaskCropped[:20, :] = 0 + MaskCropped[-20:, :] = 0 + MaskCropped[:, :20] = 0 + MaskCropped[:, -20:] = 0 + + + cropLimit = [upperLimit, lowerLimit, leftLimit, rightLimit] + + return ImgCropped, MaskCropped, cropLimit + + + +######################################################## +###new function to get the limit for cropping. +###try to get higher speed than np.where, but not working. + +def getLimit(Mask): + + Mask1 = Mask > 0 + colSums = np.sum(Mask1, axis=1) + rowSums = np.sum(Mask1, axis=0) + maxColSum = np.max(colSums) + maxRowSum = np.max(rowSums) + + colList = np.where(colSums >= 0.01*maxColSum)[0] + rowList = np.where(rowSums >= 0.01*maxRowSum)[0] + + leftLimit0 = np.min(rowList) + rightLimit0 = np.max(rowList) + upperLimit0 = np.min(colList) + lowerLimit0 = np.max(colList) + + margin = 50 + leftLimit = np.clip(leftLimit0-margin, 0, Mask.shape[1]) + rightLimit = np.clip(rightLimit0+margin, 0, Mask.shape[1]) + upperLimit = np.clip(upperLimit0 - margin, 0, Mask.shape[0]) + lowerLimit = np.clip(lowerLimit0 + margin, 0, Mask.shape[0]) + + + return leftLimit, rightLimit, upperLimit, lowerLimit + + + + + +def cropImage(Image, Mask): + ##This program will crop the filed of view based on the mask + ##Input: orginal image, origimal Mask (the image needs to be RGB resized image) + ##Output: Cropped image, Cropped Mask, the cropping limit + + height, width = Image.shape[:2] + + rowsMask0, colsMask0 = np.where(Mask > 0) + minColIndex0, maxColIndex0 = np.argmin(colsMask0), np.argmax(colsMask0) + minCol, maxCol = colsMask0[minColIndex0], colsMask0[maxColIndex0] + + minRowIndex0, maxRowIndex0 = np.argmin(rowsMask0), np.argmax(rowsMask0) + minRow, maxRow = rowsMask0[minRowIndex0], rowsMask0[maxRowIndex0] + + upperLimit = np.maximum(0, minRow - 50) #20 + lowerLimit = np.minimum(maxRow + 50, height) #20 + leftLimit = np.maximum(0, minCol - 50) #lowerLimit = np.minimum(maxCol + 50, width) #20 + rightLimit = np.minimum(maxCol + 50, width) + + if len(Image.shape) == 3: + ImgCropped = Image[upperLimit:lowerLimit, leftLimit:rightLimit, :] + MaskCropped = Mask[upperLimit:lowerLimit, leftLimit:rightLimit] + + ImgCropped[:20, :, :] = 0 + ImgCropped[-20:, :, :] = 0 + ImgCropped[:, :20, :] = 0 + ImgCropped[:, -20:, :] = 0 + MaskCropped[:20, :] = 0 + MaskCropped[-20:, :] = 0 + MaskCropped[:, :20] = 0 + MaskCropped[:, -20:] = 0 + elif len(Image.shape) == 2: + ImgCropped = Image[upperLimit:lowerLimit, leftLimit:rightLimit] + MaskCropped = Mask[upperLimit:lowerLimit, leftLimit:rightLimit] + ImgCropped[:20, :] = 0 + ImgCropped[-20:, :] = 0 + ImgCropped[:, :20] = 0 + ImgCropped[:, -20:] = 0 + MaskCropped[:20, :] = 0 + MaskCropped[-20:, :] = 0 + MaskCropped[:, :20] = 0 + MaskCropped[:, -20:] = 0 + else: + pass + + + cropLimit = [upperLimit, lowerLimit, leftLimit, rightLimit] + + return ImgCropped, MaskCropped, cropLimit + + +if __name__ == '__main__': + if not os.path.exists(os.path.join('../data','AV_DRIVE','test','mask')): + os.makedirs(os.path.join('../data','AV_DRIVE','test','mask')) + + for file in os.listdir(os.path.join('../data','AV_DRIVE','test','images')): + # suffix file name + if file.endswith('.jpg') or file.endswith('.png'): + # read image + img = cv2.imread(os.path.join('../data','AV_DRIVE','test','images',file)) + + + _,mask = creatMask(img) + + + # save mask + cv2.imwrite(os.path.join('../data','AV_DRIVE','test','mask',file),mask) \ No newline at end of file diff --git a/AV/Tools/Remove_small_holes.py b/AV/Tools/Remove_small_holes.py new file mode 100644 index 0000000000000000000000000000000000000000..204a2120cd6e6d81536e72c51bf64b98c05edd21 --- /dev/null +++ b/AV/Tools/Remove_small_holes.py @@ -0,0 +1,88 @@ + + +import numpy as np +import functools +import warnings +from scipy import ndimage as ndi +from skimage import morphology + + +def remove_small_holes(ar, min_size=64, connectivity=1, in_place=False): + """Remove continguous holes smaller than the specified size. + Parameters + ---------- + ar : ndarray (arbitrary shape, int or bool type) + The array containing the connected components of interest. + min_size : int, optional (default: 64) + The hole component size. + connectivity : int, {1, 2, ..., ar.ndim}, optional (default: 1) + The connectivity defining the neighborhood of a pixel. + in_place : bool, optional (default: False) + If `True`, remove the connected components in the input array itself. + Otherwise, make a copy. + Raises + ------ + TypeError + If the input array is of an invalid type, such as float or string. + ValueError + If the input array contains negative values. + Returns + ------- + out : ndarray, same shape and type as input `ar` + The input array with small holes within connected components removed. + Examples + -------- + # >>> from skimage import morphology + # >>> a = np.array([[1, 1, 1, 1, 1, 0], + # ... [1, 1, 1, 0, 1, 0], + # ... [1, 0, 0, 1, 1, 0], + # ... [1, 1, 1, 1, 1, 0]], bool) + # >>> b = morphology.remove_small_holes(a, 2) + # >>> b + # array([[ True, True, True, True, True, False], + # [ True, True, True, True, True, False], + # [ True, False, False, True, True, False], + # [ True, True, True, True, True, False]], dtype=bool) + # >>> c = morphology.remove_small_holes(a, 2, connectivity=2) + # >>> c + # array([[ True, True, True, True, True, False], + # [ True, True, True, False, True, False], + # [ True, False, False, True, True, False], + # [ True, True, True, True, True, False]], dtype=bool) + # >>> d = morphology.remove_small_holes(a, 2, in_place=True) + # >>> d is a + # True + # Notes + # ----- + # If the array type is int, it is assumed that it contains already-labeled + # objects. The labels are not kept in the output image (this function always + # outputs a bool image). It is suggested that labeling is completed after + # using this function. + # """ + # _check_dtype_supported(ar) + + #Creates warning if image is an integer image + # if ar.dtype != bool: + # warnings.warn("Any labeled images will be returned as a boolean array. " + # "Did you mean to use a boolean array?", UserWarning) + + if in_place: + out = ar + else: + out = ar.copy() + + #Creating the inverse of ar + if in_place: + out = np.logical_not(out,out) + else: + out = np.logical_not(out) + + #removing small objects from the inverse of ar + out = morphology.remove_small_objects(out, min_size, connectivity, in_place) + + if in_place: + out = np.logical_not(out,out) + else: + out = np.logical_not(out) + + return out \ No newline at end of file diff --git a/AV/Tools/Standardize.py b/AV/Tools/Standardize.py new file mode 100644 index 0000000000000000000000000000000000000000..92bdac60bb77b55b4f3dff419c3ad955a34345c0 --- /dev/null +++ b/AV/Tools/Standardize.py @@ -0,0 +1,45 @@ + + +from __future__ import division +import numpy as np +import cv2 + +def standardize(img,mask,wsize): + """ + Convert the image values to standard images. + :param img: + :param mask: + :param wsize: + :return: + """ + + if wsize == 0: + simg=globalstandardize(img,mask) + else: + img[mask == 0]=0 + img_mean=cv2.blur(img, ksize=wsize) + img_squared_mean = cv2.blur(img*img, ksize=wsize) + img_std = np.sqrt(img_squared_mean - img_mean*img_mean) + simg=(img - img_mean) / img_std + simg[img_std == 0]=0 + simg[mask == 0]=0 + return simg + +def globalstandardize(img,mask): + + usedpixels = np.double(img[mask == 1]) + m=np.mean(usedpixels) + s=np.std(usedpixels) + simg=np.zeros(img.shape) + simg[mask == 1]=(usedpixels - m) / s + return simg + +def getmean(x): + usedx=x[x != 0] + m=np.mean(usedx) + return m + +def getstd(x): + usedx=x[x != 0] + s=np.std(usedx) + return s \ No newline at end of file diff --git a/AV/Tools/__init__.py b/AV/Tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/AV/Tools/__pycache__/AVclassifiation.cpython-39.pyc b/AV/Tools/__pycache__/AVclassifiation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66837d8d71cbb459df23f117273b704e968c4d81 Binary files /dev/null and b/AV/Tools/__pycache__/AVclassifiation.cpython-39.pyc differ diff --git a/AV/Tools/__pycache__/AVclassifiationMetrics.cpython-310.pyc b/AV/Tools/__pycache__/AVclassifiationMetrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8bd295fc6e44c7c6f0e60ccc29ed46961bf407c Binary files /dev/null and b/AV/Tools/__pycache__/AVclassifiationMetrics.cpython-310.pyc differ diff --git a/AV/Tools/__pycache__/AVclassifiationMetrics.cpython-39.pyc b/AV/Tools/__pycache__/AVclassifiationMetrics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6bf05d3b5412115698d86fb672565017e877362 Binary files /dev/null and b/AV/Tools/__pycache__/AVclassifiationMetrics.cpython-39.pyc differ diff --git a/AV/Tools/__pycache__/AVclassifiationMetrics_v1.cpython-39.pyc b/AV/Tools/__pycache__/AVclassifiationMetrics_v1.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f66b828c6877df0ba5fcb370533d15383cf4f66c Binary files /dev/null and b/AV/Tools/__pycache__/AVclassifiationMetrics_v1.cpython-39.pyc differ diff --git a/AV/Tools/__pycache__/BGR2RGB.cpython-39.pyc b/AV/Tools/__pycache__/BGR2RGB.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49650d1ce3064fe473aa3ca0062981303c5870d7 Binary files /dev/null and b/AV/Tools/__pycache__/BGR2RGB.cpython-39.pyc differ diff --git a/AV/Tools/__pycache__/BinaryPostProcessing.cpython-310.pyc b/AV/Tools/__pycache__/BinaryPostProcessing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..275515e52a1154df6b51dc72a4c50c17089efb11 Binary files /dev/null and b/AV/Tools/__pycache__/BinaryPostProcessing.cpython-310.pyc differ diff --git a/AV/Tools/__pycache__/BinaryPostProcessing.cpython-39.pyc b/AV/Tools/__pycache__/BinaryPostProcessing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71742981670a468af20ddedc986f81c548bf6c2e Binary files /dev/null and b/AV/Tools/__pycache__/BinaryPostProcessing.cpython-39.pyc differ diff --git a/AV/Tools/__pycache__/ImageResize.cpython-310.pyc b/AV/Tools/__pycache__/ImageResize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a5ecebdf5fd76bef4dc5230de412ae07ed98e77 Binary files /dev/null and b/AV/Tools/__pycache__/ImageResize.cpython-310.pyc differ diff --git a/AV/Tools/__pycache__/ImageResize.cpython-39.pyc b/AV/Tools/__pycache__/ImageResize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f23d2e996d0f414516ad28dbd3f457a5696b504 Binary files /dev/null and b/AV/Tools/__pycache__/ImageResize.cpython-39.pyc differ diff --git a/AV/Tools/__pycache__/Remove_small_holes.cpython-310.pyc b/AV/Tools/__pycache__/Remove_small_holes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8caadc03e98fd42208ba4feea8cc2ce7af43c116 Binary files /dev/null and b/AV/Tools/__pycache__/Remove_small_holes.cpython-310.pyc differ diff --git a/AV/Tools/__pycache__/Remove_small_holes.cpython-39.pyc b/AV/Tools/__pycache__/Remove_small_holes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..781b89b1730225535ce9f74f63fc0b23d0756aa4 Binary files /dev/null and b/AV/Tools/__pycache__/Remove_small_holes.cpython-39.pyc differ diff --git a/AV/Tools/__pycache__/__init__.cpython-310.pyc b/AV/Tools/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d29148b44a6c051473c4a207a9fdf6bec1f52b7 Binary files /dev/null and b/AV/Tools/__pycache__/__init__.cpython-310.pyc differ diff --git a/AV/Tools/__pycache__/__init__.cpython-39.pyc b/AV/Tools/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d12fba52c3d5b28acd46ecd5f124dea04c7b64d7 Binary files /dev/null and b/AV/Tools/__pycache__/__init__.cpython-39.pyc differ diff --git a/AV/Tools/__pycache__/data_augmentation.cpython-310.pyc b/AV/Tools/__pycache__/data_augmentation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54da7f23dc75ca8adf5b6156cee5162b401249dd Binary files /dev/null and b/AV/Tools/__pycache__/data_augmentation.cpython-310.pyc differ diff --git a/AV/Tools/__pycache__/data_augmentation.cpython-39.pyc b/AV/Tools/__pycache__/data_augmentation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d7a5819aec24717e78e0df6f3e31cc4b49b67bd Binary files /dev/null and b/AV/Tools/__pycache__/data_augmentation.cpython-39.pyc differ diff --git a/AV/Tools/__pycache__/evalution_vessel.cpython-310.pyc b/AV/Tools/__pycache__/evalution_vessel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1349de755660f00d2fa54be02cce372da16bb88a Binary files /dev/null and b/AV/Tools/__pycache__/evalution_vessel.cpython-310.pyc differ diff --git a/AV/Tools/__pycache__/global2patch_AND_patch2global.cpython-39.pyc b/AV/Tools/__pycache__/global2patch_AND_patch2global.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38b3bf7084c924336f1cf3fe5e744cfb5bff4234 Binary files /dev/null and b/AV/Tools/__pycache__/global2patch_AND_patch2global.cpython-39.pyc differ diff --git a/AV/Tools/__pycache__/utils.cpython-310.pyc b/AV/Tools/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b55765dce95166e4d1e62445738e0a412690517 Binary files /dev/null and b/AV/Tools/__pycache__/utils.cpython-310.pyc differ diff --git a/AV/Tools/__pycache__/utils.cpython-39.pyc b/AV/Tools/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2803b205ca7dcf0eb4b76cabe72212b112aa183e Binary files /dev/null and b/AV/Tools/__pycache__/utils.cpython-39.pyc differ diff --git a/AV/Tools/__pycache__/utils_test.cpython-39.pyc b/AV/Tools/__pycache__/utils_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c20daac8ff3a46de6fbbe7691f26e183f7aea672 Binary files /dev/null and b/AV/Tools/__pycache__/utils_test.cpython-39.pyc differ diff --git a/AV/Tools/__pycache__/warmup.cpython-39.pyc b/AV/Tools/__pycache__/warmup.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f19ff2efb8f885f9e3efcaaeb207f89a745d843 Binary files /dev/null and b/AV/Tools/__pycache__/warmup.cpython-39.pyc differ diff --git a/AV/Tools/centerline_evaluation.py b/AV/Tools/centerline_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..c66f2aa1aab3176075544c6c640bee36e44225cc --- /dev/null +++ b/AV/Tools/centerline_evaluation.py @@ -0,0 +1,160 @@ +import numpy as np +import cv2 +import os +import natsort +from Tools.Hemelings_eval import evaluation_code +import pandas as pd + +def getFolds(ImgPath, LabelPath, k_fold_idx,k_fold, trainset=True): + print(f'ImgPath: {ImgPath}') + print(f'LabelPath: {LabelPath}') + print(f'k_fold_idx: {k_fold_idx}') + print(f'k_fold: {k_fold}') + print(f'trainset: {trainset}') + for dirpath,dirnames,filenames in os.walk(ImgPath): + ImgDirAll = filenames + break + + for dirpath,dirnames,filenames in os.walk(LabelPath): + LabelDirAll = filenames + break + ImgDir = [] + LabelDir = [] + + ImgDir_testset = [] + LabelDir_testset = [] + + if k_fold >0: + ImgDirAll = natsort.natsorted(ImgDirAll) + LabelDirAll = natsort.natsorted(LabelDirAll) + num_fold = len(ImgDirAll) // k_fold + for i in range(k_fold): + start_idx = i * num_fold + end_idx = (i+1) * num_fold + # if i == k_fold_idx: + ImgDir_testset.extend(ImgDirAll[start_idx:end_idx]) + LabelDir_testset.extend(LabelDirAll[start_idx:end_idx]) + #continue + ImgDir.extend(ImgDirAll[start_idx:end_idx]) + LabelDir.extend(LabelDirAll[start_idx:end_idx]) + if not trainset: + return ImgDir_testset, LabelDir_testset + return ImgDir, ImgDir + +def centerline_eval(ProMap, config): + if config.dataset_name == 'hrf': + ImgPath = os.path.join(config.trainset_path,'test', 'images') + LabelPath = os.path.join(config.trainset_path,'test', 'ArteryVein_0410_final') + elif config.dataset_name == 'DRIVE': + dataroot = './data/AV_DRIVE/test' + LabelPath = os.path.join(dataroot, 'av') + else: #if config.dataset_name == 'INSPIRE': + dataroot = './data/INSPIRE_AV' + LabelPath = os.path.join(dataroot, 'label') + DF_disc = pd.read_excel('./Tools/DiskParameters_INSPIRE_resize.xls', sheet_name=0) + if config.dataset_name == 'hrf': + k_fold_idx = config.k_fold_idx + k_fold = config.k_fold + + ImgList0 , LabelList0 = getFolds(ImgPath, LabelPath, k_fold_idx,k_fold, trainset=False) + overall_value = [[0,0] for i in range(3)] + overall_value.append([0,0,0,0]) + overall_value.append(0) + img_num = ProMap.shape[0] + for i in range(img_num): + arteryImg = ProMap[i, 0, :, :] + veinImg = ProMap[i, 1, :, :] + vesselImg = ProMap[i, 2, :, :] + + idx = str(i+1) + idx = idx.zfill(2) + if config.dataset_name == 'hrf': + imgName = ImgList0[i] + elif config.dataset_name == 'DRIVE': + imgName = idx + '_test.png' + else: + imgName = 'image'+ str(i+1) + '_ManualAV.png' + gt_path = os.path.join(LabelPath , imgName) + print(gt_path) + gt = cv2.imread(gt_path) + gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB) + if config.dataset_name == 'hrf': + gt = cv2.resize(gt, (1200,800)) + gt_vessel = gt[:,:,0]+gt[:,:,2] + + h, w = arteryImg.shape + + ArteryPred = np.float32(arteryImg) + VeinPred = np.float32(veinImg) + VesselPred = np.float32(vesselImg) + AVSeg1 = np.zeros((h, w, 3)) + vesselSeg = np.zeros((h, w)) + th = 0 + vesselPixels = np.where(gt_vessel)#(VesselPred>th) # + + for k in np.arange(len(vesselPixels[0])): + row = vesselPixels[0][k] + col = vesselPixels[1][k] + if ArteryPred[row, col] >= VeinPred[row, col]: + AVSeg1[row, col] = (255, 0, 0) + else: + AVSeg1[row, col] = ( 0, 0, 255) + + AVSeg1 = np.float32(AVSeg1) + AVSeg1 = np.uint8(AVSeg1) + + if config.dataset_name == 'INSPIRE': + discCenter = (DF_disc.loc[i, 'DiskCenterRow'], DF_disc.loc[i, 'DiskCenterCol']) + discRadius = DF_disc.loc[i, 'DiskRadius'] + MaskDisc = np.ones((h, w), np.uint8) + cv2.circle(MaskDisc, center=(discCenter[1], discCenter[0]), radius= discRadius, color=0, thickness=-1) + out = evaluation_code(AVSeg1, gt, mask=MaskDisc,use_mask=True) + else: + out = evaluation_code(AVSeg1, gt) + + for j in range(len(out)): + if j == 4: + overall_value[j] += out[j] + continue + if j == 3: + overall_value[j][0] += out[j][0] + overall_value[j][1] += out[j][1] + overall_value[j][2] += out[j][2] + overall_value[j][3] += out[j][3] + continue + + overall_value[j][0] += out[j][0] + overall_value[j][1] += out[j][1] + + # print("overall_value:", overall_value) + for j in range(len(overall_value)): + if j == 4: + overall_value[j] /= img_num + continue + if j == 3: + overall_value[j][0] /= img_num + overall_value[j][1] /= img_num + overall_value[j][2] /= img_num + overall_value[j][3] /= img_num + continue + overall_value[j][0] /= img_num + overall_value[j][1] /= img_num + # print + metrics_names = ['full image', 'discovered centerline pixels', 'vessels wider than two pixels', 'all centerline', 'vessel detection rate'] + filewriter = "" + print("--------------------------Centerline---------------------------------") + filewriter += "--------------------------Centerline---------------------------------\n" + for j in range(len(overall_value)): + if j == 4: + print("{} - Ratio:{}".format(metrics_names[j], overall_value[j])) + filewriter += "{} - Ratio:{}\n".format(metrics_names[j], overall_value[j]) + continue + if j == 3: + print("{} - Acc: {} , F1:{}, Sens:{}, Spec:{}".format(metrics_names[j], overall_value[j][0],overall_value[j][1],overall_value[j][2],overall_value[j][3])) + filewriter += "{} - Acc: {} , F1:{}, Sens:{}, Spec:{}\n".format(metrics_names[j], overall_value[j][0],overall_value[j][1],overall_value[j][2],overall_value[j][3]) + continue + + print("{} - Acc: {} , F1:{}".format(metrics_names[j], overall_value[j][0],overall_value[j][1])) + filewriter += "{} - Acc: {} , F1:{}\n".format(metrics_names[j], overall_value[j][0],overall_value[j][1]) + + return filewriter diff --git a/AV/Tools/data_augmentation.py b/AV/Tools/data_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..afc3529472a4d594d368775851aef5cae8e954c8 --- /dev/null +++ b/AV/Tools/data_augmentation.py @@ -0,0 +1,94 @@ +import numpy as np +import tensorlayer as tl + +def data_augmentation1_5(*args): + # image3 = np.expand_dims(image3,-1) + args = tl.prepro.rotation_multi(args, rg=180, is_random=True, + fill_mode='reflect') + args = np.squeeze(args).astype(np.float32) + + return args + +def data_augmentation3_5(*args): + # image3 = np.expand_dims(image3,-1) + args = tl.prepro.shift_multi(args, wrg=0.10, hrg=0.10, is_random=True, + fill_mode='reflect') + args = np.squeeze(args).astype(np.float32) + + return args + +def data_augmentation4_5(*args): + + args = tl.prepro.swirl_multi(args,is_random=True) + args = np.squeeze(args).astype(np.float32) + + return args + +def data_augmentation2_5(*args): + # image3 = np.expand_dims(image3,-1) + args = tl.prepro.zoom_multi(args, zoom_range=[0.5, 2.5], is_random=True, + fill_mode='reflect') + args = np.squeeze(args).astype(np.float32) + + return args + +def data_aug5_old(data_mat, label_mat, label_data_centerness, choice): + data_mat = np.transpose(data_mat, (1, 2, 0)) + label_mat = np.transpose(label_mat, (1, 2, 0)) + label_data_centerness = np.transpose(label_data_centerness, (1, 2, 0)) + + if choice == 0: + data_mat = data_mat + label_mat = label_mat + label_data_centerness = label_data_centerness + + elif choice == 1: + data_mat = np.fliplr(data_mat) + label_mat = np.fliplr(label_mat) + label_data_centerness = np.fliplr(label_data_centerness) + + elif choice == 2: + data_mat = np.flipud(data_mat) + label_mat = np.flipud(label_mat) + label_data_centerness = np.flipud(label_data_centerness) + + elif choice == 3: + data_mat, label_mat, label_data_centerness= data_augmentation1_5(data_mat, label_mat, label_data_centerness) + elif choice == 4: + data_mat, label_mat, label_data_centerness= data_augmentation2_5(data_mat, label_mat, label_data_centerness) + elif choice == 5: + data_mat, label_mat, label_data_centerness= data_augmentation3_5(data_mat, label_mat, label_data_centerness) + elif choice == 6: + data_mat, label_mat, label_data_centerness= data_augmentation4_5(data_mat, label_mat, label_data_centerness) + + data_mat = np.transpose(data_mat, (2, 0, 1)) + label_mat = np.transpose(label_mat, (2, 0, 1)) + label_data_centerness = np.transpose(label_data_centerness, (2, 0, 1)) + + + return data_mat, label_mat, label_data_centerness + +# data augmentation for variable number of input +def data_aug5(*args,choice): + datas=[np.transpose(item, (1, 2, 0)) for item in args] + + if choice==1: + datas=[np.fliplr(item) for item in datas] + elif choice==2: + datas = [np.flipud(item) for item in datas] + elif choice==3: + datas = data_augmentation1_5(*datas) + elif choice==4: + datas = data_augmentation2_5(*datas) + elif choice==5: + datas = data_augmentation3_5(*datas) + elif choice==6: + datas = data_augmentation4_5(*datas) + + datas = [np.transpose(item, (2, 0, 1)) for item in datas] + + return tuple(datas) + + + + diff --git a/AV/Tools/evalution_vessel.py b/AV/Tools/evalution_vessel.py new file mode 100644 index 0000000000000000000000000000000000000000..b7701da64677320d3ad048c585b325a0f29a4e72 --- /dev/null +++ b/AV/Tools/evalution_vessel.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +################################################### +# +# Script to +# - Calculate prediction of the test dataset +# - Calculate the parameters to evaluate the prediction +# +################################################## + +#Python +import numpy as np +from sklearn.metrics import roc_curve +from sklearn.metrics import roc_auc_score,f1_score,jaccard_score +from sklearn.metrics import confusion_matrix +from sklearn.metrics import precision_recall_curve + +from lib.extract_patches2 import pred_only_FOV + + +def evalue(preImg, gtruth_masks, test_border_masks): + + #predictions only inside the FOV + y_scores, y_true = pred_only_FOV(preImg,gtruth_masks, test_border_masks) #returns data only inside the FOV + + #Area under the ROC curve + fpr, tpr, thresholds = roc_curve((y_true), y_scores) + AUC_ROC = roc_auc_score(y_true, y_scores) + # test_integral = np.trapz(tpr,fpr) #trapz is numpy integration + + #Precision-recall curve + precision, recall, thresholds = precision_recall_curve(y_true, y_scores) + precision = np.fliplr([precision])[0] #so the array is increasing (you won't get negative AUC) + recall = np.fliplr([recall])[0] #so the array is increasing (you won't get negative AUC) + + + #Confusion matrix + threshold_confusion = 0.5 + + y_pred = np.empty((y_scores.shape[0])) + for i in range(y_scores.shape[0]): + if y_scores[i]>=threshold_confusion: + y_pred[i]=1 + else: + y_pred[i]=0 + confusion = confusion_matrix(y_true, y_pred) + + accuracy = 0 + if float(np.sum(confusion))!=0: + accuracy = float(confusion[0,0]+confusion[1,1])/float(np.sum(confusion)) + + specificity = 0 + if float(confusion[0,0]+confusion[0,1])!=0: + specificity = float(confusion[0,0])/float(confusion[0,0]+confusion[0,1]) + + sensitivity = 0 + if float(confusion[1,1]+confusion[1,0])!=0: + sensitivity = float(confusion[1,1])/float(confusion[1,1]+confusion[1,0]) + + precision = 0 + if float(confusion[1,1]+confusion[0,1])!=0: + precision = float(confusion[1,1])/float(confusion[1,1]+confusion[0,1]) + + #Jaccard similarity index + #jaccard_index = jaccard_similarity_score(y_true, y_pred, normalize=True) + + + #F1 score + F1_score = f1_score(y_true, y_pred, labels=None, average='binary', sample_weight=None) + iou_score = jaccard_score(y_true, y_pred) + dice_score = 2*iou_score/(1+iou_score) + + + return AUC_ROC,accuracy,specificity,sensitivity,F1_score,dice_score,iou_score + + diff --git a/AV/Tools/global2patch_AND_patch2global.py b/AV/Tools/global2patch_AND_patch2global.py new file mode 100644 index 0000000000000000000000000000000000000000..7918f28d2e405ae713ad2b70bc9d89bc7de019cd --- /dev/null +++ b/AV/Tools/global2patch_AND_patch2global.py @@ -0,0 +1,106 @@ +# -*- encoding: utf-8 -*- +# author Victorshengw + + +import os +import numpy as np +from torchvision import transforms +from torch.autograd import Variable +import torch +from PIL import Image + + +def get_patch_info(shape, p_size): + ''' + shape: origin image size, (x, y) + p_size: patch size (square) + return: n_x, n_y, step_x, step_y + ''' + x = shape[0] + y = shape[1] + if x==p_size and y==p_size: + return 1, 1, 0, 0 + + n = m = 1 + while x > n * p_size: + n += 1 + while p_size - 1.0 * (x - p_size) / (n - 1) < p_size/4: + n += 1 + while y > m * p_size: + m += 1 + while p_size - 1.0 * (y - p_size) / (m - 1) < p_size/4: + m += 1 + return n, m, (x - p_size) * 1.0 / (n - 1), (y - p_size) * 1.0 / (m - 1) + + + +def global2patch(images, p_size): + ''' + image/label => patches + p_size: patch size + return: list of PIL patch images; coordinates: images->patches; ratios: (h, w) + ''' + patches = []; coordinates = []; templates = []; sizes = []; ratios = [(0, 0)] * len(images); patch_ones = np.ones(p_size) + for i in range(len(images)): + w, h = images[i].size + size = (h, w) + sizes.append(size) + ratios[i] = (float(p_size[0]) / size[0], float(p_size[1]) / size[1]) + template = np.zeros(size) + n_x, n_y, step_x, step_y = get_patch_info(size, p_size[0]) + patches.append([images[i]] * (n_x * n_y)) + coordinates.append([(0, 0)] * (n_x * n_y)) + for x in range(n_x): + if x < n_x - 1: top = int(np.round(x * step_x)) + else: top = size[0] - p_size[0] + for y in range(n_y): + if y < n_y - 1: left = int(np.round(y * step_y)) + else: left = size[1] - p_size[1] + template[top:top+p_size[0], left:left+p_size[1]] += patch_ones + coordinates[i][x * n_y + y] = (1.0 * top / size[0], 1.0 * left / size[1]) + patches[i][x * n_y + y] = transforms.functional.crop(images[i], top, left, p_size[0], p_size[1]) + + # patches[i][x * n_y + y].show() + templates.append(Variable(torch.Tensor(template).expand(1, 1, -1, -1))) + return patches, coordinates, templates, sizes, ratios + +def patch2global(patches, n_class, sizes, coordinates, p_size,flag = 0): + ''' + predicted patches (after classify layer) => predictions + return: list of np.array + ''' + patches = np.array(torch.detach(patches).cpu().numpy()) + predictions = [ np.zeros((n_class, size[0], size[1])) for size in sizes] + + for i in range(len(sizes)): + for j in range(len(coordinates[i])): + top, left = coordinates[i][j] + top = int(np.round(top * sizes[i][0])); left = int(np.round(left * sizes[i][1])) + + patches_tmp = np.zeros(patches[j][:,:,:].shape) + whole_img_tmp = predictions[i][:, top: top + p_size[0], left: left + p_size[1]] + #俩小块儿最大(max)的成为最终的prediction + #patches[j][:,:,:] 就是每个要贴到大图中的小块儿,whole_img_tmp是整个目标大图中对应patches_tmp的那一小块儿,然后将这俩及逆行比较,谁大就取谁 + if flag == 0: + patches_tmp[patches[j][:, :, :] > whole_img_tmp] = patches[j][:,:,:][patches[j][:, :, :] > whole_img_tmp] # 要贴上去的小块中的值大于大图中的值 patches[j][:, :, :] > whole_img_tmp + patches_tmp[patches[j][:, :, :] < whole_img_tmp] = whole_img_tmp[patches[j][:, :, :] < whole_img_tmp] # 要贴上去的小块中的值小于于大图中的值 patches[j][:, :, :] < whole_img_tmp + predictions[i][:, top: top + p_size[0], left: left + p_size[1]] += patches_tmp + else: + + + predictions[i][:, top: top + p_size[0], left: left + p_size[1]] += patches[j][:, :, :] + + + return predictions + + +if __name__ == '__main__': + images = [] + + img = Image.open(os.path.join(r"../train_valid/003DRIVE/image", f"01.png")) + images.append(img) + # print(len(images)) = 3 + p_size = (224,224) + patches, coordinates, templates, sizes, ratios = global2patch(images, p_size) + # predictions = patch2global(patches, 3, sizes, coordinates, p_size) + # print(type(predictions)) \ No newline at end of file diff --git a/AV/Tools/utils_test.py b/AV/Tools/utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..08d89c326d96993bed3f5ef4e6fe5778acb8b5a2 --- /dev/null +++ b/AV/Tools/utils_test.py @@ -0,0 +1,353 @@ +import cv2 +import numpy as np + +def paint_border_overlap(img, patch_h, patch_w, stride_h, stride_w): + img_h = img.shape[0] #height of the full image + img_w = img.shape[1] #width of the full image + leftover_h = (img_h-patch_h)%stride_h #leftover on the h dim + leftover_w = (img_w-patch_w)%stride_w #leftover on the w dim + if (leftover_h != 0): #change dimension of img_h + tmp_full_imgs = np.zeros((img_h+(stride_h-leftover_h),img_w, 3)) + tmp_full_imgs[0:img_h,0:img_w, :] = img + img = tmp_full_imgs + if (leftover_w != 0): #change dimension of img_w + tmp_full_imgs = np.zeros((img.shape[0], img_w+(stride_w - leftover_w), 3)) + tmp_full_imgs[0:img.shape[0], 0:img_w, :] = img + img = tmp_full_imgs + return img + +def paint_border_overlap_trad(img, patch_h, patch_w, stride_h, stride_w): + img_h = img.shape[0] #height of the full image + img_w = img.shape[1] #width of the full image + leftover_h = (img_h-patch_h)%stride_h #leftover on the h dim + leftover_w = (img_w-patch_w)%stride_w #leftover on the w dim + if (leftover_h != 0): #change dimension of img_h + tmp_full_imgs = np.zeros((img_h+(stride_h-leftover_h),img_w, 2)) + tmp_full_imgs[0:img_h,0:img_w, :] = img + img = tmp_full_imgs + if (leftover_w != 0): #change dimension of img_w + tmp_full_imgs = np.zeros((img.shape[0], img_w+(stride_w - leftover_w), 2)) + tmp_full_imgs[0:img.shape[0], 0:img_w, :] = img + img = tmp_full_imgs + return img + +def pred_only_FOV_AV(data_imgs1,data_imgs2,data_masks1,data_masks2,original_imgs_border_masks,threshold_confusion): + assert (len(data_imgs1.shape)==4 and len(data_masks1.shape)==4) #4D arrays + assert (data_imgs1.shape[0]==data_masks1.shape[0]) + assert (data_imgs1.shape[2]==data_masks1.shape[2]) + assert (data_imgs1.shape[3]==data_masks1.shape[3]) + assert (data_imgs1.shape[1]==1 and data_masks1.shape[1]==1) #check the channel is 1 + height = data_imgs1.shape[2] + width = data_imgs1.shape[3] + new_pred_imgs1 = [] + new_pred_masks1 = [] + new_pred_imgs2 = [] + new_pred_masks2 = [] + for i in range(data_imgs1.shape[0]): #loop over the full images + for x in range(width): + for y in range(height): + if inside_FOV_DRIVE_AV(i,x,y,data_imgs1,data_imgs2,original_imgs_border_masks,threshold_confusion)==True: + new_pred_imgs1.append(data_imgs1[i,:,y,x]) + new_pred_masks1.append(data_masks1[i,:,y,x]) + new_pred_imgs2.append(data_imgs2[i,:,y,x]) + new_pred_masks2.append(data_masks2[i,:,y,x]) + new_pred_imgs1 = np.asarray(new_pred_imgs1) + new_pred_masks1 = np.asarray(new_pred_masks1) + new_pred_imgs2 = np.asarray(new_pred_imgs2) + new_pred_masks2 = np.asarray(new_pred_masks2) + return new_pred_imgs1, new_pred_masks1,new_pred_imgs2, new_pred_masks2 + +def pred_only_FOV_AV(data_imgs1,data_imgs2,data_masks1,data_masks2,original_imgs_border_masks,threshold_confusion): + assert (len(data_imgs1.shape)==4 and len(data_masks1.shape)==4) #4D arrays + assert (data_imgs1.shape[0]==data_masks1.shape[0]) + assert (data_imgs1.shape[2]==data_masks1.shape[2]) + assert (data_imgs1.shape[3]==data_masks1.shape[3]) + assert (data_imgs1.shape[1]==1 and data_masks1.shape[1]==1) #check the channel is 1 + height = data_imgs1.shape[2] + width = data_imgs1.shape[3] + new_pred_imgs1 = [] + new_pred_masks1 = [] + new_pred_imgs2 = [] + new_pred_masks2 = [] + for i in range(data_imgs1.shape[0]): #loop over the full images + + + for x in range(width): + for y in range(height): + if inside_FOV_DRIVE_AV(i,x,y,data_masks1,data_masks2,original_imgs_border_masks,threshold_confusion)==True: + new_pred_imgs1.append(data_imgs1[i,:,y,x]) + new_pred_masks1.append(data_masks1[i,:,y,x]) + new_pred_imgs2.append(data_imgs2[i,:,y,x]) + new_pred_masks2.append(data_masks2[i,:,y,x]) + new_pred_imgs1 = np.asarray(new_pred_imgs1) + new_pred_masks1 = np.asarray(new_pred_masks1) + new_pred_imgs2 = np.asarray(new_pred_imgs2) + new_pred_masks2 = np.asarray(new_pred_masks2) + return new_pred_imgs1, new_pred_masks1,new_pred_imgs2, new_pred_masks2 + + + +def inside_FOV_DRIVE_AV(i, x, y,data_imgs1,data_imgs2, DRIVE_masks,threshold_confusion): + assert (len(DRIVE_masks.shape)==4) #4D arrays + assert (DRIVE_masks.shape[1]==1) #DRIVE masks is black and white + # DRIVE_masks = DRIVE_masks/255. #NOOO!! otherwise with float numbers takes forever!! + + if (x >= DRIVE_masks.shape[3] or y >= DRIVE_masks.shape[2]): #my image bigger than the original + return False + + if (DRIVE_masks[i,0,y,x]>0)&((data_imgs1[i,0,y,x]>threshold_confusion)|(data_imgs2[i,0,y,x]>threshold_confusion)): #0==black pixels + # print DRIVE_masks[i,0,y,x] #verify it is working right + return True + else: + return False + +def extract_ordered_overlap_trad(img, patch_h, patch_w,stride_h,stride_w,ratio): + img_h = img.shape[0] #height of the full image + img_w = img.shape[1] #width of the full image + assert ((img_h-patch_h)%stride_h==0 and (img_w-patch_w)%stride_w==0) + N_patches_img = ((img_h-patch_h)//stride_h+1)*((img_w-patch_w)//stride_w+1) #// --> division between integers + patches = np.empty((N_patches_img, patch_h//ratio, patch_w//ratio, 2)) + iter_tot = 0 #iter over the total number of patches (N_patches) + for h in range((img_h-patch_h)//stride_h+1): + for w in range((img_w-patch_w)//stride_w+1): + patch = img[h*stride_h:(h*stride_h)+patch_h, w*stride_w:(w*stride_w)+patch_w, :] + patch = cv2.resize(patch,(patch_h//ratio, patch_w//ratio)) + patches[iter_tot]=patch + iter_tot +=1 #total + assert (iter_tot==N_patches_img) + return patches #array with all the img divided in patches + +def make_pad(Patches,pad,sign='',normalize=True): + # h,w,3 + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + if pad<=0: + return Patches + p = np.zeros((256,256,3),dtype=np.float32) + if sign=='img' and normalize: + p[:,:,0] = (p[:,:,0] - mean[0]) / std[0] + p[:,:,1] = (p[:,:,1] - mean[1]) / std[1] + p[:,:,2] = (p[:,:,2] - mean[2]) / std[2] + p[:Patches.shape[0],:Patches.shape[1],:] = Patches + return p + +def extract_ordered_overlap_big(img, patch_h=256, patch_w=256, stride_h=256, stride_w=256): + img_h = img.shape[0] #height of the full image + img_w = img.shape[1] #width of the full image + + big_patch_h = int(patch_h * 1.5) + big_patch_w = int(patch_w * 1.5) + assert ((img_h-patch_h)%stride_h==0 and (img_w-patch_w)%stride_w==0) + N_patches_img = ((img_h-patch_h)//stride_h+1)*((img_w-patch_w)//stride_w+1) #// --> division between integers + patches = np.empty((N_patches_img, patch_h, patch_w, 3) ) + patches_big = np.empty((N_patches_img, patch_h, patch_w, 3)) + img_big = np.zeros((img_h+(big_patch_h-patch_h), img_w+(big_patch_w-patch_w), 3)) + + img_big[(big_patch_h-patch_h)//2:(big_patch_h-patch_h)//2+img_h, (big_patch_w-patch_w)//2:(big_patch_w-patch_w)//2+img_w, :] = img + + iter_tot = 0 #iter over the total number of patches (N_patches) + for h in range((img_h-patch_h)//stride_h+1): + for w in range((img_w-patch_w)//stride_w+1): + patch = img[h*stride_h:(h*stride_h)+patch_h, w*stride_w:(w*stride_w)+patch_w, :] + #patch = cv2.resize(patch,(256, 256)) + patches[iter_tot] = patch + if np.unique(patch).shape[0] == 1: + patches_big[iter_tot] = patch + else: + patch_big = img_big[h*stride_h:h*stride_h+big_patch_h, w*stride_w:w*stride_w+big_patch_w, :] + # print(patch_big.shape) + patch_big = cv2.resize(patch_big,(patch_h, patch_w)) + patches_big[iter_tot] = patch_big + + iter_tot += 1 # total + assert (iter_tot == N_patches_img) + return patches, patches_big # array with all the img divided in patches + +def extract_ordered_overlap_big_v2(img, patch_h=256, patch_w=256, stride_h=256, stride_w=256): + img_h = img.shape[0] #height of the full image + img_w = img.shape[1] #width of the full image + assert ((img_h-patch_h)%stride_h==0 and (img_w-patch_w)%stride_w==0) + N_patches_img = ((img_h-patch_h)//stride_h+1)*((img_w-patch_w)//stride_w+1) #// --> division between integers + patches = np.empty((N_patches_img, 256, 256, 3)) + patches_big = np.empty((N_patches_img, 256, 256, 3)) + + iter_tot = 0 #iter over the total number of patches (N_patches) + for h in range((img_h-patch_h)//stride_h+1): + for w in range((img_w-patch_w)//stride_w+1): + patch = img[h*stride_h:(h*stride_h)+patch_h, w*stride_w:(w*stride_w)+patch_w, :] + pad = max(0,256-patch_h) + patch = make_pad(patch,pad,normalize=False) + #patch = cv2.resize(patch,(256, 256)) + patches[iter_tot]=patch + + # patch_big = img[max(0,h*stride_h-patch_h//4):min((h*stride_h)+patch_h+patch_h//4,img_h),max(0,w*stride_w-patch_w//4):min((w*stride_w)+patch_w+patch_w//4,img_w), :] + if h==0 and w==0: + patch_big = img[0:patch_h+patch_h//2, 0:patch_w+patch_w//2, :] + + elif h==0 and w!=0 and w!=((img_w-patch_w)//stride_w): + patch_big = img[0:0+patch_h+patch_h//2, w*stride_w-patch_w//4:(w*stride_w)+patch_w+patch_w//4, :] + elif h==0 and w==((img_w-patch_w)//stride_w): + patch_big = img[0:0+patch_h+patch_h//2, (w)*stride_w-patch_w//2:(w*stride_w)+patch_w, :] + elif h!=0 and h!=((img_h-patch_h)//stride_h) and w==0: + patch_big = img[h*stride_h-patch_h//4:(h*stride_h)+patch_h+patch_h//4, 0:0+patch_w+patch_w//2, :] + + elif h==((img_h-patch_h)//stride_h) and w==0: + patch_big = img[(h)*stride_h-patch_h//2:(h*stride_h)+patch_h, 0:patch_w+patch_w//2, :] + elif h==((img_h-patch_h)//stride_h) and w!=0 and w!=((img_w-patch_w)//stride_w): + patch_big = img[h*stride_h-patch_h//2:(h*stride_h)+patch_h, w*stride_w-patch_w//4:(w*stride_w)+patch_w+patch_w//4, :] + elif h==((img_h-patch_h)//stride_h) and w==((img_w-patch_w)//stride_w): + patch_big = img[h*stride_h-patch_h//2:(h*stride_h)+patch_h, (w)*stride_w-patch_w//2:(w*stride_w)+patch_w, :] + elif h!=0 and h!=((img_h-patch_h)//stride_h) and w==((img_w-patch_w)//stride_w): + patch_big = img[h*stride_h-patch_h//4:(h*stride_h)+patch_h+patch_h//4, (w)*stride_w-patch_w//2:(w*stride_w)+patch_w, :] + else: + patch_big = img[h*stride_h-patch_h//4:(h*stride_h)+patch_h+patch_h//4,w*stride_w-patch_w//4:(w*stride_w)+patch_w+patch_w//4, :] + # print(patch_big.shape) + patch_big = cv2.resize(patch_big,(256, 256)) + + patches_big[iter_tot]=patch_big + + iter_tot +=1 #total + assert (iter_tot==N_patches_img) + return patches,patches_big #array with all the img divided in patches + + +def extract_ordered_overlap_big_v1(img, patch_h, patch_w,stride_h,stride_w): + img_h = img.shape[0] #height of the full image + img_w = img.shape[1] #width of the full image + assert ((img_h-patch_h)%stride_h==0 and (img_w-patch_w)%stride_w==0) + N_patches_img = ((img_h-patch_h)//stride_h+1)*((img_w-patch_w)//stride_w+1) #// --> division between integers + patches = np.empty((N_patches_img, patch_h, patch_w, 3)) + patches_big = np.empty((N_patches_img, patch_h, patch_w, 3)) + + iter_tot = 0 #iter over the total number of patches (N_patches) + for h in range((img_h-patch_h)//stride_h+1): + for w in range((img_w-patch_w)//stride_w+1): + patch = img[h*stride_h:(h*stride_h)+patch_h, w*stride_w:(w*stride_w)+patch_w, :] + #patch = cv2.resize(patch,(256, 256)) + patches[iter_tot]=patch + if h==0 and w==0: + patch_big = img[h*stride_h:(h*stride_h)+patch_h+stride_h, w*stride_w:(w*stride_w)+patch_w+stride_w, :] + elif h==0 and w!=0 and w!=((img_w-patch_w)//stride_w): + patch_big = img[h*stride_h:(h*stride_h)+patch_h+stride_h, int((w-0.5)*stride_w):(w*stride_w)+patch_w+stride_w//2, :] + elif h==0 and w==((img_w-patch_w)//stride_w): + patch_big = img[h*stride_h:(h*stride_h)+patch_h+stride_h, (w-1)*stride_w:(w*stride_w)+patch_w, :] + elif h!=0 and h!=((img_h-patch_h)//stride_h) and w==0: + patch_big = img[int((h-0.5)*stride_h):(h*stride_h)+patch_h+stride_h//2, w*stride_w:(w*stride_w)+patch_w+stride_w, :] + + elif h==((img_h-patch_h)//stride_h) and w==0: + patch_big = img[(h-1)*stride_h:(h*stride_h)+patch_h, w*stride_w:(w*stride_w)+patch_w+stride_w, :] + elif h==((img_h-patch_h)//stride_h) and w!=0 and w!=((img_w-patch_w)//stride_w): + patch_big = img[(h-1)*stride_h:(h*stride_h)+patch_h, int((w-0.5)*stride_w):(w*stride_w)+patch_w+stride_w//2, :] + elif h==((img_h-patch_h)//stride_h) and w==((img_w-patch_w)//stride_w): + patch_big = img[(h-1)*stride_h:(h*stride_h)+patch_h, (w-1)*stride_w:(w*stride_w)+patch_w, :] + elif h!=0 and h!=((img_h-patch_h)//stride_h) and w==((img_w-patch_w)//stride_w): + patch_big = img[int((h-0.5)*stride_h):(h*stride_h)+patch_h+stride_h//2, (w-1)*stride_w:(w*stride_w)+patch_w, :] + else: + patch_big = img[int((h-0.5)*stride_h):(h*stride_h)+patch_h+stride_h//2, int((w-0.5)*stride_w):(w*stride_w)+patch_w+stride_w//2, :] + + patch_big = cv2.resize(patch_big,(256, 256)) + + patches_big[iter_tot]=patch_big + + iter_tot +=1 #total + assert (iter_tot==N_patches_img) + return patches,patches_big #array with all the img divided in patches + + + + + +def pred_to_imgs(pred,mode="original"): + assert (len(pred.shape)==3) #3D array: (Npatches,height*width,2) + assert (pred.shape[2]==2 ) #check the classes are 2 + pred_images = np.empty((pred.shape[0],pred.shape[1])) #(Npatches,height*width) + if mode=="original": + for i in range(pred.shape[0]): + for pix in range(pred.shape[1]): + pred_images[i,pix]=pred[i,pix,1] + elif mode=="threshold": + for i in range(pred.shape[0]): + for pix in range(pred.shape[1]): + if pred[i,pix,1]>=0.5: + pred_images[i,pix]=1 + else: + pred_images[i,pix]=0 + else: + print("mode " +str(mode) +" not recognized, it can be 'original' or 'threshold'") + exit() + pred_images = np.reshape(pred_images,(pred_images.shape[0],1,48,48)) + return pred_images + +def recompone_overlap(pred_patches, img_h, img_w, stride_h, stride_w): + assert (len(pred_patches.shape)==4) #4D arrays + #assert (pred_patches.shape[1]==2 or pred_patches.shape[1]==3) #check the channel is 1 or 3 + patch_h = pred_patches.shape[2] + patch_w = pred_patches.shape[3] + N_patches_h = (img_h-patch_h)//stride_h+1 + N_patches_w = (img_w-patch_w)//stride_w+1 + N_patches_img = N_patches_h * N_patches_w + #assert (pred_patches.shape[0]%N_patches_img==0) + #N_full_imgs = pred_patches.shape[0]//N_patches_img + full_prob = np.zeros((pred_patches.shape[1], img_h,img_w,)) #itialize to zero mega array with sum of Probabilities + full_sum = np.zeros((pred_patches.shape[1], img_h,img_w)) + + k = 0 #iterator over all the patches + for h in range(N_patches_h): + for w in range(N_patches_w): + full_prob[:, h*stride_h:(h*stride_h)+patch_h, w*stride_w:(w*stride_w)+patch_w]+=pred_patches[k] + full_sum[:, h*stride_h:(h*stride_h)+patch_h, w*stride_w:(w*stride_w)+patch_w]+=1 + k+=1 + assert(k==pred_patches.shape[0]) + assert(np.min(full_sum)>=1.0) #at least one + final_avg = full_prob/full_sum + #print(final_avg.shape) + # assert(np.max(final_avg)<=1.0) #max value for a pixel is 1.0 + # assert(np.min(final_avg)>=0.0) #min value for a pixel is 0.0 + return final_avg + + +def Normalize(Patches): + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + + # mean = [0.3261, 0.2287, 0.1592] + # std = [0.2589, 0.1882, 0.1369] + + + Patches[:,0,:,:] = (Patches[:,0,:,:] - mean[0]) / std[0] + Patches[:,1,:,:] = (Patches[:,1,:,:] - mean[1]) / std[1] + Patches[:,2,:,:] = (Patches[:,2,:,:] - mean[2]) / std[2] + return Patches + +def Normalize_patch(Patches): + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + + Patches[0,:,:] = (Patches[0,:,:] - mean[0]) / std[0] + Patches[1,:,:] = (Patches[1,:,:] - mean[1]) / std[1] + Patches[2,:,:] = (Patches[2,:,:] - mean[2]) / std[2] + return Patches + +def sigmoid(x): + return np.exp((x)) / (1 + np.exp(x)) + +def inside_FOV_DRIVE(x, y, DRIVE_masks): + # DRIVE_masks = DRIVE_masks/255. #NOOO!! otherwise with float numbers takes forever!! + if (x >= DRIVE_masks.shape[1] or y >= DRIVE_masks.shape[0]): #my image bigger than the original + return False + + if (DRIVE_masks[y,x]>0): #0==black pixels + return True + else: + return False + + +def kill_border(pred_img, border_masks): + height = pred_img.shape[1] + width = pred_img.shape[2] + for x in range(width): + for y in range(height): + if inside_FOV_DRIVE(x,y, border_masks)==False: + pred_img[:,y,x]=0.0 + return pred_img + diff --git a/AV/Tools/warmup.py b/AV/Tools/warmup.py new file mode 100644 index 0000000000000000000000000000000000000000..e903565039fce937eb164d0609cf41d6803d95d9 --- /dev/null +++ b/AV/Tools/warmup.py @@ -0,0 +1,69 @@ + + +import torch +from torch.optim.lr_scheduler import StepLR, ExponentialLR + +class GradualWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): + """ Gradually warm-up(increasing) learning rate in optimizer. + Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. + + Args: + optimizer (Optimizer): Wrapped optimizer. + multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. + total_epoch: target learning rate is reached at total_epoch, gradually + after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) + """ + + def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): + self.multiplier = multiplier + if self.multiplier < 1.: + raise ValueError('multiplier should be greater thant or equal to 1.') + self.total_epoch = total_epoch + self.after_scheduler = after_scheduler + self.finished = False + super(GradualWarmupScheduler, self).__init__(optimizer) + + def get_lr(self): + if self.last_epoch > self.total_epoch: + if self.after_scheduler: + if not self.finished: + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] + self.finished = True + return self.after_scheduler.get_last_lr() + return [base_lr * self.multiplier for base_lr in self.base_lrs] + + if self.multiplier == 1.0: + return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] + else: + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] + + + def step(self, epoch=None, metrics=None): + + if self.finished and self.after_scheduler: + if epoch is None: + self.after_scheduler.step() + else: + self.after_scheduler.step() + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super(GradualWarmupScheduler, self).step(epoch) + + +if __name__ == '__main__': + model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] + optim = torch.optim.Adam(model, 0.0002) + + # scheduler_warmup is chained with schduler_steplr + scheduler_steplr = StepLR(optim, step_size=80, gamma=0.1) + scheduler_warmup = GradualWarmupScheduler(optim, multiplier=2, total_epoch=10, after_scheduler=scheduler_steplr) + + # this zero gradient update is needed to avoid a warning message, issue #8. + optim.zero_grad() + optim.step() + + for epoch in range(1, 20): + scheduler_warmup.step(epoch) + print(epoch, optim.param_groups[0]['lr']) + + optim.step() \ No newline at end of file diff --git a/AV/config/__pycache__/config_test_general.cpython-310.pyc b/AV/config/__pycache__/config_test_general.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97ca7235e9c9bbaca45d140a3cebfb4828c905fa Binary files /dev/null and b/AV/config/__pycache__/config_test_general.cpython-310.pyc differ diff --git a/AV/config/__pycache__/config_test_general.cpython-39.pyc b/AV/config/__pycache__/config_test_general.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03c699f443248192e21db798d8154c3ee5998c2c Binary files /dev/null and b/AV/config/__pycache__/config_test_general.cpython-39.pyc differ diff --git a/AV/config/__pycache__/config_train_general.cpython-310.pyc b/AV/config/__pycache__/config_train_general.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fe225e18a71d6d29d2c35bcddc29e09d95018cb Binary files /dev/null and b/AV/config/__pycache__/config_train_general.cpython-310.pyc differ diff --git a/AV/config/__pycache__/config_train_general.cpython-39.pyc b/AV/config/__pycache__/config_train_general.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dfab89c1a16a028e018af16af3ed808821cfc79 Binary files /dev/null and b/AV/config/__pycache__/config_train_general.cpython-39.pyc differ diff --git a/AV/config/config_test_general.py b/AV/config/config_test_general.py new file mode 100644 index 0000000000000000000000000000000000000000..46fb4d1ec5ee00d63cfbc6a1124709c9061c8a07 --- /dev/null +++ b/AV/config/config_test_general.py @@ -0,0 +1,134 @@ +import torch +import os + +# Check GPU availability +use_cuda = torch.cuda.is_available() +gpu_ids = [0] if use_cuda else [] +device = torch.device('cuda' if use_cuda else 'cpu') + +dataset_name = 'all' # DRIVE +#dataset_name = 'LES' # LES +# dataset_name = 'hrf' # HRF +# dataset_name = 'ukbb' # UKBB +# dataset_name = 'all' +dataset = dataset_name +max_step = 30000 # 30000 for ukbb + +batch_size = 8 # default: 4 +print_iter = 100 # default: 100 +display_iter = 100 # default: 100 +save_iter = 5000 # default: 5000 +first_display_metric_iter = max_step - save_iter # default: 25000 +lr = 0.0002 # if dataset_name!='LES' else 0.00005 # default: 0.0002 +step_size = 7000 # 7000 for DRIVE +lr_decay_gamma = 0.5 # default: 0.5 +use_SGD = False # default:False + +input_nc = 3 +ndf = 32 +netD_type = 'basic' +n_layers_D = 5 +norm = 'instance' +no_lsgan = False +init_type = 'normal' +init_gain = 0.02 +use_sigmoid = no_lsgan +use_noise_input_D = False +use_dropout_D = False +# torch.cuda.set_device(gpu_ids[0]) +use_GAN = True # default: True + +# adam +beta1 = 0.5 + +# settings for GAN loss +num_classes_D = 1 +lambda_GAN_D = 0.01 +lambda_GAN_G = 0.01 +lambda_GAN_gp = 100 +lambda_BCE = 5 +lambda_DICE = 5 + +input_nc_D = input_nc + 3 + +# settings for centerness +use_centerness = True # default: True +lambda_centerness = 1 +center_loss_type = 'centerness' +centerness_map_size = [128, 128] + +# pretrained model +use_pretrained_G = True +use_pretrained_D = False +# model_path_pretrained_G = './log/patch_pretrain' +model_path_pretrained_G = '' +model_step_pretrained_G = 0 +stride_height = 0 +stride_width = 0 +patch_size_list=[] + +def set_dataset(name): + global dataset_name, model_path_pretrained_G, model_step_pretrained_G + global stride_height, stride_width,patch_size,patch_size_list,dataset + dataset_name = name + dataset = name + if dataset_name == 'DRIVE': + model_path_pretrained_G = './AV/log/DRIVE-2023_10_20_08_36_50(6500)' + model_step_pretrained_G = 6500 + elif dataset_name == 'LES': + model_path_pretrained_G = './AV/log/LES-2023_09_28_14_04_06(0)' + model_step_pretrained_G = 0 + elif dataset_name == 'hrf': + model_path_pretrained_G = './AV/log/HRF-2023_10_19_11_07_31(1500)' + model_step_pretrained_G = 1500 + elif dataset_name == 'ukbb': + model_path_pretrained_G = './AV/log/UKBB-2023_11_02_23_22_07(5000)' + model_step_pretrained_G = 5000 + else: + model_path_pretrained_G = './AV/log/ALL-2024_09_06_09_17_18(9000)' + model_step_pretrained_G = 9000 + if dataset_name == 'DRIVE': + patch_size_list = [64, 128, 256] + elif dataset_name == 'LES': + patch_size_list = [96, 384, 256] + elif dataset_name == 'hrf': + patch_size_list = [64, 384, 256] + elif dataset_name == 'ukbb': + patch_size_list = [96, 384, 256] + else: + patch_size_list = [96, 384, 512] + patch_size = patch_size_list[2] + +# path for dataset + if dataset_name == 'DRIVE' or dataset_name == 'LES' or dataset_name == 'hrf': + stride_height = 50 + stride_width = 50 + else: + stride_height = 150 + stride_width = 150 + +n_classes = 3 + +model_step = 0 + +# use CAM +use_CAM = False + +#use resize +use_resize = True +resize_w_h = (1920,512) + +# use av_cross +use_av_cross = False + +use_high_semantic = False +lambda_high = 1 # A,V,Vessel + +# use global semantic +use_global_semantic = False +global_warmup_step = 0 if use_pretrained_G else 5000 + + + + + diff --git a/AV/config/config_train_general.py b/AV/config/config_train_general.py new file mode 100644 index 0000000000000000000000000000000000000000..98b025091efed5931da26b7aa0fd6ebc8df3d898 --- /dev/null +++ b/AV/config/config_train_general.py @@ -0,0 +1,121 @@ +import torch +import os + +# Check GPU availability +use_cuda = torch.cuda.is_available() +gpu_ids = [0] if use_cuda else [] +device = torch.device('cuda' if use_cuda else 'cpu') + + +dataset_name = 'DRIVE' # DRIVE +#dataset_name = 'LES' # LES +#dataset_name = 'hrf' # HRF +dataset = dataset_name + +max_step = 30000 # 30000 for ukbb +if dataset_name=='DRIVE': + patch_size_list = [64, 128, 256] +elif dataset_name=='LES': + patch_size_list = [96,384, 256] +elif dataset_name=='hrf': + patch_size_list = [64, 384, 256] +patch_size = patch_size_list[2] +batch_size = 8 # default: 4 +print_iter = 100 # default: 100 +display_iter = 100 # default: 100 +save_iter = 5000 # default: 5000 +first_display_metric_iter = max_step-save_iter # default: 25000 +lr = 0.0002 #if dataset_name!='LES' else 0.00005 # default: 0.0002 +step_size = 7000 # 7000 for DRIVE +lr_decay_gamma = 0.5 # default: 0.5 +use_SGD = False # default:False + +input_nc = 3 +ndf = 32 +netD_type = 'basic' +n_layers_D = 5 +norm = 'instance' +no_lsgan = False +init_type = 'normal' +init_gain = 0.02 +use_sigmoid = no_lsgan +use_noise_input_D = False +use_dropout_D = False + +# torch.cuda.set_device(gpu_ids[0]) +use_GAN = True # default: True + +# adam +beta1 = 0.5 + +# settings for GAN loss +num_classes_D = 1 +lambda_GAN_D = 0.01 +lambda_GAN_G = 0.01 +lambda_GAN_gp = 100 +lambda_BCE = 5 +lambda_DICE = 5 + +input_nc_D = input_nc + 3 + +# settings for centerness +use_centerness =True # default: True +lambda_centerness = 1 +center_loss_type = 'centerness' +centerness_map_size = [128,128] + +# pretrained model +use_pretrained_G = True +use_pretrained_D = False + +model_path_pretrained_G = r"../RIP/weight" + +model_step_pretrained_G = 'best_drive' + + +# path for dataset +stride_height = 50 +stride_width = 50 + + +n_classes = 3 + +model_step = 0 + +# use CAM +use_CAM = False + +#use resize +use_resize = False +resize_w_h = (256,256) + +#use av_cross +use_av_cross = False + +use_high_semantic = False +lambda_high = 1 # A,V,Vessel + +# use global semantic +use_global_semantic = True +global_warmup_step = 0 if use_pretrained_G else 5000 + +# use network +use_network = 'convnext_tiny' # swin_t,convnext_tiny + +dataset_path = {'DRIVE': './data/AV_DRIVE/training/', + + 'hrf': './data/hrf/training/', + + 'LES': './data/LES_AV/training/', + + } +trainset_path = dataset_path[dataset_name] + + +print("Dataset:") +print(trainset_path) +print(use_network) + + + + diff --git a/AV/log/ALL-2024_09_06_09_17_18(9000)/G_9000.pkl b/AV/log/ALL-2024_09_06_09_17_18(9000)/G_9000.pkl new file mode 100644 index 0000000000000000000000000000000000000000..f6580a6bf04a1852206981672dab50805235478d --- /dev/null +++ b/AV/log/ALL-2024_09_06_09_17_18(9000)/G_9000.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15bff178940810286c7df89a4ac416a167a3f2a14c015ce3bf3a4c65b082ae49 +size 115842242 diff --git a/AV/log/DRIVE-2023_10_20_08_36_50(6500)/G_6500.pkl b/AV/log/DRIVE-2023_10_20_08_36_50(6500)/G_6500.pkl new file mode 100644 index 0000000000000000000000000000000000000000..910ea248f103128c8ebf15b8a5f41ceac48cef1e --- /dev/null +++ b/AV/log/DRIVE-2023_10_20_08_36_50(6500)/G_6500.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c2f29a05c014e372e9f3421b34df62cf3a557b9f4834da27e37c555a4f465d5 +size 115848956 diff --git a/AV/log/HRF-2023_10_19_11_07_31(1500)/G_1500.pkl b/AV/log/HRF-2023_10_19_11_07_31(1500)/G_1500.pkl new file mode 100644 index 0000000000000000000000000000000000000000..2181adccf6a620c683648ad4feb44672864d5a33 --- /dev/null +++ b/AV/log/HRF-2023_10_19_11_07_31(1500)/G_1500.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:325ce3ee74ecd9fb1c554fe810f70e471595e2f521b5fb094a34a682846efb7d +size 115848956 diff --git a/AV/log/LES-2023_09_28_14_04_06(0)/G_0.pkl b/AV/log/LES-2023_09_28_14_04_06(0)/G_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7af40635582f07bee0d182a10fd3989cbf75b184 --- /dev/null +++ b/AV/log/LES-2023_09_28_14_04_06(0)/G_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21d2261e124773dfe15113c2bb9c4e16c59988c99613cba68e0209a69be7a770 +size 115830579 diff --git a/AV/log/UKBB-2023_11_02_23_22_07(5000)/G_5000.pkl b/AV/log/UKBB-2023_11_02_23_22_07(5000)/G_5000.pkl new file mode 100644 index 0000000000000000000000000000000000000000..b8507164d1b0c0e336db43b36f0ae9e819ea6013 --- /dev/null +++ b/AV/log/UKBB-2023_11_02_23_22_07(5000)/G_5000.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d17ff9ab13f85f6e29ab4ddff08a918a962e73a895168fb4525c55cfc0493b1 +size 115848956 diff --git a/AV/models/__pycache__/layers.cpython-310.pyc b/AV/models/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..515cc643d5d0bf424bc746b8b5ea6e205df1e83d Binary files /dev/null and b/AV/models/__pycache__/layers.cpython-310.pyc differ diff --git a/AV/models/__pycache__/layers.cpython-38.pyc b/AV/models/__pycache__/layers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abf15e2e63570daa6ec05c222ca7bfd43bd4a719 Binary files /dev/null and b/AV/models/__pycache__/layers.cpython-38.pyc differ diff --git a/AV/models/__pycache__/layers.cpython-39.pyc b/AV/models/__pycache__/layers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85b15d465d54f2a9db6a802c94d2e8eaf8998c55 Binary files /dev/null and b/AV/models/__pycache__/layers.cpython-39.pyc differ diff --git a/AV/models/__pycache__/network.cpython-310.pyc b/AV/models/__pycache__/network.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4addeb38390cc8a1831c6259df6b90dd8a556c50 Binary files /dev/null and b/AV/models/__pycache__/network.cpython-310.pyc differ diff --git a/AV/models/__pycache__/network.cpython-38.pyc b/AV/models/__pycache__/network.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..729b7fb4c031dcf0e557ac2624284433734f03da Binary files /dev/null and b/AV/models/__pycache__/network.cpython-38.pyc differ diff --git a/AV/models/__pycache__/network.cpython-39.pyc b/AV/models/__pycache__/network.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01ee33ae8c23503b4e839922371310b2d3098e52 Binary files /dev/null and b/AV/models/__pycache__/network.cpython-39.pyc differ diff --git a/AV/models/__pycache__/networks_gan.cpython-310.pyc b/AV/models/__pycache__/networks_gan.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..731c69d26e06d2cdbb66b8d530b5315fc83933bb Binary files /dev/null and b/AV/models/__pycache__/networks_gan.cpython-310.pyc differ diff --git a/AV/models/__pycache__/networks_gan.cpython-39.pyc b/AV/models/__pycache__/networks_gan.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71239deb3d4086bfe32d5e43d495c7fc42f0e120 Binary files /dev/null and b/AV/models/__pycache__/networks_gan.cpython-39.pyc differ diff --git a/AV/models/__pycache__/sw_gan.cpython-310.pyc b/AV/models/__pycache__/sw_gan.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d5b45e5e40c5860b20cb95544ba2375824a08c5 Binary files /dev/null and b/AV/models/__pycache__/sw_gan.cpython-310.pyc differ diff --git a/AV/models/__pycache__/sw_gan.cpython-38.pyc b/AV/models/__pycache__/sw_gan.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6649b71a5a88dbd58bafc22ab84b0df2b4260a56 Binary files /dev/null and b/AV/models/__pycache__/sw_gan.cpython-38.pyc differ diff --git a/AV/models/__pycache__/sw_gan.cpython-39.pyc b/AV/models/__pycache__/sw_gan.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc43c866845b5056df1bb3cc64799e7ef2989d20 Binary files /dev/null and b/AV/models/__pycache__/sw_gan.cpython-39.pyc differ diff --git a/AV/models/layers.py b/AV/models/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..0c0d732c300434d9abc129e0c7a51d2c3aa23650 --- /dev/null +++ b/AV/models/layers.py @@ -0,0 +1,674 @@ +# -*- coding: utf-8 -*- + +import torch +from torch import nn +import torch.nn.functional as F +# from timm.models.layers.cbam import CbamModule +import numpy as np +from einops import rearrange, repeat +import math + + +class ConvBn2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding): + super(ConvBn2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class sSE(nn.Module): + def __init__(self, out_channels): + super(sSE, self).__init__() + self.conv = ConvBn2d(in_channels=out_channels, out_channels=1, kernel_size=1, padding=0) + + def forward(self, x): + x = self.conv(x) + # print('spatial',x.size()) + x = F.sigmoid(x) + return x + + +class cSE(nn.Module): + def __init__(self, out_channels): + super(cSE, self).__init__() + self.conv1 = ConvBn2d(in_channels=out_channels, out_channels=int(out_channels / 2), kernel_size=1, padding=0) + self.conv2 = ConvBn2d(in_channels=int(out_channels / 2), out_channels=out_channels, kernel_size=1, padding=0) + + def forward(self, x): + x = nn.AvgPool2d(x.size()[2:])(x) + # print('channel',x.size()) + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.sigmoid(x) + return x + + +class scSEBlock(nn.Module): + def __init__(self, out_channels): + super(scSEBlock, self).__init__() + self.spatial_gate = sSE(out_channels) + self.channel_gate = cSE(out_channels) + + def forward(self, x): + g1 = self.spatial_gate(x) + g2 = self.channel_gate(x) + x = g1 * x + g2 * x + return x + + +class SaveFeatures(): + features = None + + def __init__(self, m): + self.hook = m.register_forward_hook(self.hook_fn) + + def hook_fn(self, module, input, output): + # print('input',input) + # print('output',output.size()) + if len(output.shape) == 3: + B, L, C = output.shape + h = int(L ** 0.5) + output = output.view(B, h, h, C) + + output = output.permute(0, 3, 1, 2).contiguous() + if len(output.shape) == 4 and output.shape[2] != output.shape[3]: + output = output.permute(0, 3, 1, 2).contiguous() + # print(module) + self.features = output + + def remove(self): + self.hook.remove() + + +class DBlock(nn.Module): + + def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type=None): + + super(DBlock, self).__init__() + + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + if attention_type == 'scse': + self.attention1 = scSEBlock(in_channels) + elif attention_type == 'cbam': + self.attention1 = nn.Identity() + + elif attention_type == 'transformer': + + self.attention1 = nn.Identity() + + + else: + self.attention1 = nn.Identity() + + self.conv2 = \ + nn.Sequential( + nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + self.conv3 = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + if attention_type == 'scse': + self.attention2 = scSEBlock(out_channels) + elif attention_type == 'cbam': + self.attention2 = CbamModule(channels=out_channels) + + elif attention_type == 'transformer': + self.attention2 = nn.Identity() + + else: + self.attention2 = nn.Identity() + + def forward(self, x, skip): + if x.shape[1] != skip.shape[1]: + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + + # print(x.shape,skip.shape) + x = self.attention1(x) + x = self.conv1(x) + + x = torch.cat([x, skip], dim=1) + + x = self.conv2(x) + x = self.conv3(x) + x = self.attention2(x) + + return x + + +class DBlock_res(nn.Module): + + def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type=None): + + super(DBlock_res, self).__init__() + + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + if attention_type == 'scse': + self.attention1 = scSEBlock(in_channels) + elif attention_type == 'cbam': + self.attention1 = CbamModule(channels=in_channels) + + elif attention_type == 'transformer': + + self.attention1 = nn.Identity() + + + else: + self.attention1 = nn.Identity() + + self.conv2 = \ + nn.Sequential( + nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + self.conv3 = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + if attention_type == 'scse': + self.attention2 = scSEBlock(out_channels) + elif attention_type == 'cbam': + self.attention2 = CbamModule(channels=out_channels) + + elif attention_type == 'transformer': + self.attention2 = nn.Identity() + + else: + self.attention2 = nn.Identity() + + def forward(self, x, skip): + + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + + # print(x.shape,skip.shape) + x = self.attention1(x) + x = self.conv1(x) + + x = torch.cat([x, skip], dim=1) + + x = self.conv2(x) + x = self.conv3(x) + x = self.attention2(x) + + return x + + +class DBlock_att(nn.Module): + + def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type='transformer'): + + super(DBlock_att, self).__init__() + + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + if attention_type == 'scse': + self.attention1 = scSEBlock(in_channels) + elif attention_type == 'cbam': + self.attention1 = CbamModule(channels=in_channels) + + elif attention_type == 'transformer': + + self.attention1 = nn.Identity() + + + else: + self.attention1 = nn.Identity() + + self.conv2 = \ + nn.Sequential( + nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + self.conv3 = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + if attention_type == 'scse': + self.attention2 = scSEBlock(out_channels) + elif attention_type == 'cbam': + self.attention2 = CbamModule(channels=out_channels) + + elif attention_type == 'transformer': + self.attention2 = nn.Identity() + + else: + self.attention2 = nn.Identity() + + def forward(self, x, skip): + if x.shape[1] != skip.shape[1]: + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + + # print(x.shape,skip.shape) + x = self.attention1(x) + x = self.conv1(x) + + x = torch.cat([x, skip], dim=1) + x = self.conv2(x) + x = self.conv3(x) + + x = self.attention2(x) + + return x + + +class SegmentationHead(nn.Module): + def __init__(self, in_channels, num_class, kernel_size=3, upsample=4): + super(SegmentationHead, self).__init__() + self.upsample = nn.UpsamplingBilinear2d(scale_factor=upsample) if upsample > 1 else nn.Identity() + self.conv = nn.Conv2d(in_channels, num_class, kernel_size=kernel_size, padding=kernel_size // 2) + + def forward(self, x): + x = self.upsample(x) + x = self.conv(x) + return x + + +class AV_Cross(nn.Module): + + def __init__(self, channels=2, r=2, residual=True, block=4, kernel_size=1): + super(AV_Cross, self).__init__() + out_channels = int(channels // r) + self.residual = residual + self.block = block + self.bn = nn.BatchNorm2d(3) + self.relu = False + self.kernel_size = kernel_size + self.a_ve_att = nn.ModuleList() + self.v_ve_att = nn.ModuleList() + self.ve_att = nn.ModuleList() + for i in range(self.block): + self.a_ve_att.append(nn.Sequential( + nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1, + padding=(self.kernel_size - 1) // 2), + nn.BatchNorm2d(out_channels), + )) + self.v_ve_att.append(nn.Sequential( + nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1, + padding=(self.kernel_size - 1) // 2), + nn.BatchNorm2d(out_channels), + )) + self.ve_att.append(nn.Sequential( + nn.Conv2d(3, out_channels, kernel_size=self.kernel_size, stride=1, padding=(self.kernel_size - 1) // 2), + nn.BatchNorm2d(out_channels), + )) + self.sigmoid = nn.Sigmoid() + self.final = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + a, ve, v = x[:, 0:1, :, :], x[:, 1:2, :, :], x[:, 2:, :, :] + for i in range(self.block): + # x = self.relu(self.bn(x)) + a_ve = torch.concat([a, ve], dim=1) + v_ve = torch.concat([v, ve], dim=1) + a_v_ve = torch.concat([a, ve, v], dim=1) + x_a = self.a_ve_att[i](a_ve) + x_v = self.v_ve_att[i](v_ve) + x_a_v = self.ve_att[i](a_v_ve) + a_weight = self.sigmoid(x_a) + v_weight = self.sigmoid(x_v) + ve_weight = self.sigmoid(x_a_v) + if self.residual: + a = a + a * a_weight + v = v + v * v_weight + ve = ve + ve * ve_weight + else: + a = a * a_weight + v = v * v_weight + ve = ve * ve_weight + + out = torch.concat([a, ve, v], dim=1) + + if self.relu: + out = F.relu(out) + out = self.final(out) + return out + + +class AV_Cross_v2(nn.Module): + + def __init__(self, channels=2, r=2, residual=True, block=1, relu=False, kernel_size=1): + super(AV_Cross_v2, self).__init__() + out_channels = int(channels // r) + self.residual = residual + self.block = block + self.relu = relu + self.kernel_size = kernel_size + self.a_ve_att = nn.ModuleList() + self.v_ve_att = nn.ModuleList() + self.ve_att = nn.ModuleList() + for i in range(self.block): + self.a_ve_att.append(nn.Sequential( + nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1, + padding=(self.kernel_size - 1) // 2), + nn.BatchNorm2d(out_channels) + )) + self.v_ve_att.append(nn.Sequential( + nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1, + padding=(self.kernel_size - 1) // 2), + nn.BatchNorm2d(out_channels) + )) + self.ve_att.append(nn.Sequential( + nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1, + padding=(self.kernel_size - 1) // 2), + nn.BatchNorm2d(out_channels) + )) + + self.sigmoid = nn.Sigmoid() + self.final = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + a, ve, v = x[:, 0:1, :, :], x[:, 1:2, :, :], x[:, 2:, :, :] + + for i in range(self.block): + tmp = torch.cat([a, ve, v], dim=1) + a_ve = torch.concat([a, ve], dim=1) + a_ve = torch.cat([torch.max(a_ve, dim=1, keepdim=True)[0], torch.mean(a_ve, dim=1, keepdim=True)], dim=1) + v_ve = torch.concat([v, ve], dim=1) + v_ve = torch.cat([torch.max(v_ve, dim=1, keepdim=True)[0], torch.mean(v_ve, dim=1, keepdim=True)], dim=1) + a_v_ve = torch.concat([torch.max(tmp, dim=1, keepdim=True)[0], torch.mean(tmp, dim=1, keepdim=True)], dim=1) + + a_ve = self.a_ve_att[i](a_ve) + v_ve = self.v_ve_att[i](v_ve) + a_v_ve = self.ve_att[i](a_v_ve) + a_weight = self.sigmoid(a_ve) + v_weight = self.sigmoid(v_ve) + ve_weight = self.sigmoid(a_v_ve) + if self.residual: + a = a + a * a_weight + v = v + v * v_weight + ve = ve + ve * ve_weight + else: + a = a * a_weight + v = v * v_weight + ve = ve * ve_weight + + out = torch.concat([a, ve, v], dim=1) + + if self.relu: + out = F.relu(out) + out = self.final(out) + return out + + +class MultiHeadAttention(nn.Module): + def __init__(self, embedding_dim, head_num): + super().__init__() + + self.head_num = head_num + self.dk = (embedding_dim // head_num) ** (1 / 2) + + self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False) + self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False) + + def forward(self, x, mask=None): + qkv = self.qkv_layer(x) + + query, key, value = tuple(rearrange(qkv, 'b t (d k h ) -> k b h t d ', k=3, h=self.head_num)) + energy = torch.einsum("... i d , ... j d -> ... i j", query, key) * self.dk + + if mask is not None: + energy = energy.masked_fill(mask, -np.inf) + + attention = torch.softmax(energy, dim=-1) + + x = torch.einsum("... i j , ... j d -> ... i d", attention, value) + + x = rearrange(x, "b h t d -> b t (h d)") + x = self.out_attention(x) + + return x + + +class MLP(nn.Module): + def __init__(self, embedding_dim, mlp_dim): + super().__init__() + + self.mlp_layers = nn.Sequential( + nn.Linear(embedding_dim, mlp_dim), + nn.GELU(), + nn.Dropout(0.1), + nn.Linear(mlp_dim, embedding_dim), + nn.Dropout(0.1) + ) + + def forward(self, x): + x = self.mlp_layers(x) + + return x + + +class TransformerEncoderBlock(nn.Module): + def __init__(self, embedding_dim, head_num, mlp_dim): + super().__init__() + + self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num) + self.mlp = MLP(embedding_dim, mlp_dim) + + self.layer_norm1 = nn.LayerNorm(embedding_dim) + self.layer_norm2 = nn.LayerNorm(embedding_dim) + + self.dropout = nn.Dropout(0.1) + + def forward(self, x): + _x = self.multi_head_attention(x) + _x = self.dropout(_x) + x = x + _x + x = self.layer_norm1(x) + + _x = self.mlp(x) + x = x + _x + x = self.layer_norm2(x) + + return x + + +class TransformerEncoder(nn.Module): + """ + embedding_dim: token 向量长度 + head_num: 自注意力头 + block_num: transformer个数 + """ + + def __init__(self, embedding_dim, head_num, block_num=2): + super().__init__() + self.layer_blocks = nn.ModuleList( + [TransformerEncoderBlock(embedding_dim, head_num, 2 * embedding_dim) for _ in range(block_num)]) + + def forward(self, x): + for layer_block in self.layer_blocks: + x = layer_block(x) + return x + + +class PathEmbedding(nn.Module): + """ + img_dim: 输入图的大小 + in_channels: 输入的通道数 + embedding_dim: 每个token的向量长度 + patch_size:输入图token化,token的大小 + """ + + def __init__(self, img_dim, in_channels, embedding_dim, patch_size): + super().__init__() + + self.patch_size = patch_size + self.num_tokens = (img_dim // patch_size) ** 2 + self.token_dim = in_channels * (patch_size ** 2) + # 1. projection + self.projection = nn.Linear(self.token_dim, embedding_dim) + # 2. position embedding + self.embedding = nn.Parameter(torch.rand(self.num_tokens + 1, embedding_dim)) + # 3. cls token + self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) + + def forward(self, x): + img_patches = rearrange(x, + 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)', + patch_x=self.patch_size, patch_y=self.patch_size) + + batch_size, tokens_num, _ = img_patches.shape + + patch_token = self.projection(img_patches) + cls_token = repeat(self.cls_token, 'b ... -> (b batch_size) ...', + batch_size=batch_size) + + patches = torch.cat([cls_token, patch_token], dim=1) + # add postion embedding + patches += self.embedding[:tokens_num + 1, :] + + # B,tokens_num+1,embedding_dim + return patches + + +class TransformerBottleNeck(nn.Module): + def __init__(self, img_dim, in_channels, embedding_dim, head_num, + block_num, patch_size=1, classification=False, dropout=0.1, num_classes=1): + super().__init__() + self.patch_embedding = PathEmbedding(img_dim, in_channels, embedding_dim, patch_size) + self.transformer = TransformerEncoder(embedding_dim, head_num, block_num) + self.dropout = nn.Dropout(dropout) + self.classification = classification + if self.classification: + self.mlp_head = nn.Linear(embedding_dim, num_classes) + + def forward(self, x): + x = self.patch_embedding(x) + x = self.dropout(x) + x = self.transformer(x) + x = self.mlp_head(x[:, 0, :]) if self.classification else x[:, 1:, :] + return x + + +class PGFusion(nn.Module): + + def __init__(self, in_channel=384, out_channel=384): + + super(PGFusion, self).__init__() + + self.in_channel = in_channel + self.out_channel = out_channel + + self.patch_query = nn.Conv2d(in_channel, in_channel, kernel_size=1) + self.patch_key = nn.Conv2d(in_channel, in_channel, kernel_size=1) + self.patch_value = nn.Conv2d(in_channel, in_channel, kernel_size=1, bias=False) + self.patch_global_query = nn.Conv2d(in_channel, in_channel, kernel_size=1) + + self.global_key = nn.Conv2d(in_channel, in_channel, kernel_size=1) + self.global_value = nn.Conv2d(in_channel, in_channel, kernel_size=1, bias=False) + + self.fusion = nn.Conv2d(in_channel * 2, in_channel * 2, kernel_size=1) + + self.out_patch = nn.Conv2d(in_channel, out_channel, kernel_size=1) + self.out_global = nn.Conv2d(in_channel, out_channel, kernel_size=1) + + self.softmax = nn.Softmax(dim=2) + self.softmax_concat = nn.Softmax(dim=0) + + # self.gamma_patch_self = nn.Parameter(torch.zeros(1)) + # self.gamma_patch_global = nn.Parameter(torch.zeros(1)) + + self.init_parameters() + + def init_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): + nn.init.normal_(m.weight, 0, 0.01) + # nn.init.xavier_uniform_(m.weight.data) + if m.bias is not None: + nn.init.zeros_(m.bias) + # nn.init.constant_(m.bias, 0) + m.inited = True + + def forward(self, patch_rep, global_rep): + patch_rep_ = patch_rep.clone() + patch_value = self.patch_value(patch_rep) + patch_value = patch_value.view(patch_value.size(0), patch_value.size(1), -1) + patch_key = self.patch_key(patch_rep) + patch_key = patch_key.view(patch_key.size(0), patch_key.size(1), -1) + dim_k = patch_key.shape[-1] + patch_query = self.patch_query(patch_rep) + patch_query = patch_query.view(patch_query.size(0), patch_query.size(1), -1) + + patch_global_query = self.patch_global_query(patch_rep) + patch_global_query = patch_global_query.view(patch_global_query.size(0), patch_global_query.size(1), -1) + + global_value = self.global_value(global_rep) + global_value = global_value.view(global_value.size(0), global_value.size(1), -1) + global_key = self.global_key(global_rep) + global_key = global_key.view(global_key.size(0), global_key.size(1), -1) + + ### patch self attention + patch_self_sim_map = patch_query @ patch_key.transpose(-2, -1) / math.sqrt(dim_k) + patch_self_sim_map = self.softmax(patch_self_sim_map) + patch_self_sim_map = patch_self_sim_map @ patch_value + patch_self_sim_map = patch_self_sim_map.view(patch_self_sim_map.size(0), patch_self_sim_map.size(1), + *patch_rep.size()[2:]) + # patch_self_sim_map = self.gamma_patch_self * patch_self_sim_map + patch_self_sim_map = 1 * patch_self_sim_map + ### patch global attention + patch_global_sim_map = patch_global_query @ global_key.transpose(-2, -1) / math.sqrt(dim_k) + patch_global_sim_map = self.softmax(patch_global_sim_map) + patch_global_sim_map = patch_global_sim_map @ global_value + patch_global_sim_map = patch_global_sim_map.view(patch_global_sim_map.size(0), patch_global_sim_map.size(1), + *patch_rep.size()[2:]) + # patch_global_sim_map = self.gamma_patch_global * patch_global_sim_map + patch_global_sim_map = 1 * patch_global_sim_map + + fusion_sim_weight_map = torch.cat((patch_self_sim_map, patch_global_sim_map), dim=1) + fusion_sim_weight_map = self.fusion(fusion_sim_weight_map) + fusion_sim_weight_map = 1 * fusion_sim_weight_map + + patch_self_sim_weight_map = torch.split(fusion_sim_weight_map, dim=1, split_size_or_sections=self.in_channel)[0] + patch_self_sim_weight_map = torch.sigmoid(patch_self_sim_weight_map) # 0-1 + + patch_global_sim_weight_map = torch.split(fusion_sim_weight_map, dim=1, split_size_or_sections=self.in_channel)[ + 1] + patch_global_sim_weight_map = torch.sigmoid(patch_global_sim_weight_map) # 0-1 + + patch_self_sim_weight_map = torch.unsqueeze(patch_self_sim_weight_map, 0) + patch_global_sim_weight_map = torch.unsqueeze(patch_global_sim_weight_map, 0) + + ct = torch.concat((patch_self_sim_weight_map, patch_global_sim_weight_map), 0) + ct = self.softmax_concat(ct) + + out = patch_rep_ + patch_self_sim_map * ct[0] + patch_global_sim_map * (1 - ct[0]) + + return out + + +if __name__ == '__main__': + x = torch.randn((2, 384, 16, 16)) + m = PGFusion() + print(m) + # y = TransformerBottleNeck(x.shape[2],x.shape[1],x.shape[1],8,4) + print(m(x, x).shape) diff --git a/AV/models/network.py b/AV/models/network.py new file mode 100644 index 0000000000000000000000000000000000000000..33dc17dfe63f3e9e10aad683ae098e545465bcb5 --- /dev/null +++ b/AV/models/network.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +import torchvision.models +from torch import nn +import torch + +import torch.nn.functional as F +from AV.models.layers import * +from torchvision.models.convnext import convnext_tiny, ConvNeXt_Tiny_Weights +import numpy as np +import math +from torchvision import models +import copy + +class PGNet(nn.Module): + def __init__(self, input_ch=3, resnet='convnext_tiny', num_classes=3, use_cuda=False, pretrained=True,centerness=False, centerness_map_size=[128,128],use_global_semantic=False): + super(PGNet, self).__init__() + self.resnet = resnet + base_model = convnext_tiny + # layers = list(base_model(pretrained=pretrained,num_classes=num_classes,input_ch=input_ch).children())[:cut] + self.use_high_semantic = False + + cut = 6 + if pretrained: + layers = list(base_model(weights=ConvNeXt_Tiny_Weights.IMAGENET1K_V1).features)[:cut] + else: + layers = list(base_model().features)[:cut] + + base_layers = nn.Sequential(*layers) + self.use_global_semantic = use_global_semantic + ### global momentum + if self.use_global_semantic: + + self.pg_fusion = PGFusion() + self.base_layers_global_momentum = copy.deepcopy(base_layers) + set_requires_grad(self.base_layers_global_momentum,requires_grad=False) + + # self.stage = [SaveFeatures(base_layers[0][1])] # stage 1 c=96 + + self.stage = [] + self.stage.append(SaveFeatures(base_layers[0][1])) # stem c=96 + self.stage.append(SaveFeatures(base_layers[1][2])) # stage 1 c=96 + self.stage.append(SaveFeatures(base_layers[3][2])) # stage 2 c=192 + self.stage.append(SaveFeatures(base_layers[5][8])) # stage 3 c=384 + # self.stage.append(SaveFeatures(base_layers[7][2])) # stage 5 c=768 + + self.up2 = DBlock(384, 192) + self.up3 = DBlock(192, 96) + self.up4 = DBlock(96, 96) + + # final convolutional layers + # predict artery, vein and vessel + + self.seg_head = SegmentationHead(96, num_classes, 3, upsample=4) + + self.sn_unet = base_layers + self.num_classes = num_classes + + self.bn_out = nn.BatchNorm2d(3) + #self.av_cross = AV_Cross(block=4,kernel_size=1) + # use centerness block + self.centerness = centerness + + if self.centerness and centerness_map_size[0] == 128: + + # block 1 + self.cenBlock1 = [ + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ] + self.cenBlock1 = nn.Sequential(*self.cenBlock1) + + # centerness block + self.cenBlockMid = [ + nn.Conv2d(96, 48, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(48), + # nn.Conv2d(48, 48, kernel_size=3, padding=3, bias=False), + # nn.BatchNorm2d(48), + nn.Conv2d(48, 96, kernel_size=1, padding=0, bias=False), + ] + self.cenBlockMid = nn.Sequential(*self.cenBlockMid) + self.cenBlockFinal = [ + nn.BatchNorm2d(96), + nn.ReLU(inplace=True), + nn.Conv2d(96, 3, kernel_size=1, padding=0, bias=True), + nn.Sigmoid() + + ] + self.cenBlockFinal = nn.Sequential(*self.cenBlockFinal) + def forward(self, x,y=None): + + x = self.sn_unet(x) + global_rep = None + if self.use_global_semantic: + global_rep = self.base_layers_global_momentum(y) + x = self.pg_fusion(x,global_rep) + if len(x.shape) == 4 and x.shape[2] != x.shape[3]: + B, H, W, C = x.shape + x = x.permute(0, 3, 1, 2).contiguous() + elif len(x.shape) == 3: + B, L, C = x.shape + h = int(L ** 0.5) + x = x.view(B, h, h, C) + x = x.permute(0, 3, 1, 2).contiguous() + else: + x = x + + if self.use_high_semantic: + high_out = x.clone() + else: + high_out = x.clone() + if self.resnet == 'swin_t' or self.resnet == 'convnext_tiny': + # feature = self.stage[1:] + feature = self.stage[::-1] + # head = feature[0] + skip = feature[1:] + + # x = self.up1(x,skip[0].features) + x = self.up2(x, skip[0].features) + x = self.up3(x, skip[1].features) + x = self.up4(x, skip[2].features) + x_out = self.seg_head(x) + ######################## + # baseline output + # artery, vein and vessel + output = x_out.clone() + #av cross + #output = self.av_cross(output) + #output = F.relu(self.bn_out(output)) + # use centerness block + centerness_maps = None + if self.centerness: + + block1 = self.cenBlock1(self.stage[1].features) # [96,64] + _block1 = self.cenBlockMid(block1) # [96,64] + block1 = block1 + _block1 + blocks = [block1] + blocks = torch.cat(blocks, dim=1) + + # print("blocks", blocks.shape) + centerness_maps = self.cenBlockFinal(blocks) + # print("maps:", centerness_maps.shape) + + return output, centerness_maps + + + def forward_patch_rep(self, x): + patch_rep = self.sn_unet(x) + return patch_rep + + def forward_global_rep_momentum(self, x): + global_rep = self.base_layers_global_momentum(x) + return global_rep + def close(self): + for sf in self.stage: sf.remove() + + + + +def close(self): + for sf in self.stage: sf.remove() + + +# set requies_grad=Fasle to avoid computation + +def set_requires_grad(nets, requires_grad=False): + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + +pretrained_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False).view((1, 3, 1, 1)) +pretrained_std = torch.tensor([0.229, 0.224, 0.225], requires_grad=False).view((1, 3, 1, 1)) + +if __name__ == '__main__': + s = PGNet(input_ch=3, resnet='convnext_tiny',centerness=True, pretrained=False,use_global_semantic=False) + + + x = torch.randn(2, 3, 256, 256) + y,Y2 = s(x) + + + + print(y.shape) + print(Y2.shape) + + + + # pt = torch.load(r'F:\dw\MICCAI2023-STS-2D\segmentation\log\2023_07_25_18_10_10\G_0.pkl') + # print(pt) + # import torchvision.models as models + # m = models.vit_b_16(pretrained=False) + # print(m) + # m = resnet18() + # m_list = list(m.children()) + # def hook(module, input, output): + # print('fafafafgafa') + # print(input[0].shape) + # print(output[0].shape) + # m_list[0].register_forward_hook(hook) + # + # + # y = m(x) diff --git a/AV/models/networks_gan.py b/AV/models/networks_gan.py new file mode 100644 index 0000000000000000000000000000000000000000..eac2e45f31894494e9ee0dd0de3587becc4ca41a --- /dev/null +++ b/AV/models/networks_gan.py @@ -0,0 +1,1167 @@ +import torch +import torch.nn as nn +from torch.nn import init +import functools +from torch.optim import lr_scheduler + +############################################################################### +# Helper Functions +############################################################################### + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + elif norm_type == 'none': + norm_layer = None + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + +def get_scheduler(optimizer, opt): + if opt.lr_policy == 'lambda': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + +def init_weights(net, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) + init_weights(net, init_type, gain=init_gain) + return net + + +def calc_mean_std(feat, eps=1e-5): + # eps is a small value added to the variance to avoid divide-by-zero. + size = feat.size() + assert (len(size) == 4) + N, C = size[:2] + feat_var = feat.view(N, C, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(N, C, 1, 1) + feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return feat_mean, feat_std + + + +def affine_transformation(X, alpha, beta): + x = X.clone() + mean, std = calc_mean_std(x) + mean = mean.expand_as(x) + std = std.expand_as(x) + return alpha * ((x-mean)/std) + beta + + +############################################################################### +# Defining G/D +############################################################################### + +def define_G(input_nc, guide_nc, output_nc, ngf, netG, n_layers=8, n_downsampling=3, n_blocks=9, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): + net = None + norm_layer = get_norm_layer(norm_type=norm) + + if netG == 'bFT_resnet': + net = bFT_Resnet(input_nc, guide_nc, output_nc, ngf, norm_layer=norm_layer, n_blocks=n_blocks) + elif netG == 'bFT_unet': + net = bFT_Unet(input_nc, guide_nc, output_nc, n_layers, ngf, norm_layer=norm_layer) + elif netG == 'bFT_unet_cat': + net = bFT_Unet_cat(input_nc, guide_nc, output_nc, n_layers, ngf, norm_layer=norm_layer) + elif netG == 'uFT_unet': + net = uFT_Unet(input_nc, guide_nc, output_nc, n_layers, ngf, norm_layer=norm_layer) + elif netG == 'concat_Unet': + net = concat_Unet(input_nc, guide_nc, output_nc, n_layers, ngf, norm_layer=norm_layer) + else: + raise NotImplementedError('Generator model name [%s] is not recognized' % netG) + net = init_net(net, init_type, init_gain, gpu_ids) + + return net + +def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[], num_classes_D=1, use_noise=False, use_dropout=False): + net = None + norm_layer = get_norm_layer(norm_type=norm) + + if netD == 'basic': + net = NLayerDiscriminator(input_nc, ndf, n_layers=n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, num_classes_D=num_classes_D, use_noise=use_noise, use_dropout=use_dropout) + elif netD == 'n_layers': + net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid) + elif netD == 'pixel': + net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid) + else: + raise NotImplementedError('Discriminator model name [%s] is not recognized' % net) + return init_net(net, init_type, init_gain, gpu_ids) + + +############################################################################## +# Classes +############################################################################## + +class GANLoss(nn.Module): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(input) + + def __call__(self, input, target_is_real): + target_tensor = self.get_target_tensor(input, target_is_real) + return self.loss(input, target_tensor) + + +# Define a resnet block +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True)): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation) + def build_conv_block(self, dim, padding_type, norm_layer, activation): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim), + activation] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim)] + return nn.Sequential(*conv_block) + def forward(self, x): + out = x + self.conv_block(x) + return out + + +############################################################################## +# Discriminators +############################################################################## + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, num_classes_D=1, use_noise=False, use_dropout=False): + super(NLayerDiscriminator, self).__init__() + self.use_noise = use_noise + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + kw = 4 + padw = 1 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + if use_dropout: + sequence.append(nn.Dropout(p=0.2)) + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2**n, 16) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + if use_dropout: + sequence.append(nn.Dropout(p=0.2)) + # nf_mult_prev = nf_mult + # nf_mult = min(2**n_layers, 8) + # sequence += [ + # nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + # kernel_size=kw, stride=1, padding=padw, bias=use_bias), + # norm_layer(ndf * nf_mult), + # nn.LeakyReLU(0.2, True) + # ] + sequence += [nn.Conv2d(ndf * nf_mult, num_classes_D, kernel_size=3, stride=1, padding=padw)] + if use_sigmoid: + sequence += [nn.Sigmoid()] + self.model = nn.ModuleList(list(nn.Sequential(*sequence))) + def forward(self, input, layers=None): + input = input + torch.randn_like(input) if self.use_noise else input + #output = self.model(input) + output = input + results = [] + for ii, model in enumerate(self.model): + output = model(output) + if layers and ii in layers: + results.append(output.view(output.shape[0], -1)) + if layers == None: + return output.reshape([output.shape[0], output.shape[1], -1]) + return results + + +class PixelDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False): + super(PixelDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + self.net = [ + nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + if use_sigmoid: + self.net.append(nn.Sigmoid()) + self.net = nn.Sequential(*self.net) + def forward(self, input): + return self.net(input) + +############################################################################## +# Generators +############################################################################## + +class bFT_Unet(nn.Module): + def __init__(self, input_nc, guide_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, bottleneck_depth=100): + super(bFT_Unet, self).__init__() + + self.num_downs = num_downs + + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.downconv1 = nn.Sequential(*[nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.downconv2 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.downconv3 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.downconv4 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + + downconv = [] ## this has #(num_downs - 5) layers each with [relu-downconv-norm] + for i in range(num_downs - 5): + downconv += [nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)] + self.downconv = nn.Sequential(*downconv) + self.downconv5 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + + ### bottleneck ------ + + self.upconv1 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf * 8)]) + upconv = [] ## this has #(num_downs - 5) layers each with [relu-upconv-norm] + for i in range(num_downs - 5): + upconv += [nn.ReLU(True), nn.ConvTranspose2d(ngf * 8 * 2, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf * 8)] + self.upconv = nn.Sequential(*upconv) + self.upconv2 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 8 * 2, ngf * 4, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf * 4)]) + self.upconv3 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 4 * 2, ngf * 2, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf * 2)]) + self.upconv4 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 2 * 2, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf)]) + self.upconv5 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 2, output_nc, kernel_size=4, stride=2, padding=1)]) + #self.upconv5 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 2, output_nc, kernel_size=4, stride=2, padding=1), nn.Tanh()]) + + + ### guide downsampling + self.G_downconv1 = nn.Sequential(*[nn.Conv2d(guide_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.G_downconv2 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.G_downconv3 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.G_downconv4 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + G_downconv = [] ## this has #(num_downs - 5) layers each with [relu-downconv-norm] + for i in range(num_downs - 5): + G_downconv += [nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)] + self.G_downconv = nn.Sequential(*G_downconv) + + ### bottlenecks for param generation + self.bottleneck_alpha_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.bottleneck_beta_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.bottleneck_alpha_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.bottleneck_beta_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.bottleneck_alpha_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + self.bottleneck_beta_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + bottleneck_alpha = [] + bottleneck_beta = [] + for i in range(num_downs - 5): + bottleneck_alpha += self.bottleneck_layer(ngf * 8, bottleneck_depth) + bottleneck_beta += self.bottleneck_layer(ngf * 8, bottleneck_depth) + self.bottleneck_alpha = nn.Sequential(*bottleneck_alpha) + self.bottleneck_beta = nn.Sequential(*bottleneck_beta) + ### for guide + self.G_bottleneck_alpha_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.G_bottleneck_beta_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.G_bottleneck_alpha_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.G_bottleneck_beta_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.G_bottleneck_alpha_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + self.G_bottleneck_beta_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + G_bottleneck_alpha = [] + G_bottleneck_beta = [] + for i in range(num_downs - 5): + G_bottleneck_alpha += self.bottleneck_layer(ngf * 8, bottleneck_depth) + G_bottleneck_beta += self.bottleneck_layer(ngf * 8, bottleneck_depth) + self.G_bottleneck_alpha = nn.Sequential(*G_bottleneck_alpha) + self.G_bottleneck_beta = nn.Sequential(*G_bottleneck_beta) + + def bottleneck_layer(self, nc, bottleneck_depth): + return [nn.Conv2d(nc, bottleneck_depth, kernel_size=1), nn.ReLU(True), nn.Conv2d(bottleneck_depth, nc, kernel_size=1)] + + # per pixel + def get_FiLM_param_(self, X, i, guide=False): + x = X.clone() + # bottleneck + if guide: + if (i=='2'): + alpha_layer = self.G_bottleneck_alpha_2 + beta_layer = self.G_bottleneck_beta_2 + elif (i=='3'): + alpha_layer = self.G_bottleneck_alpha_3 + beta_layer = self.G_bottleneck_beta_3 + elif (i=='4'): + alpha_layer = self.G_bottleneck_alpha_4 + beta_layer = self.G_bottleneck_beta_4 + else: # a number i will be given to specify which bottleneck to use + alpha_layer = self.G_bottleneck_alpha[i:i+3] + beta_layer = self.G_bottleneck_beta[i:i+3] + else: + if (i=='2'): + alpha_layer = self.bottleneck_alpha_2 + beta_layer = self.bottleneck_beta_2 + elif (i=='3'): + alpha_layer = self.bottleneck_alpha_3 + + beta_layer = self.bottleneck_beta_3 + elif (i=='4'): + alpha_layer = self.bottleneck_alpha_4 + beta_layer = self.bottleneck_beta_4 + else: # a number i will be given to specify which bottleneck to use + alpha_layer = self.bottleneck_alpha[i:i+3] + beta_layer = self.bottleneck_beta[i:i+3] + + alpha = alpha_layer(x) + beta = beta_layer(x) + return alpha, beta + + def forward (self, input, guide): + ## downconv + down1 = self.downconv1(input) + G_down1 = self.G_downconv1(guide) + + down2 = self.downconv2(down1) + G_down2 = self.G_downconv2(G_down1) + + g_alpha2, g_beta2 = self.get_FiLM_param_(G_down2, '2', guide=True) + i_alpha2, i_beta2 = self.get_FiLM_param_(down2, '2') + down2 = affine_transformation(down2, g_alpha2, g_beta2) + G_down2 = affine_transformation(G_down2, i_alpha2, i_beta2) + + + down3 = self.downconv3(down2) + G_down3 = self.G_downconv3(G_down2) + + g_alpha3, g_beta3 = self.get_FiLM_param_(G_down3, '3', guide=True) + i_alpha3, i_beta3 = self.get_FiLM_param_(down3, '3') + down3 = affine_transformation(down3, g_alpha3, g_beta3) + G_down3 = affine_transformation(G_down3, i_alpha3, i_beta3) + + down4 = self.downconv4(down3) + G_down4 = self.G_downconv4(G_down3) + + g_alpha4, g_beta4 = self.get_FiLM_param_(G_down4, '4', guide=True) + i_alpha4, i_beta4 = self.get_FiLM_param_(down4, '4') + down4 = affine_transformation(down4, g_alpha4, g_beta4) + G_down4 = affine_transformation(G_down4, i_alpha4, i_beta4) + + ## (num_downs - 5) layers + down = [] + G_down = [] + for i in range(self.num_downs - 5): + layer = 2 * i + bottleneck_layer = 3 * i + downconv = self.downconv[layer:layer+2] + G_downconv = self.G_downconv[layer:layer+2] + if (layer == 0): + down += [downconv(down4)] + G_down += [G_downconv(G_down4)] + else: + down += [downconv(down[i-1])] + G_down += [G_downconv(G_down[i-1])] + + g_alpha, g_beta = self.get_FiLM_param_(G_down[i], bottleneck_layer, guide=True) + i_alpha, i_beta = self.get_FiLM_param_(down[i], bottleneck_layer) + down[i] = affine_transformation(down[i], g_alpha, g_beta) + G_down[i] = affine_transformation(G_down[i], i_alpha, i_beta) + + down5 = self.downconv5(down[-1]) + + ## concat and upconv + up = self.upconv1(down5) + num_down = self.num_downs - 5 + for i in range(self.num_downs - 5): + layer = 3 * i + upconv = self.upconv[layer:layer+3] + num_down -= 1 + up = upconv(torch.cat([down[num_down], up], 1)) + up = self.upconv2(torch.cat([down4,up],1)) + up = self.upconv3(torch.cat([down3,up],1)) + up = self.upconv4(torch.cat([down2,up],1)) + up = self.upconv5(torch.cat([down1,up],1)) + return up + +class bFT_Resnet(nn.Module): + def __init__(self, input_nc, guide_nc, output_nc, ngf=64, n_blocks=9, norm_layer=nn.BatchNorm2d, + padding_type='reflect', bottleneck_depth=100): + super(bFT_Resnet, self).__init__() + + self.activation = nn.ReLU(True) + + n_downsampling=3 + + ## input + padding_in = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0)] + self.padding_in = nn.Sequential(*padding_in) + self.conv1 = nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1) + self.conv2 = nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1) + self.conv3 = nn.Conv2d(ngf * 4, ngf * 8, kernel_size=3, stride=2, padding=1) + + ## guide + padding_g = [nn.ReflectionPad2d(3), nn.Conv2d(guide_nc, ngf, kernel_size=7, padding=0)] + self.padding_g = nn.Sequential(*padding_g) + self.conv1_g = nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1) + self.conv2_g = nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1) + self.conv3_g = nn.Conv2d(ngf * 4, ngf * 8, kernel_size=3, stride=2, padding=1) + + # bottleneck1 + self.bottleneck_alpha_1 = self.bottleneck_layer(ngf, bottleneck_depth) + self.G_bottleneck_alpha_1 = self.bottleneck_layer(ngf, bottleneck_depth) + self.bottleneck_beta_1 = self.bottleneck_layer(ngf, bottleneck_depth) + self.G_bottleneck_beta_1 = self.bottleneck_layer(ngf, bottleneck_depth) + # bottleneck2 + self.bottleneck_alpha_2 = self.bottleneck_layer(ngf*2, bottleneck_depth) + self.G_bottleneck_alpha_2 = self.bottleneck_layer(ngf*2, bottleneck_depth) + self.bottleneck_beta_2 = self.bottleneck_layer(ngf*2, bottleneck_depth) + self.G_bottleneck_beta_2 = self.bottleneck_layer(ngf*2, bottleneck_depth) + # bottleneck3 + self.bottleneck_alpha_3 = self.bottleneck_layer(ngf*4, bottleneck_depth) + self.G_bottleneck_alpha_3 = self.bottleneck_layer(ngf*4, bottleneck_depth) + self.bottleneck_beta_3 = self.bottleneck_layer(ngf*4, bottleneck_depth) + self.G_bottleneck_beta_3 = self.bottleneck_layer(ngf*4, bottleneck_depth) + # bottleneck4 + self.bottleneck_alpha_4 = self.bottleneck_layer(ngf*8, bottleneck_depth) + self.G_bottleneck_alpha_4 = self.bottleneck_layer(ngf*8, bottleneck_depth) + self.bottleneck_beta_4 = self.bottleneck_layer(ngf*8, bottleneck_depth) + self.G_bottleneck_beta_4 = self.bottleneck_layer(ngf*8, bottleneck_depth) + + resnet = [] + mult = 2**n_downsampling + for i in range(n_blocks): + resnet += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=self.activation, norm_layer=norm_layer)] + self.resnet = nn.Sequential(*resnet) + decoder = [] + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + decoder += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), + norm_layer(int(ngf * mult / 2)), self.activation] + self.pre_decoder = nn.Sequential(*decoder) + self.decoder = nn.Sequential(*[nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]) + + def bottleneck_layer(self, nc, bottleneck_depth): + return nn.Sequential(*[nn.Conv2d(nc, bottleneck_depth, kernel_size=1), self.activation, nn.Conv2d(bottleneck_depth, nc, kernel_size=1)]) + + def get_FiLM_param_(self, X, i, guide=False): + x = X.clone() + # bottleneck + if guide: + if (i==1): + alpha_layer = self.G_bottleneck_alpha_1 + beta_layer = self.G_bottleneck_beta_1 + elif (i==2): + alpha_layer = self.G_bottleneck_alpha_2 + beta_layer = self.G_bottleneck_beta_2 + elif (i==3): + alpha_layer = self.G_bottleneck_alpha_3 + beta_layer = self.G_bottleneck_beta_3 + elif (i==4): + alpha_layer = self.G_bottleneck_alpha_4 + beta_layer = self.G_bottleneck_beta_4 + else: + if (i==1): + alpha_layer = self.bottleneck_alpha_1 + beta_layer = self.bottleneck_beta_1 + elif (i==2): + alpha_layer = self.bottleneck_alpha_2 + beta_layer = self.bottleneck_beta_2 + elif (i==3): + alpha_layer = self.bottleneck_alpha_3 + beta_layer = self.bottleneck_beta_3 + elif (i==4): + alpha_layer = self.bottleneck_alpha_4 + beta_layer = self.bottleneck_beta_4 + alpha = alpha_layer(x) + beta = beta_layer(x) + return alpha, beta + + + def forward(self, input, guidance): + input = self.padding_in(input) + guidance = self.padding_g(guidance) + + g_alpha1, g_beta1 = self.get_FiLM_param_(guidance, 1, guide=True) + i_alpha1, i_beta1 = self.get_FiLM_param_(input, 1) + guidance = affine_transformation(guidance, i_alpha1, i_beta1) + input = affine_transformation(input, g_alpha1, g_beta1) + + input = self.activation(input) + guidance = self.activation(guidance) + + input = self.conv1(input) + guidance = self.conv1_g(guidance) + + g_alpha2, g_beta2 = self.get_FiLM_param_(guidance, 2, guide=True) + i_alpha2, i_beta2 = self.get_FiLM_param_(input, 2) + input = affine_transformation(input, g_alpha2, g_beta2) + guidance = affine_transformation(guidance, i_alpha2, i_beta2) + + input = self.activation(input) + guidance = self.activation(guidance) + + input = self.conv2(input) + guidance = self.conv2_g(guidance) + + g_alpha3, g_beta3 = self.get_FiLM_param_(guidance, 3, guide=True) + i_alpha3, i_beta3 = self.get_FiLM_param_(input, 3) + input = affine_transformation(input, g_alpha3, g_beta3) + guidance = affine_transformation(guidance, i_alpha3, i_beta3) + + input = self.activation(input) + guidance = self.activation(guidance) + + input = self.conv3(input) + guidance = self.conv3_g(guidance) + + g_alpha4, g_beta4 = self.get_FiLM_param_(guidance, 4, guide=True) + input = affine_transformation(input, g_alpha4, g_beta4) + + input = self.activation(input) + + input = self.resnet(input) + input = self.pre_decoder(input) + output = self.decoder(input) + return output + +class bFT_Unet_cat(nn.Module): + def __init__(self, input_nc, guide_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, bottleneck_depth=100): + super(bFT_Unet_cat, self).__init__() + + self.num_downs = num_downs + + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.downconv1 = nn.Sequential(*[nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.downconv2 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.downconv3 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.downconv4 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + + downconv = [] ## this has #(num_downs - 5) layers each with [relu-downconv-norm] + for i in range(num_downs - 5): + downconv += [nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)] + self.downconv = nn.Sequential(*downconv) + self.downconv5 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + + ### bottleneck ------ + + self.upconv1 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf * 8)]) + upconv = [] ## this has #(num_downs - 5) layers each with [relu-upconv-norm] + for i in range(num_downs - 5): + upconv += [nn.ReLU(True), nn.ConvTranspose2d(ngf * 8 * 2, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf * 8)] + self.upconv = nn.Sequential(*upconv) + self.upconv2 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 8 * 2, ngf * 4, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf * 4)]) + self.upconv3 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 4 * 2, ngf * 2, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf * 2)]) + self.upconv4 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 2 * 2, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf)]) + self.upconv5 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 2, output_nc, kernel_size=4, stride=2, padding=1), nn.Tanh()]) + + ### guide downsampling + self.G_downconv1 = nn.Sequential(*[nn.Conv2d(guide_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.G_downconv2 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.G_downconv3 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.G_downconv4 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + G_downconv = [] ## this has #(num_downs - 5) layers each with [relu-downconv-norm] + for i in range(num_downs - 5): + G_downconv += [nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)] + self.G_downconv = nn.Sequential(*G_downconv) + + ### bottlenecks for param generation + self.bottleneck_alpha_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.bottleneck_beta_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.bottleneck_alpha_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.bottleneck_beta_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.bottleneck_alpha_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + self.bottleneck_beta_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + bottleneck_alpha = [] + bottleneck_beta = [] + for i in range(num_downs - 5): + bottleneck_alpha += self.bottleneck_layer(ngf * 8, bottleneck_depth) + bottleneck_beta += self.bottleneck_layer(ngf * 8, bottleneck_depth) + self.bottleneck_alpha = nn.Sequential(*bottleneck_alpha) + self.bottleneck_beta = nn.Sequential(*bottleneck_beta) + ### for guide + self.G_bottleneck_alpha_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.G_bottleneck_beta_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.G_bottleneck_alpha_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.G_bottleneck_beta_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.G_bottleneck_alpha_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + self.G_bottleneck_beta_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + G_bottleneck_alpha = [] + G_bottleneck_beta = [] + for i in range(num_downs - 5): + G_bottleneck_alpha += self.bottleneck_layer(ngf * 8, bottleneck_depth) + G_bottleneck_beta += self.bottleneck_layer(ngf * 8, bottleneck_depth) + self.G_bottleneck_alpha = nn.Sequential(*G_bottleneck_alpha) + self.G_bottleneck_beta = nn.Sequential(*G_bottleneck_beta) + + def bottleneck_layer(self, nc, bottleneck_depth): + return [nn.Conv2d(nc, bottleneck_depth, kernel_size=1), nn.ReLU(True), nn.Conv2d(bottleneck_depth, nc, kernel_size=1)] + + # per pixel + def get_FiLM_param_(self, X, i, guide=False): + x = X.clone() + # bottleneck + if guide: + if (i=='2'): + alpha_layer = self.G_bottleneck_alpha_2 + beta_layer = self.G_bottleneck_beta_2 + elif (i=='3'): + alpha_layer = self.G_bottleneck_alpha_3 + beta_layer = self.G_bottleneck_beta_3 + elif (i=='4'): + alpha_layer = self.G_bottleneck_alpha_4 + beta_layer = self.G_bottleneck_beta_4 + else: # a number i will be given to specify which bottleneck to use + alpha_layer = self.G_bottleneck_alpha[i:i+3] + beta_layer = self.G_bottleneck_beta[i:i+3] + else: + if (i=='2'): + alpha_layer = self.bottleneck_alpha_2 + beta_layer = self.bottleneck_beta_2 + elif (i=='3'): + alpha_layer = self.bottleneck_alpha_3 + beta_layer = self.bottleneck_beta_3 + elif (i=='4'): + alpha_layer = self.bottleneck_alpha_4 + beta_layer = self.bottleneck_beta_4 + else: # a number i will be given to specify which bottleneck to use + alpha_layer = self.bottleneck_alpha[i:i+3] + beta_layer = self.bottleneck_beta[i:i+3] + + alpha = alpha_layer(x) + beta = beta_layer(x) + return alpha, beta + + def forward (self, input, guide): + ## downconv + down1 = self.downconv1(input) + G_down1 = self.G_downconv1(guide) + + down2 = self.downconv2(down1) + G_down2 = self.G_downconv2(G_down1) + + # g_alpha2, g_beta2 = self.get_FiLM_param_(G_down2, '2', guide=True) + # i_alpha2, i_beta2 = self.get_FiLM_param_(down2, '2') + # down2 = affine_transformation(down2, g_alpha2, g_beta2) + # G_down2 = affine_transformation(G_down2, i_alpha2, i_beta2) + + + down3 = self.downconv3(down2) + G_down3 = self.G_downconv3(G_down2) + + # g_alpha3, g_beta3 = self.get_FiLM_param_(G_down3, '3', guide=True) + # i_alpha3, i_beta3 = self.get_FiLM_param_(down3, '3') + # down3 = affine_transformation(down3, g_alpha3, g_beta3) + # G_down3 = affine_transformation(G_down3, i_alpha3, i_beta3) + + down4 = self.downconv4(down3) + G_down4 = self.G_downconv4(G_down3) + + # g_alpha4, g_beta4 = self.get_FiLM_param_(G_down4, '4', guide=True) + # i_alpha4, i_beta4 = self.get_FiLM_param_(down4, '4') + # down4 = affine_transformation(down4, g_alpha4, g_beta4) + # G_down4 = affine_transformation(G_down4, i_alpha4, i_beta4) + + ## (num_downs - 5) layers + down = [] + G_down = [] + for i in range(self.num_downs - 5): + layer = 2 * i + bottleneck_layer = 3 * i + downconv = self.downconv[layer:layer+2] + G_downconv = self.G_downconv[layer:layer+2] + if (layer == 0): + down += [downconv(down4)] + G_down += [G_downconv(G_down4)] + else: + down += [downconv(down[i-1])] + G_down += [G_downconv(G_down[i-1])] + + # g_alpha, g_beta = self.get_FiLM_param_(G_down[i], bottleneck_layer, guide=True) + # i_alpha, i_beta = self.get_FiLM_param_(down[i], bottleneck_layer) + # down[i] = affine_transformation(down[i], g_alpha, g_beta) + # G_down[i] = affine_transformation(G_down[i], i_alpha, i_beta) + + down5 = self.downconv5(down[-1]) + + ## concat and upconv + up = self.upconv1(down5) + num_down = self.num_downs - 5 + for i in range(self.num_downs - 5): + layer = 3 * i + upconv = self.upconv[layer:layer+3] + num_down -= 1 + up = upconv(torch.cat([down[num_down], up], 1)) + up = self.upconv2(torch.cat([down4,up],1)) + up = self.upconv3(torch.cat([down3,up],1)) + up = self.upconv4(torch.cat([down2,up],1)) + up = self.upconv5(torch.cat([down1,up],1)) + return up + + +class uFT_Unet(nn.Module): + def __init__(self, input_nc, guide_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, bottleneck_depth=100): + super(uFT_Unet, self).__init__() + + self.num_downs = num_downs + + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.downconv1 = nn.Sequential(*[nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.downconv2 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.downconv3 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.downconv4 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + + downconv = [] ## this has #(num_downs - 5) layers each with [relu-downconv-norm] + for i in range(num_downs - 5): + downconv += [nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)] + self.downconv = nn.Sequential(*downconv) + self.downconv5 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + + ### bottleneck ------ + + self.upconv1 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf * 8)]) + upconv = [] ## this has #(num_downs - 5) layers each with [relu-upconv-norm] + for i in range(num_downs - 5): + upconv += [nn.ReLU(True), nn.ConvTranspose2d(ngf * 8 * 2, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf * 8)] + self.upconv = nn.Sequential(*upconv) + self.upconv2 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 8 * 2, ngf * 4, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf * 4)]) + self.upconv3 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 4 * 2, ngf * 2, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf * 2)]) + self.upconv4 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 2 * 2, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias), norm_layer(ngf)]) + self.upconv5 = nn.Sequential(*[nn.ReLU(True), nn.ConvTranspose2d(ngf * 2, output_nc, kernel_size=4, stride=2, padding=1), nn.Tanh()]) + + + ### guide downsampling + self.G_downconv1 = nn.Sequential(*[nn.Conv2d(guide_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.G_downconv2 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.G_downconv3 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.G_downconv4 = nn.Sequential(*[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + G_downconv = [] ## this has #(num_downs - 5) layers each with [relu-downconv-norm] + for i in range(num_downs - 5): + G_downconv += [nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)] + self.G_downconv = nn.Sequential(*G_downconv) + + ### bottlenecks for param generation + self.bottleneck_alpha_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.bottleneck_beta_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.bottleneck_alpha_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.bottleneck_beta_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.bottleneck_alpha_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + self.bottleneck_beta_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + bottleneck_alpha = [] + bottleneck_beta = [] + for i in range(num_downs - 5): + bottleneck_alpha += self.bottleneck_layer(ngf * 8, bottleneck_depth) + bottleneck_beta += self.bottleneck_layer(ngf * 8, bottleneck_depth) + self.bottleneck_alpha = nn.Sequential(*bottleneck_alpha) + self.bottleneck_beta = nn.Sequential(*bottleneck_beta) + ### for guide + self.G_bottleneck_alpha_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.G_bottleneck_beta_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.G_bottleneck_alpha_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.G_bottleneck_beta_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.G_bottleneck_alpha_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + self.G_bottleneck_beta_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + G_bottleneck_alpha = [] + G_bottleneck_beta = [] + for i in range(num_downs - 5): + G_bottleneck_alpha += self.bottleneck_layer(ngf * 8, bottleneck_depth) + G_bottleneck_beta += self.bottleneck_layer(ngf * 8, bottleneck_depth) + self.G_bottleneck_alpha = nn.Sequential(*G_bottleneck_alpha) + self.G_bottleneck_beta = nn.Sequential(*G_bottleneck_beta) + + def bottleneck_layer(self, nc, bottleneck_depth): + return [nn.Conv2d(nc, bottleneck_depth, kernel_size=1), nn.ReLU(True), nn.Conv2d(bottleneck_depth, nc, kernel_size=1)] + + # per pixel + def get_FiLM_param_(self, X, i, guide=False): + x = X.clone() + # bottleneck + if guide: + if (i=='2'): + alpha_layer = self.G_bottleneck_alpha_2 + beta_layer = self.G_bottleneck_beta_2 + elif (i=='3'): + alpha_layer = self.G_bottleneck_alpha_3 + beta_layer = self.G_bottleneck_beta_3 + elif (i=='4'): + alpha_layer = self.G_bottleneck_alpha_4 + beta_layer = self.G_bottleneck_beta_4 + else: # a number i will be given to specify which bottleneck to use + alpha_layer = self.G_bottleneck_alpha[i:i+3] + beta_layer = self.G_bottleneck_beta[i:i+3] + else: + if (i=='2'): + alpha_layer = self.bottleneck_alpha_2 + beta_layer = self.bottleneck_beta_2 + elif (i=='3'): + alpha_layer = self.bottleneck_alpha_3 + beta_layer = self.bottleneck_beta_3 + elif (i=='4'): + alpha_layer = self.bottleneck_alpha_4 + beta_layer = self.bottleneck_beta_4 + else: # a number i will be given to specify which bottleneck to use + alpha_layer = self.bottleneck_alpha[i:i+3] + beta_layer = self.bottleneck_beta[i:i+3] + + alpha = alpha_layer(x) + beta = beta_layer(x) + return alpha, beta + + def forward (self, input, guide): + ## downconv + down1 = self.downconv1(input) + G_down1 = self.G_downconv1(guide) + + down2 = self.downconv2(down1) + G_down2 = self.G_downconv2(G_down1) + + g_alpha2, g_beta2 = self.get_FiLM_param_(G_down2, '2', guide=True) + #i_alpha2, i_beta2 = self.get_FiLM_param_(down2, '2') + down2 = affine_transformation(down2, g_alpha2, g_beta2) + #G_down2 = affine_transformation(G_down2, i_alpha2, i_beta2) + + + down3 = self.downconv3(down2) + G_down3 = self.G_downconv3(G_down2) + + g_alpha3, g_beta3 = self.get_FiLM_param_(G_down3, '3', guide=True) + #i_alpha3, i_beta3 = self.get_FiLM_param_(down3, '3') + down3 = affine_transformation(down3, g_alpha3, g_beta3) + #G_down3 = affine_transformation(G_down3, i_alpha3, i_beta3) + + down4 = self.downconv4(down3) + G_down4 = self.G_downconv4(G_down3) + + g_alpha4, g_beta4 = self.get_FiLM_param_(G_down4, '4', guide=True) + #i_alpha4, i_beta4 = self.get_FiLM_param_(down4, '4') + down4 = affine_transformation(down4, g_alpha4, g_beta4) + #G_down4 = affine_transformation(G_down4, i_alpha4, i_beta4) + + ## (num_downs - 5) layers + down = [] + G_down = [] + for i in range(self.num_downs - 5): + layer = 2 * i + bottleneck_layer = 3 * i + downconv = self.downconv[layer:layer+2] + G_downconv = self.G_downconv[layer:layer+2] + if (layer == 0): + down += [downconv(down4)] + G_down += [G_downconv(G_down4)] + else: + down += [downconv(down[i-1])] + G_down += [G_downconv(G_down[i-1])] + + g_alpha, g_beta = self.get_FiLM_param_(G_down[i], bottleneck_layer, guide=True) + #i_alpha, i_beta = self.get_FiLM_param_(down[i], bottleneck_layer) + down[i] = affine_transformation(down[i], g_alpha, g_beta) + #G_down[i] = affine_transformation(G_down[i], i_alpha, i_beta) + + down5 = self.downconv5(down[-1]) + + ## concat and upconv + up = self.upconv1(down5) + num_down = self.num_downs - 5 + for i in range(self.num_downs - 5): + layer = 3 * i + upconv = self.upconv[layer:layer+3] + num_down -= 1 + up = upconv(torch.cat([down[num_down], up], 1)) + up = self.upconv2(torch.cat([down4,up],1)) + up = self.upconv3(torch.cat([down3,up],1)) + up = self.upconv4(torch.cat([down2,up],1)) + up = self.upconv5(torch.cat([down1,up],1)) + return up + +# concat input and guide image + +class concat_Unet(nn.Module): + def __init__(self, input_nc, guide_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, + bottleneck_depth=100): + super(concat_Unet, self).__init__() + + self.num_downs = num_downs + + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.downconv1 = nn.Sequential(*[nn.Conv2d(input_nc+guide_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.downconv2 = nn.Sequential( + *[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.downconv3 = nn.Sequential( + *[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.downconv4 = nn.Sequential( + *[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + + downconv = [] ## this has #(num_downs - 5) layers each with [relu-downconv-norm] + for i in range(num_downs - 5): + downconv += [nn.LeakyReLU(0.2, True), + nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)] + self.downconv = nn.Sequential(*downconv) + self.downconv5 = nn.Sequential( + *[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + + ### bottleneck ------ + + self.upconv1 = nn.Sequential( + *[nn.ReLU(True), nn.ConvTranspose2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias), + norm_layer(ngf * 8)]) + upconv = [] ## this has #(num_downs - 5) layers each with [relu-upconv-norm] + for i in range(num_downs - 5): + upconv += [nn.ReLU(True), + nn.ConvTranspose2d(ngf * 8 * 2, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias), + norm_layer(ngf * 8)] + self.upconv = nn.Sequential(*upconv) + self.upconv2 = nn.Sequential(*[nn.ReLU(True), + nn.ConvTranspose2d(ngf * 8 * 2, ngf * 4, kernel_size=4, stride=2, padding=1, + bias=use_bias), norm_layer(ngf * 4)]) + self.upconv3 = nn.Sequential(*[nn.ReLU(True), + nn.ConvTranspose2d(ngf * 4 * 2, ngf * 2, kernel_size=4, stride=2, padding=1, + bias=use_bias), norm_layer(ngf * 2)]) + self.upconv4 = nn.Sequential( + *[nn.ReLU(True), nn.ConvTranspose2d(ngf * 2 * 2, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias), + norm_layer(ngf)]) + self.upconv5 = nn.Sequential( + *[nn.ReLU(True), nn.ConvTranspose2d(ngf * 2, output_nc, kernel_size=4, stride=2, padding=1), nn.Tanh()]) + + ### guide downsampling + self.G_downconv1 = nn.Sequential(*[nn.Conv2d(guide_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.G_downconv2 = nn.Sequential( + *[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.G_downconv3 = nn.Sequential( + *[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + self.G_downconv4 = nn.Sequential( + *[nn.LeakyReLU(0.2, True), nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)]) + G_downconv = [] ## this has #(num_downs - 5) layers each with [relu-downconv-norm] + for i in range(num_downs - 5): + G_downconv += [nn.LeakyReLU(0.2, True), + nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1, bias=use_bias)] + self.G_downconv = nn.Sequential(*G_downconv) + + ### bottlenecks for param generation + self.bottleneck_alpha_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.bottleneck_beta_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.bottleneck_alpha_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.bottleneck_beta_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.bottleneck_alpha_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + self.bottleneck_beta_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + bottleneck_alpha = [] + bottleneck_beta = [] + for i in range(num_downs - 5): + bottleneck_alpha += self.bottleneck_layer(ngf * 8, bottleneck_depth) + bottleneck_beta += self.bottleneck_layer(ngf * 8, bottleneck_depth) + self.bottleneck_alpha = nn.Sequential(*bottleneck_alpha) + self.bottleneck_beta = nn.Sequential(*bottleneck_beta) + ### for guide + self.G_bottleneck_alpha_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.G_bottleneck_beta_2 = nn.Sequential(*self.bottleneck_layer(ngf * 2, bottleneck_depth)) + self.G_bottleneck_alpha_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.G_bottleneck_beta_3 = nn.Sequential(*self.bottleneck_layer(ngf * 4, bottleneck_depth)) + self.G_bottleneck_alpha_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + self.G_bottleneck_beta_4 = nn.Sequential(*self.bottleneck_layer(ngf * 8, bottleneck_depth)) + G_bottleneck_alpha = [] + G_bottleneck_beta = [] + for i in range(num_downs - 5): + G_bottleneck_alpha += self.bottleneck_layer(ngf * 8, bottleneck_depth) + G_bottleneck_beta += self.bottleneck_layer(ngf * 8, bottleneck_depth) + self.G_bottleneck_alpha = nn.Sequential(*G_bottleneck_alpha) + self.G_bottleneck_beta = nn.Sequential(*G_bottleneck_beta) + + def bottleneck_layer(self, nc, bottleneck_depth): + return [nn.Conv2d(nc, bottleneck_depth, kernel_size=1), nn.ReLU(True), + nn.Conv2d(bottleneck_depth, nc, kernel_size=1)] + + # per pixel + def get_FiLM_param_(self, X, i, guide=False): + x = X.clone() + # bottleneck + if guide: + if (i == '2'): + alpha_layer = self.G_bottleneck_alpha_2 + beta_layer = self.G_bottleneck_beta_2 + elif (i == '3'): + alpha_layer = self.G_bottleneck_alpha_3 + beta_layer = self.G_bottleneck_beta_3 + elif (i == '4'): + alpha_layer = self.G_bottleneck_alpha_4 + beta_layer = self.G_bottleneck_beta_4 + else: # a number i will be given to specify which bottleneck to use + alpha_layer = self.G_bottleneck_alpha[i:i + 3] + beta_layer = self.G_bottleneck_beta[i:i + 3] + else: + if (i == '2'): + alpha_layer = self.bottleneck_alpha_2 + beta_layer = self.bottleneck_beta_2 + elif (i == '3'): + alpha_layer = self.bottleneck_alpha_3 + + beta_layer = self.bottleneck_beta_3 + elif (i == '4'): + alpha_layer = self.bottleneck_alpha_4 + beta_layer = self.bottleneck_beta_4 + else: # a number i will be given to specify which bottleneck to use + alpha_layer = self.bottleneck_alpha[i:i + 3] + beta_layer = self.bottleneck_beta[i:i + 3] + + alpha = alpha_layer(x) + beta = beta_layer(x) + return alpha, beta + + def forward(self, input, guide): + ## downconv + input = torch.cat((input, guide), dim=1) + down1 = self.downconv1(input) + #G_down1 = self.G_downconv1(guide) + + down2 = self.downconv2(down1) + #G_down2 = self.G_downconv2(G_down1) + + #g_alpha2, g_beta2 = self.get_FiLM_param_(G_down2, '2', guide=True) + #i_alpha2, i_beta2 = self.get_FiLM_param_(down2, '2') + #down2 = affine_transformation(down2, g_alpha2, g_beta2) + #G_down2 = affine_transformation(G_down2, i_alpha2, i_beta2) + + down3 = self.downconv3(down2) + #G_down3 = self.G_downconv3(G_down2) + + #g_alpha3, g_beta3 = self.get_FiLM_param_(G_down3, '3', guide=True) + #i_alpha3, i_beta3 = self.get_FiLM_param_(down3, '3') + #down3 = affine_transformation(down3, g_alpha3, g_beta3) + #G_down3 = affine_transformation(G_down3, i_alpha3, i_beta3) + + down4 = self.downconv4(down3) + #G_down4 = self.G_downconv4(G_down3) + + #g_alpha4, g_beta4 = self.get_FiLM_param_(G_down4, '4', guide=True) + #i_alpha4, i_beta4 = self.get_FiLM_param_(down4, '4') + #down4 = affine_transformation(down4, g_alpha4, g_beta4) + #G_down4 = affine_transformation(G_down4, i_alpha4, i_beta4) + + ## (num_downs - 5) layers + down = [] + #G_down = [] + for i in range(self.num_downs - 5): + layer = 2 * i + bottleneck_layer = 3 * i + downconv = self.downconv[layer:layer + 2] + G_downconv = self.G_downconv[layer:layer + 2] + if (layer == 0): + down += [downconv(down4)] + #G_down += [G_downconv(G_down4)] + else: + down += [downconv(down[i - 1])] + #G_down += [G_downconv(G_down[i - 1])] + + #g_alpha, g_beta = self.get_FiLM_param_(G_down[i], bottleneck_layer, guide=True) + #i_alpha, i_beta = self.get_FiLM_param_(down[i], bottleneck_layer) + #down[i] = affine_transformation(down[i], g_alpha, g_beta) + #G_down[i] = affine_transformation(G_down[i], i_alpha, i_beta) + + down5 = self.downconv5(down[-1]) + + ## concat and upconv + up = self.upconv1(down5) + num_down = self.num_downs - 5 + for i in range(self.num_downs - 5): + layer = 3 * i + upconv = self.upconv[layer:layer + 3] + num_down -= 1 + up = upconv(torch.cat([down[num_down], up], 1)) + up = self.upconv2(torch.cat([down4, up], 1)) + up = self.upconv3(torch.cat([down3, up], 1)) + up = self.upconv4(torch.cat([down2, up], 1)) + up = self.upconv5(torch.cat([down1, up], 1)) + return up + +if __name__ == '__main__': + model = NLayerDiscriminator(input_nc=3, ndf=64,num_classes_D=1, n_layers=3, norm_layer=nn.BatchNorm2d) + x = torch.randn(1, 3, 256, 256) + o = model(x) + label_shape = [1,1, o.shape[2]] + # 0, 1 + label_real = torch.zeros(label_shape) + label_fake = torch.ones(label_shape) + print(label_real.shape) + + k = nn.BCEWithLogitsLoss()(o, label_real) + print(k) diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..9ae691f60652f8977c2d151d84fb05a7ab34c408 --- /dev/null +++ b/app.py @@ -0,0 +1,384 @@ +import torch +import gradio as gr +from PIL import Image +import cv2 +from AV.models.network import PGNet +from AV.Tools.AVclassifiation import AVclassifiation +from AV.Tools.utils_test import paint_border_overlap, extract_ordered_overlap_big, Normalize, sigmoid, recompone_overlap, \ + kill_border +from AV.config import config_test_general as cfg +import torch.autograd as autograd +import numpy as np +import os +from datetime import datetime + +def creatMask(Image, threshold=5): + ##This program try to creat the mask for the filed-of-view + ##Input original image (RGB or green channel), threshold (user set parameter, default 10) + ##Output: the filed-of-view mask + + if len(Image.shape) == 3: ##RGB image + gray = cv2.cvtColor(Image, cv2.COLOR_BGR2GRAY) + Mask0 = gray >= threshold + + else: # for green channel image + Mask0 = Image >= threshold + + # ######get the largest blob, this takes 0.18s + cvVersion = int(cv2.__version__.split('.')[0]) + + Mask0 = np.uint8(Mask0) + + contours, hierarchy = cv2.findContours(Mask0, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + + areas = [cv2.contourArea(c) for c in contours] + max_index = np.argmax(areas) + Mask = np.zeros(Image.shape[:2], dtype=np.uint8) + cv2.drawContours(Mask, contours, max_index, 1, -1) + + ResultImg = Image.copy() + if len(Image.shape) == 3: + ResultImg[Mask == 0] = (255, 255, 255) + else: + ResultImg[Mask == 0] = 255 + Mask[Mask > 0] = 255 + kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) + Mask = cv2.morphologyEx(Mask, cv2.MORPH_OPEN, kernel, iterations=3) + return ResultImg, Mask + + +def shift_rgb(img, *args): + result_img = np.empty_like(img) + shifts = args + max_value = 255 + # print(shifts) + for i, shift in enumerate(shifts): + lut = np.arange(0, max_value + 1).astype("float32") + lut += shift + + lut = np.clip(lut, 0, max_value).astype(img.dtype) + if len(img.shape) == 2: + print(f'=========grey image=======') + result_img = cv2.LUT(img, lut) + else: + result_img[..., i] = cv2.LUT(img[..., i], lut) + + return result_img + + +def CAM(x, img_path, rate=0.8, ind=0): + """ + :param dataset_path: 计算整个训练数据集的平均RGB通道值 + :param image: array, 单张图片的array 形式 + :return: array形式的cam后的结果 + """ + # 每次使用新数据集时都需要重新计算前面的RBG平均值 + # RGB-->Rshift-->CLAHE + + x = np.uint8(x) + _, Mask0 = creatMask(x, threshold=10) + Mask = np.zeros((x.shape[0], x.shape[1]), np.float32) + Mask[Mask0 > 0] = 1 + + resize = False + R_mea_num, G_mea_num, B_mea_num = [], [], [] + + dataset_path = img_path + image = np.array(Image.open(dataset_path)) + R_mea_num.append(np.mean(image[:, :, 0])) + G_mea_num.append(np.mean(image[:, :, 1])) + B_mea_num.append(np.mean(image[:, :, 2])) + + mea2stand = int((np.mean(R_mea_num) - np.mean(x[:, :, 0])) * rate) + mea2standg = int((np.mean(G_mea_num) - np.mean(x[:, :, 1])) * rate) + mea2standb = int((np.mean(B_mea_num) - np.mean(x[:, :, 2])) * rate) + + y = shift_rgb(x, mea2stand, mea2standg, mea2standb) + + y[Mask == 0, :] = 0 + + return y + + +def modelEvalution_out_big(net, use_cuda=False, dataset='', is_kill_border=True, input_ch=3, + config=None, output_dir='', evaluate_metrics=False): + # path for images to save + n_classes = 3 + Net = PGNet(use_global_semantic=config.use_global_semantic, input_ch=input_ch, + num_classes=n_classes, use_cuda=use_cuda, pretrained=False, centerness=config.use_centerness, + centerness_map_size=config.centerness_map_size) + msg = Net.load_state_dict(net, strict=False) + + if use_cuda: + Net.cuda() + Net.eval() + + image_basename = dataset + + # if not os.path.exists(output_dir): + # os.makedirs(output_dir) + + step = 1 + # every step of between star and end for loop until len(image_basename) + + # for start_end in start_end_list: + image0 = cv2.imread(image_basename) + test_image_height = image0.shape[0] + test_image_width = image0.shape[1] + + if config.use_resize: + + if min(test_image_height, test_image_width) <= 256: + scaling = 512 / min(test_image_height, test_image_width) + new_width = int(test_image_width * scaling) + new_height = int(test_image_height * scaling) + test_image_width, test_image_height = new_width, new_height + + # 大尺寸处理:确保最长边≤1536 + elif max(test_image_height, test_image_width) >= 2048: + scaling = 2048 / max(test_image_height, test_image_width) + new_width = int(test_image_width * scaling) + new_height = int(test_image_height * scaling) + test_image_width, test_image_height = new_width, new_height + + ArteryPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) + VeinPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) + VesselPredAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) + ProMap = np.zeros((1, 3, test_image_height, test_image_width), np.float32) + MaskAll = np.zeros((1, 1, test_image_height, test_image_width), np.float32) + ArteryPred, VeinPred, VesselPred, Mask, LabelArtery, LabelVein, LabelVessel = GetResult_out_big(Net, 0, + use_cuda=use_cuda, + dataset=image_basename, + is_kill_border=is_kill_border, + config=config, + resize_w_h=( + test_image_width, + test_image_height) + ) + ArteryPredAll[0 % step, :, :, :] = ArteryPred + VeinPredAll[0 % step, :, :, :] = VeinPred + VesselPredAll[0 % step, :, :, :] = VesselPred + + MaskAll[0 % step, :, :, :] = Mask + + image_color = AVclassifiation(output_dir, ArteryPredAll, VeinPredAll, VesselPredAll, 1, image_basename) + + return image_color + + +def GetResult_out_big(Net, k, use_cuda=False, dataset='', is_kill_border=False, config=None, + resize_w_h=None): + ImgName = dataset + Img0 = cv2.imread(ImgName) + + _, Mask0 = creatMask(Img0, threshold=-1) + Mask = np.zeros((Img0.shape[0], Img0.shape[1]), np.float32) + Mask[Mask0 > 0] = 1 + + if config.use_resize: + Img0 = cv2.resize(Img0, resize_w_h) + Mask = cv2.resize(Mask, resize_w_h, interpolation=cv2.INTER_NEAREST) + + Img = Img0 + height, width = Img.shape[:2] + n_classes = 3 + patch_height = config.patch_size + patch_width = config.patch_size + stride_height = config.stride_height + stride_width = config.stride_width + + Img = cv2.cvtColor(Img, cv2.COLOR_BGR2RGB) + if cfg.dataset == 'all': + # # # 将图像转换为 LAB 颜色空间 + lab = cv2.cvtColor(Img, cv2.COLOR_RGB2LAB) + + # 拆分 LAB 通道 + l, a, b = cv2.split(lab) + + # 创建 CLAHE 对象并应用到 L 通道 + clahe = cv2.createCLAHE(clipLimit=2, tileGridSize=(8, 8)) + l_clahe = clahe.apply(l) + + # 将 CLAHE 处理后的 L 通道与原始的 A 和 B 通道合并 + lab_clahe = cv2.merge((l_clahe, a, b)) + + # 将图像转换回 BGR 颜色空间 + Img = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB) + + if cfg.use_CAM: + Img = CAM(Img, dataset) + + Img = np.float32(Img / 255.) + Img_enlarged = paint_border_overlap(Img, patch_height, patch_width, stride_height, stride_width) + patch_size = config.patch_size + batch_size = 2 + patches_imgs, global_images = extract_ordered_overlap_big(Img_enlarged, patch_height, patch_width, + stride_height, + stride_width) + + patches_imgs = np.transpose(patches_imgs, (0, 3, 1, 2)) + patches_imgs = Normalize(patches_imgs) + global_images = np.transpose(global_images, (0, 3, 1, 2)) + global_images = Normalize(global_images) + patchNum = patches_imgs.shape[0] + max_iter = int(np.ceil(patchNum / float(batch_size))) + + pred_patches = np.zeros((patchNum, n_classes, patch_size, patch_size), np.float32) + + for i in range(max_iter): + begin_index = i * batch_size + end_index = (i + 1) * batch_size + + patches_temp1 = patches_imgs[begin_index:end_index, :, :, :] + + patches_input_temp1 = torch.FloatTensor(patches_temp1) + global_input_temp1 = patches_input_temp1 + if config.use_global_semantic: + global_temp1 = global_images[begin_index:end_index, :, :, :] + global_input_temp1 = torch.FloatTensor(global_temp1) + if use_cuda: + patches_input_temp1 = autograd.Variable(patches_input_temp1.cuda()) + if config.use_global_semantic: + global_input_temp1 = autograd.Variable(global_input_temp1.cuda()) + else: + patches_input_temp1 = autograd.Variable(patches_input_temp1) + if config.use_global_semantic: + global_input_temp1 = autograd.Variable(global_input_temp1) + + output_temp, _1, = Net(patches_input_temp1, global_input_temp1) + + pred_patches_temp = np.float32(output_temp.data.cpu().numpy()) + + pred_patches_temp_sigmoid = sigmoid(pred_patches_temp) + + pred_patches[begin_index:end_index, :, :, :] = pred_patches_temp_sigmoid[:, :, :patch_size, :patch_size] + + del patches_input_temp1 + del pred_patches_temp + del patches_temp1 + del output_temp + del pred_patches_temp_sigmoid + + new_height, new_width = Img_enlarged.shape[0], Img_enlarged.shape[1] + + pred_img = recompone_overlap(pred_patches, new_height, new_width, stride_height, stride_width) # predictions + pred_img = pred_img[:, 0:height, 0:width] + + if is_kill_border: + pred_img = kill_border(pred_img, Mask) + + ArteryPred = np.float32(pred_img[0, :, :]) + VeinPred = np.float32(pred_img[2, :, :]) + VesselPred = np.float32(pred_img[1, :, :]) + + ArteryPred = ArteryPred[np.newaxis, :, :] + VeinPred = VeinPred[np.newaxis, :, :] + VesselPred = VesselPred[np.newaxis, :, :] + Mask = Mask[np.newaxis, :, :] + + return ArteryPred, VeinPred, VesselPred, Mask, ArteryPred, VeinPred, VesselPred, + + +def out_test(cfg, output_dir='', evaluate_metrics=False, img_name='out_test'): + device = torch.device("cuda" if cfg.use_cuda else "cpu") + model_root = cfg.model_path_pretrained_G + model_path = os.path.join(model_root, 'G_' + str(cfg.model_step_pretrained_G) + '.pkl') + net = torch.load(model_path, map_location=device) + + image_color = modelEvalution_out_big(net, + use_cuda=cfg.use_cuda, + dataset=img_name, + input_ch=cfg.input_nc, + config=cfg, + output_dir=output_dir, evaluate_metrics=evaluate_metrics) + + return image_color + + +def segment_by_out_test(image,model_name): + print("✅ 传到后端的模型名:", model_name) + + cfg.set_dataset(model_name) + if image is None: + raise gr.Error("请上传一张图像。") + os.makedirs("./examples", exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + temp_path = f"./examples/tmp_upload_{timestamp}.png" + image.save(temp_path) + + image_color = out_test(cfg, output_dir='', evaluate_metrics=False, img_name=temp_path) + return Image.fromarray(image_color) + +def gradio_interface(): + model_info_md = """ + ### 📘 模型说明 + + | 模型 | 数据集 | patch size |running time | + |------|--------|------------|--------| + | DRIVE | 小分辨率血管图像 | 256 |30s以内| + | HRF | 高分辨率图像(健康、青光眼等) | 2min以内| + | LES | 视盘中心图像适配 | 256 |2min以内| + | UKBB | UKBB图像 | 256 |2min以内 | + | 通用模型(512) | 超清图像,适配性强 | 512 |2min以内| + """ + model_choices = [ + ("1: DRIVE专用模型", "DRIVE"), + ("2: HRF专用模型", "hrf"), + ("3: LES专用模型","LES"), + ("4: UKBB专用模型", "ukbb"), + ("5: 通用模型", "all"), + ] + + with gr.Blocks(theme=gr.themes.Soft()) as demo: + gr.Markdown("# 👁️ 眼底图像动静脉血管分割") + gr.Markdown("上传眼底图像,选择一个模型开始处理,结果将自动生成。") + + with gr.Row(): + image_input = gr.Image(type="pil", label="📤 上传图像",height=300) + + with gr.Row(): + with gr.Column(): + model_select = gr.Radio( + choices=model_choices, + label="🎯 选择模型", + value="DRIVE", + interactive = True + ) + submit_btn = gr.Button("🚀 开始分割") + with gr.Column(): + output_image = gr.Image(label="🖼️ 分割结果") + + gr.Markdown("### 📁 示例图像(点击自动加载)") + gr.Examples( + examples=[ + ["examples/DRIVE.tif", "DRIVE"], + ["examples/LES.png", "LES"], + ["examples/hrf.png", "hrf"], + ["examples/ukbb.png", "ukbb"], + ["examples/all.jpg", "all"] + ], + inputs=[image_input, model_select], + label="示例图像", + examples_per_page=5 + ) + with gr.Accordion("📖 模型说明(点击展开)", open=False): + gr.Markdown(model_info_md) + + # 功能连接 + submit_btn.click( + fn=segment_by_out_test, + inputs=[image_input, model_select], + outputs=[output_image] + ) + gr.Markdown("📚 **专用模型**: RIP-AV: Joint Representative Instance Pre-training with Context Aware Network for Retinal Artery/Vein Segmentation") + gr.Markdown("📚 **通用模型**: An Efficient and Interpretable Foundation Model for Retinal Image Analysis in Disease Diagnosis.") + demo.queue() + demo.launch() + + +if __name__ == '__main__': + # cfg.set_dataset('all') + # image_color = out_test(cfg = cfg, evaluate_metrics=False, img_name=r'.\AV\data\AV-DRIVE\test\images\01_test.tif') + # Image.fromarray(image_color).save('image_color.png') + #print(cfg.patch_size) + gradio_interface() diff --git a/examples/DRIVE.tif b/examples/DRIVE.tif new file mode 100644 index 0000000000000000000000000000000000000000..b6ca65248e81288b999fe916e0402b31cd1f69a8 --- /dev/null +++ b/examples/DRIVE.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de6b4dc1f72dfdcda32334060b490b2b448f940c654327fcaa6dc98f7c527302 +size 729790 diff --git a/examples/LES.png b/examples/LES.png new file mode 100644 index 0000000000000000000000000000000000000000..47d277bade3c1dfec704d12b1353b616599512d4 --- /dev/null +++ b/examples/LES.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f29b5b9f5578d0f2eaf26b55afc7749ac62b2af257f8c26821e47d36df6783f +size 1825230 diff --git a/examples/all.jpg b/examples/all.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ed3e1e07cbc5913d07fcf9cc1059713924d3257f --- /dev/null +++ b/examples/all.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9565b12446acb256f5484d9c73b357c4168cbf32504b06acabfe0143c9fad837 +size 557777 diff --git a/examples/hrf.png b/examples/hrf.png new file mode 100644 index 0000000000000000000000000000000000000000..926c740416566a26fa3488f719e3efef3bc35d16 --- /dev/null +++ b/examples/hrf.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:afe035f574d3a96494cbce396887d06ec7884ea7b1f58bc426dfe8fb0236318f +size 2007754 diff --git a/examples/tmp_upload.png b/examples/tmp_upload.png new file mode 100644 index 0000000000000000000000000000000000000000..f3b2686dd2c03cfba03b11e5e3038a54fa5b144c --- /dev/null +++ b/examples/tmp_upload.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ec1886ca18076c7658eecf208e2182865071f38b20db33bb5fcadde4a603dc8 +size 392206 diff --git a/examples/ukbb.png b/examples/ukbb.png new file mode 100644 index 0000000000000000000000000000000000000000..586451cc9a0bf9d46d4aff49e0d6ae5dc39ed953 --- /dev/null +++ b/examples/ukbb.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31e91635bbcf090dc10730cb99dc7c6da05e214c7f0c3546a8d65bd3e671cd50 +size 2233770 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b426b28a2a482b3d1a40a64c62357dd7e9e0d187 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,72 @@ +absl-py==2.1.0 +altgraph==0.17.4 +astunparse==1.6.3 +cachetools==5.5.0 +certifi==2024.8.30 +charset-normalizer==3.4.0 +colorama==0.4.6 +contourpy==1.3.1 +cycler==0.12.1 +defusedxml==0.7.1 +einops==0.3.0 +filelock==3.16.1 +flatbuffers==1.12 +fonttools==4.54.1 +fpdf2==2.8.1 +fsspec==2024.10.0 +gast==0.4.0 +google-auth==2.36.0 +google-auth-oauthlib==0.4.6 +google-pasta==0.2.0 +grpcio==1.67.1 +h5py==3.12.1 +huggingface-hub==0.16.4 +idna==3.10 +imageio==2.36.0 +keras==2.9.0 +keras-preprocessing==1.1.2 +kiwisolver==1.4.7 +libclang==18.1.1 +markdown==3.7 +markupsafe==3.0.2 +matplotlib==3.8.0 +networkx==3.4.2 +numpy==1.23.5 +oauthlib==3.2.2 +opencv-python==4.10.0.84 +opt-einsum==3.4.0 +packaging==24.2 +pandas==1.5.3 +pefile==2023.2.7 +pillow==11.0.0 +protobuf==3.19.6 +pyasn1==0.6.1 +pyasn1-modules==0.4.1 +pyinstaller==6.11.1 +pyinstaller-hooks-contrib==2024.10 +pyparsing==3.2.0 +pyside2==5.15.2.1 +python-dateutil==2.9.0.post0 +pytz==2024.2 +pywavelets==1.7.0 +pywin32-ctypes==0.2.3 +pyyaml==6.0 +requests==2.32.3 +requests-oauthlib==2.0.0 +rsa==4.9 +scikit-image==0.19.3 +scipy==1.14.1 +shiboken2==5.15.2.1 +six==1.16.0 +termcolor==2.5.0 +tifffile==2024.9.20 +torch==1.13.1 +torchaudio==0.13.1 +torchvision==0.14.1 +tqdm==4.67.0 +treelib==1.7.0 +typing-extensions==4.12.2 +urllib3==2.2.3 +werkzeug==3.1.3 +wrapt==1.16.0 +gradio==4.44.1 \ No newline at end of file