RIP-AV-su-lab / AV /Tools /centerline_evaluation.py
weidai00's picture
Upload 72 files
6c0075d verified
import numpy as np
import cv2
import os
import natsort
from Tools.Hemelings_eval import evaluation_code
import pandas as pd
def getFolds(ImgPath, LabelPath, k_fold_idx,k_fold, trainset=True):
print(f'ImgPath: {ImgPath}')
print(f'LabelPath: {LabelPath}')
print(f'k_fold_idx: {k_fold_idx}')
print(f'k_fold: {k_fold}')
print(f'trainset: {trainset}')
for dirpath,dirnames,filenames in os.walk(ImgPath):
ImgDirAll = filenames
break
for dirpath,dirnames,filenames in os.walk(LabelPath):
LabelDirAll = filenames
break
ImgDir = []
LabelDir = []
ImgDir_testset = []
LabelDir_testset = []
if k_fold >0:
ImgDirAll = natsort.natsorted(ImgDirAll)
LabelDirAll = natsort.natsorted(LabelDirAll)
num_fold = len(ImgDirAll) // k_fold
for i in range(k_fold):
start_idx = i * num_fold
end_idx = (i+1) * num_fold
# if i == k_fold_idx:
ImgDir_testset.extend(ImgDirAll[start_idx:end_idx])
LabelDir_testset.extend(LabelDirAll[start_idx:end_idx])
#continue
ImgDir.extend(ImgDirAll[start_idx:end_idx])
LabelDir.extend(LabelDirAll[start_idx:end_idx])
if not trainset:
return ImgDir_testset, LabelDir_testset
return ImgDir, ImgDir
def centerline_eval(ProMap, config):
if config.dataset_name == 'hrf':
ImgPath = os.path.join(config.trainset_path,'test', 'images')
LabelPath = os.path.join(config.trainset_path,'test', 'ArteryVein_0410_final')
elif config.dataset_name == 'DRIVE':
dataroot = './data/AV_DRIVE/test'
LabelPath = os.path.join(dataroot, 'av')
else: #if config.dataset_name == 'INSPIRE':
dataroot = './data/INSPIRE_AV'
LabelPath = os.path.join(dataroot, 'label')
DF_disc = pd.read_excel('./Tools/DiskParameters_INSPIRE_resize.xls', sheet_name=0)
if config.dataset_name == 'hrf':
k_fold_idx = config.k_fold_idx
k_fold = config.k_fold
ImgList0 , LabelList0 = getFolds(ImgPath, LabelPath, k_fold_idx,k_fold, trainset=False)
overall_value = [[0,0] for i in range(3)]
overall_value.append([0,0,0,0])
overall_value.append(0)
img_num = ProMap.shape[0]
for i in range(img_num):
arteryImg = ProMap[i, 0, :, :]
veinImg = ProMap[i, 1, :, :]
vesselImg = ProMap[i, 2, :, :]
idx = str(i+1)
idx = idx.zfill(2)
if config.dataset_name == 'hrf':
imgName = ImgList0[i]
elif config.dataset_name == 'DRIVE':
imgName = idx + '_test.png'
else:
imgName = 'image'+ str(i+1) + '_ManualAV.png'
gt_path = os.path.join(LabelPath , imgName)
print(gt_path)
gt = cv2.imread(gt_path)
gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)
if config.dataset_name == 'hrf':
gt = cv2.resize(gt, (1200,800))
gt_vessel = gt[:,:,0]+gt[:,:,2]
h, w = arteryImg.shape
ArteryPred = np.float32(arteryImg)
VeinPred = np.float32(veinImg)
VesselPred = np.float32(vesselImg)
AVSeg1 = np.zeros((h, w, 3))
vesselSeg = np.zeros((h, w))
th = 0
vesselPixels = np.where(gt_vessel)#(VesselPred>th) #
for k in np.arange(len(vesselPixels[0])):
row = vesselPixels[0][k]
col = vesselPixels[1][k]
if ArteryPred[row, col] >= VeinPred[row, col]:
AVSeg1[row, col] = (255, 0, 0)
else:
AVSeg1[row, col] = ( 0, 0, 255)
AVSeg1 = np.float32(AVSeg1)
AVSeg1 = np.uint8(AVSeg1)
if config.dataset_name == 'INSPIRE':
discCenter = (DF_disc.loc[i, 'DiskCenterRow'], DF_disc.loc[i, 'DiskCenterCol'])
discRadius = DF_disc.loc[i, 'DiskRadius']
MaskDisc = np.ones((h, w), np.uint8)
cv2.circle(MaskDisc, center=(discCenter[1], discCenter[0]), radius= discRadius, color=0, thickness=-1)
out = evaluation_code(AVSeg1, gt, mask=MaskDisc,use_mask=True)
else:
out = evaluation_code(AVSeg1, gt)
for j in range(len(out)):
if j == 4:
overall_value[j] += out[j]
continue
if j == 3:
overall_value[j][0] += out[j][0]
overall_value[j][1] += out[j][1]
overall_value[j][2] += out[j][2]
overall_value[j][3] += out[j][3]
continue
overall_value[j][0] += out[j][0]
overall_value[j][1] += out[j][1]
# print("overall_value:", overall_value)
for j in range(len(overall_value)):
if j == 4:
overall_value[j] /= img_num
continue
if j == 3:
overall_value[j][0] /= img_num
overall_value[j][1] /= img_num
overall_value[j][2] /= img_num
overall_value[j][3] /= img_num
continue
overall_value[j][0] /= img_num
overall_value[j][1] /= img_num
# print
metrics_names = ['full image', 'discovered centerline pixels', 'vessels wider than two pixels', 'all centerline', 'vessel detection rate']
filewriter = ""
print("--------------------------Centerline---------------------------------")
filewriter += "--------------------------Centerline---------------------------------\n"
for j in range(len(overall_value)):
if j == 4:
print("{} - Ratio:{}".format(metrics_names[j], overall_value[j]))
filewriter += "{} - Ratio:{}\n".format(metrics_names[j], overall_value[j])
continue
if j == 3:
print("{} - Acc: {} , F1:{}, Sens:{}, Spec:{}".format(metrics_names[j], overall_value[j][0],overall_value[j][1],overall_value[j][2],overall_value[j][3]))
filewriter += "{} - Acc: {} , F1:{}, Sens:{}, Spec:{}\n".format(metrics_names[j], overall_value[j][0],overall_value[j][1],overall_value[j][2],overall_value[j][3])
continue
print("{} - Acc: {} , F1:{}".format(metrics_names[j], overall_value[j][0],overall_value[j][1]))
filewriter += "{} - Acc: {} , F1:{}\n".format(metrics_names[j], overall_value[j][0],overall_value[j][1])
return filewriter