File size: 3,665 Bytes
a0dc447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62977e4
a0dc447
 
f0ab20b
 
 
a0dc447
 
 
 
f0ab20b
a0dc447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b20df59
 
a0dc447
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import os

# Check GPU availability
use_cuda = torch.cuda.is_available()
gpu_ids = [0] if use_cuda else []
device = torch.device('cuda' if use_cuda else 'cpu')

dataset_name = 'all'  # DRIVE
#dataset_name = 'LES'  # LES
# dataset_name = 'hrf'  # HRF
# dataset_name = 'ukbb'  # UKBB
# dataset_name = 'all'
dataset = dataset_name
max_step = 30000  # 30000 for ukbb

batch_size = 8  # default: 4
print_iter = 100  # default: 100
display_iter = 100  # default: 100
save_iter = 5000  # default: 5000
first_display_metric_iter = max_step - save_iter  # default: 25000
lr = 0.0002  # if dataset_name!='LES' else 0.00005 # default: 0.0002
step_size = 7000  # 7000 for DRIVE
lr_decay_gamma = 0.5  # default: 0.5
use_SGD = False  # default:False

input_nc = 3
ndf = 32
netD_type = 'basic'
n_layers_D = 5
norm = 'instance'
no_lsgan = False
init_type = 'normal'
init_gain = 0.02
use_sigmoid = no_lsgan
use_noise_input_D = False
use_dropout_D = False
# torch.cuda.set_device(gpu_ids[0])
use_GAN = True  # default: True

# adam
beta1 = 0.5

# settings for GAN loss
num_classes_D = 1
lambda_GAN_D = 0.01
lambda_GAN_G = 0.01
lambda_GAN_gp = 100
lambda_BCE = 5
lambda_DICE = 5

input_nc_D = input_nc + 3

# settings for centerness
use_centerness = True  # default: True
lambda_centerness = 1
center_loss_type = 'centerness'
centerness_map_size = [128, 128]

# pretrained model
use_pretrained_G = True
use_pretrained_D = False
# model_path_pretrained_G = './log/patch_pretrain'
model_path_pretrained_G = ''
model_step_pretrained_G = 0
stride_height = 0
stride_width = 0
patch_size_list=[]
use_CAM = False

#use resize
use_resize = False
resize_w_h = (1920,512)
def set_dataset(name):
    global dataset_name, model_path_pretrained_G, model_step_pretrained_G
    global stride_height, stride_width,patch_size,patch_size_list,dataset,use_CAM,use_resize,resize_w_h
    dataset_name = name
    dataset = name
    if dataset_name == 'DRIVE':
        model_path_pretrained_G = './AV/log/DRIVE-2023_10_20_08_36_50(6500)'
        model_step_pretrained_G = 6500
    elif dataset_name == 'LES':
        model_path_pretrained_G = './AV/log/LES-2023_09_28_14_04_06(0)'
        model_step_pretrained_G = 0
    elif dataset_name == 'hrf':
        model_path_pretrained_G = './AV/log/HRF-2023_10_19_11_07_31(1500)'
        model_step_pretrained_G = 1500
    elif dataset_name == 'ukbb':
        model_path_pretrained_G = './AV/log/UKBB-2023_11_02_23_22_07(5000)'
        model_step_pretrained_G = 5000
    else:
        model_path_pretrained_G = './AV/log/ALL-2024_09_06_09_17_18(9000)'
        model_step_pretrained_G = 9000
    if dataset_name == 'DRIVE':
        patch_size_list = [64, 128, 256]
    elif dataset_name == 'LES':
        patch_size_list = [96, 384, 256]
    elif dataset_name == 'hrf':
        patch_size_list = [64, 384, 256]
    elif dataset_name == 'ukbb':
        patch_size_list = [96, 384, 256]
    else:
        patch_size_list = [96, 384, 512]
    patch_size = patch_size_list[2]

# path for dataset
    if dataset_name == 'DRIVE' or dataset_name == 'LES':
        stride_height = 50
        stride_width = 50
    elif dataset_name == 'ukbb' or dataset_name == 'hrf':
        use_CAM=False
        use_resize = True
        stride_height = 150
        stride_width = 150
    else:
        use_CAM=True
        use_resize = True
        stride_height = 150
        stride_width = 150
        

n_classes = 3

model_step = 0


# use av_cross
use_av_cross = False

use_high_semantic = False
lambda_high = 1  # A,V,Vessel

# use global semantic in local, huggingface set false
use_global_semantic = False
global_warmup_step = 0 if use_pretrained_G else 5000