File size: 5,630 Bytes
6c0075d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
from skimage.morphology import skeletonize, erosion
from sklearn.metrics import f1_score, accuracy_score,classification_report,confusion_matrix
import cv2
from Tools.BinaryPostProcessing import binaryPostProcessing3

def evaluation_code(prediction, groundtruth, mask=None,use_mask=False):
    '''
    Function to evaluate the performance of AV predictions with a given ground truth
    - prediction: should be an image array of [dim1, dim2, img_channels = 3] with arteries in red and veins in blue
    - groundtruth: same as above
    '''
    encoded_pred = np.zeros(prediction.shape[:2], dtype=int)
    encoded_gt = np.zeros(groundtruth.shape[:2], dtype=int)
    
    # convert white pixels to green pixels (which are ignored)
    white_ind = np.where(np.logical_and(groundtruth[:,:,0] == 255, groundtruth[:,:,1] == 255, groundtruth[:,:,2] == 255))
    if white_ind[0].size != 0:
        groundtruth[white_ind] = [0,255,0]
        
    # translate the images to arrays suited for sklearn metrics
    # --- original -------
    arteriole = np.where(np.logical_and(groundtruth[:,:,0] == 255, groundtruth[:,:,1] == 0)); encoded_gt[arteriole] = 1
    venule = np.where(np.logical_and(groundtruth[:,:,2] == 255, groundtruth[:,:,1] == 0)); encoded_gt[venule] = 2
    arteriole = np.where(prediction[:,:,0] == 255); encoded_pred[arteriole] = 1
    venule = np.where(prediction[:,:,2] == 255); encoded_pred[venule] = 2
    # --------------------

    # disk and cup
    if use_mask:
        groundtruth = cv2.bitwise_and(groundtruth, groundtruth, mask=mask)
        prediction = cv2.bitwise_and(prediction, prediction, mask=mask)

        encoded_pred = cv2.bitwise_and(encoded_pred, encoded_pred, mask=mask)
        encoded_gt = cv2.bitwise_and(encoded_gt, encoded_gt, mask=mask)

    # retrieve the indices for the centerline pixels present in the prediction
    center = np.where(np.logical_and(
        np.logical_or((skeletonize(groundtruth[:,:,0] > 0)),(skeletonize(groundtruth[:,:,2] > 0))),
        encoded_pred[:,:] > 0))
    
    encoded_pred_center = encoded_pred[center]
    encoded_gt_center = encoded_gt[center]
    
    # retrieve the indices for the centerline pixels present in the groundtruth
    center_comp = np.where(
        np.logical_or(skeletonize(groundtruth[:,:,0] > 0),skeletonize(groundtruth[:,:,2] > 0)))
    
    encoded_pred_center_comp = encoded_pred[center_comp]
    encoded_gt_center_comp = encoded_gt[center_comp]
    
    # retrieve the indices for discovered centerline pixels - limited to vessels wider than two pixels (for DRIVE)
    center_eroded = np.where(np.logical_and(
        np.logical_or(skeletonize(erosion(groundtruth[:,:,0] > 0)),skeletonize(erosion(groundtruth[:,:,2] > 0))),
        encoded_pred[:,:] > 0))
                             
    encoded_pred_center_eroded = encoded_pred[center_eroded]
    encoded_gt_center_eroded = encoded_gt[center_eroded]
    
    # metrics over full image
    cur1_acc = accuracy_score(encoded_gt.flatten(),encoded_pred.flatten())
    cur1_F1 = f1_score(encoded_gt.flatten(),encoded_pred.flatten(),average='weighted')
    # cls_report = classification_report(encoded_gt.flatten(), encoded_pred.flatten(), target_names=['class_1', 'class_2', 'class_3'])
    # print('Full image')
    # print('Accuracy: {}\nF1: {}\n'.format(cur1_acc, cur1_F1))
    # print('Class report:')
    # print(cls_report)
    metrics1 = [cur1_acc, cur1_F1]
    
    # metrics over discovered centerline pixels
    cur2_acc = accuracy_score(encoded_gt_center.flatten(),encoded_pred_center.flatten())
    cur2_F1 = f1_score(encoded_gt_center.flatten(),encoded_pred_center.flatten(),average='weighted')
    # print('Discovered centerline pixels')
    # print('Accuracy: {}\nF1: {}\n'.format(cur2_acc, cur2_F1))
    metrics2 = [cur2_acc, cur2_F1]
    
    # metrics over discovered centerline pixels - limited to vessels wider than two pixels
    cur3_acc = accuracy_score(encoded_gt_center_eroded.flatten(),encoded_pred_center_eroded.flatten())
    cur3_F1 = f1_score(encoded_gt_center_eroded.flatten(),encoded_pred_center_eroded.flatten(),average='weighted')
    # print('Discovered centerline pixels of vessels wider than two pixels')
    # print('Accuracy: {}\nF1: {}\n'.format(cur3_acc, cur3_F1))
    metrics3 = [cur3_acc, cur3_F1]
    
    # metrics over all centerline pixels in ground truth
    cur4_acc = accuracy_score(encoded_gt_center_comp.flatten(),encoded_pred_center_comp.flatten())
    cur4_F1 = f1_score(encoded_gt_center_comp.flatten(),encoded_pred_center_comp.flatten(),average='weighted')
    # print('Centerline pixels')
    # print('Accuracy: {}\nF1: {}\n'.format(cur4_acc, cur4_F1))
    # confusion matrix
    out = confusion_matrix(encoded_gt_center_comp,encoded_pred_center_comp)#.ravel()
    sens = 0
    sepc = 0
    
    if out.shape[0] == 2:
        tn, fp,fn,tp = out.ravel()
        # print(tn, fp,fn,tp)
    else:
        tn = out[1,1]
        fp = out[1,2]
        fn = out[2,1]
        tp = out[2,2]

    # sens = tpr
    spec = tp/ (tp+fn)
    # spec = TNR
    sens = tn/(fp+tn)

    metrics4 = [cur4_acc, cur4_F1, sens, spec]
    

    # finally, compute vessel detection rate
    vessel_ind = np.where(encoded_gt>0)
    vessel_gt = encoded_gt[vessel_ind]
    vessel_pred = encoded_pred[vessel_ind]
    
    detection_rate = accuracy_score(vessel_gt.flatten(),vessel_pred.flatten())
    # print('Amount of vessels detected: ' + str(detection_rate))
    
    return [metrics1,metrics2,metrics3,metrics4,detection_rate]#,encoded_pred,encoded_gt,center,center_comp,center_eroded#,encoded_pred_center_comp,encoded_gt_center_comp