RIP-AV-su-lab / AV /config /config_train_general.py
weidai00's picture
Upload 72 files
6c0075d verified
import torch
import os
# Check GPU availability
use_cuda = torch.cuda.is_available()
gpu_ids = [0] if use_cuda else []
device = torch.device('cuda' if use_cuda else 'cpu')
dataset_name = 'DRIVE' # DRIVE
#dataset_name = 'LES' # LES
#dataset_name = 'hrf' # HRF
dataset = dataset_name
max_step = 30000 # 30000 for ukbb
if dataset_name=='DRIVE':
patch_size_list = [64, 128, 256]
elif dataset_name=='LES':
patch_size_list = [96,384, 256]
elif dataset_name=='hrf':
patch_size_list = [64, 384, 256]
patch_size = patch_size_list[2]
batch_size = 8 # default: 4
print_iter = 100 # default: 100
display_iter = 100 # default: 100
save_iter = 5000 # default: 5000
first_display_metric_iter = max_step-save_iter # default: 25000
lr = 0.0002 #if dataset_name!='LES' else 0.00005 # default: 0.0002
step_size = 7000 # 7000 for DRIVE
lr_decay_gamma = 0.5 # default: 0.5
use_SGD = False # default:False
input_nc = 3
ndf = 32
netD_type = 'basic'
n_layers_D = 5
norm = 'instance'
no_lsgan = False
init_type = 'normal'
init_gain = 0.02
use_sigmoid = no_lsgan
use_noise_input_D = False
use_dropout_D = False
# torch.cuda.set_device(gpu_ids[0])
use_GAN = True # default: True
# adam
beta1 = 0.5
# settings for GAN loss
num_classes_D = 1
lambda_GAN_D = 0.01
lambda_GAN_G = 0.01
lambda_GAN_gp = 100
lambda_BCE = 5
lambda_DICE = 5
input_nc_D = input_nc + 3
# settings for centerness
use_centerness =True # default: True
lambda_centerness = 1
center_loss_type = 'centerness'
centerness_map_size = [128,128]
# pretrained model
use_pretrained_G = True
use_pretrained_D = False
model_path_pretrained_G = r"../RIP/weight"
model_step_pretrained_G = 'best_drive'
# path for dataset
stride_height = 50
stride_width = 50
n_classes = 3
model_step = 0
# use CAM
use_CAM = False
#use resize
use_resize = False
resize_w_h = (256,256)
#use av_cross
use_av_cross = False
use_high_semantic = False
lambda_high = 1 # A,V,Vessel
# use global semantic
use_global_semantic = True
global_warmup_step = 0 if use_pretrained_G else 5000
# use network
use_network = 'convnext_tiny' # swin_t,convnext_tiny
dataset_path = {'DRIVE': './data/AV_DRIVE/training/',
'hrf': './data/hrf/training/',
'LES': './data/LES_AV/training/',
}
trainset_path = dataset_path[dataset_name]
print("Dataset:")
print(trainset_path)
print(use_network)