RIP-AV-su-lab / AV /config /config_test_general.py
weidai00's picture
Update AV/config/config_test_general.py
b20df59 verified
import torch
import os
# Check GPU availability
use_cuda = torch.cuda.is_available()
gpu_ids = [0] if use_cuda else []
device = torch.device('cuda' if use_cuda else 'cpu')
dataset_name = 'all' # DRIVE
#dataset_name = 'LES' # LES
# dataset_name = 'hrf' # HRF
# dataset_name = 'ukbb' # UKBB
# dataset_name = 'all'
dataset = dataset_name
max_step = 30000 # 30000 for ukbb
batch_size = 8 # default: 4
print_iter = 100 # default: 100
display_iter = 100 # default: 100
save_iter = 5000 # default: 5000
first_display_metric_iter = max_step - save_iter # default: 25000
lr = 0.0002 # if dataset_name!='LES' else 0.00005 # default: 0.0002
step_size = 7000 # 7000 for DRIVE
lr_decay_gamma = 0.5 # default: 0.5
use_SGD = False # default:False
input_nc = 3
ndf = 32
netD_type = 'basic'
n_layers_D = 5
norm = 'instance'
no_lsgan = False
init_type = 'normal'
init_gain = 0.02
use_sigmoid = no_lsgan
use_noise_input_D = False
use_dropout_D = False
# torch.cuda.set_device(gpu_ids[0])
use_GAN = True # default: True
# adam
beta1 = 0.5
# settings for GAN loss
num_classes_D = 1
lambda_GAN_D = 0.01
lambda_GAN_G = 0.01
lambda_GAN_gp = 100
lambda_BCE = 5
lambda_DICE = 5
input_nc_D = input_nc + 3
# settings for centerness
use_centerness = True # default: True
lambda_centerness = 1
center_loss_type = 'centerness'
centerness_map_size = [128, 128]
# pretrained model
use_pretrained_G = True
use_pretrained_D = False
# model_path_pretrained_G = './log/patch_pretrain'
model_path_pretrained_G = ''
model_step_pretrained_G = 0
stride_height = 0
stride_width = 0
patch_size_list=[]
use_CAM = False
#use resize
use_resize = False
resize_w_h = (1920,512)
def set_dataset(name):
global dataset_name, model_path_pretrained_G, model_step_pretrained_G
global stride_height, stride_width,patch_size,patch_size_list,dataset,use_CAM,use_resize,resize_w_h
dataset_name = name
dataset = name
if dataset_name == 'DRIVE':
model_path_pretrained_G = './AV/log/DRIVE-2023_10_20_08_36_50(6500)'
model_step_pretrained_G = 6500
elif dataset_name == 'LES':
model_path_pretrained_G = './AV/log/LES-2023_09_28_14_04_06(0)'
model_step_pretrained_G = 0
elif dataset_name == 'hrf':
model_path_pretrained_G = './AV/log/HRF-2023_10_19_11_07_31(1500)'
model_step_pretrained_G = 1500
elif dataset_name == 'ukbb':
model_path_pretrained_G = './AV/log/UKBB-2023_11_02_23_22_07(5000)'
model_step_pretrained_G = 5000
else:
model_path_pretrained_G = './AV/log/ALL-2024_09_06_09_17_18(9000)'
model_step_pretrained_G = 9000
if dataset_name == 'DRIVE':
patch_size_list = [64, 128, 256]
elif dataset_name == 'LES':
patch_size_list = [96, 384, 256]
elif dataset_name == 'hrf':
patch_size_list = [64, 384, 256]
elif dataset_name == 'ukbb':
patch_size_list = [96, 384, 256]
else:
patch_size_list = [96, 384, 512]
patch_size = patch_size_list[2]
# path for dataset
if dataset_name == 'DRIVE' or dataset_name == 'LES':
stride_height = 50
stride_width = 50
elif dataset_name == 'ukbb' or dataset_name == 'hrf':
use_CAM=False
use_resize = True
stride_height = 150
stride_width = 150
else:
use_CAM=True
use_resize = True
stride_height = 150
stride_width = 150
n_classes = 3
model_step = 0
# use av_cross
use_av_cross = False
use_high_semantic = False
lambda_high = 1 # A,V,Vessel
# use global semantic in local, huggingface set false
use_global_semantic = False
global_warmup_step = 0 if use_pretrained_G else 5000