import torch import os # Check GPU availability use_cuda = torch.cuda.is_available() gpu_ids = [0] if use_cuda else [] device = torch.device('cuda' if use_cuda else 'cpu') dataset_name = 'DRIVE' # DRIVE #dataset_name = 'LES' # LES #dataset_name = 'hrf' # HRF dataset = dataset_name max_step = 30000 # 30000 for ukbb if dataset_name=='DRIVE': patch_size_list = [64, 128, 256] elif dataset_name=='LES': patch_size_list = [96,384, 256] elif dataset_name=='hrf': patch_size_list = [64, 384, 256] patch_size = patch_size_list[2] batch_size = 8 # default: 4 print_iter = 100 # default: 100 display_iter = 100 # default: 100 save_iter = 5000 # default: 5000 first_display_metric_iter = max_step-save_iter # default: 25000 lr = 0.0002 #if dataset_name!='LES' else 0.00005 # default: 0.0002 step_size = 7000 # 7000 for DRIVE lr_decay_gamma = 0.5 # default: 0.5 use_SGD = False # default:False input_nc = 3 ndf = 32 netD_type = 'basic' n_layers_D = 5 norm = 'instance' no_lsgan = False init_type = 'normal' init_gain = 0.02 use_sigmoid = no_lsgan use_noise_input_D = False use_dropout_D = False # torch.cuda.set_device(gpu_ids[0]) use_GAN = True # default: True # adam beta1 = 0.5 # settings for GAN loss num_classes_D = 1 lambda_GAN_D = 0.01 lambda_GAN_G = 0.01 lambda_GAN_gp = 100 lambda_BCE = 5 lambda_DICE = 5 input_nc_D = input_nc + 3 # settings for centerness use_centerness =True # default: True lambda_centerness = 1 center_loss_type = 'centerness' centerness_map_size = [128,128] # pretrained model use_pretrained_G = True use_pretrained_D = False model_path_pretrained_G = r"../RIP/weight" model_step_pretrained_G = 'best_drive' # path for dataset stride_height = 50 stride_width = 50 n_classes = 3 model_step = 0 # use CAM use_CAM = False #use resize use_resize = False resize_w_h = (256,256) #use av_cross use_av_cross = False use_high_semantic = False lambda_high = 1 # A,V,Vessel # use global semantic use_global_semantic = True global_warmup_step = 0 if use_pretrained_G else 5000 # use network use_network = 'convnext_tiny' # swin_t,convnext_tiny dataset_path = {'DRIVE': './data/AV_DRIVE/training/', 'hrf': './data/hrf/training/', 'LES': './data/LES_AV/training/', } trainset_path = dataset_path[dataset_name] print("Dataset:") print(trainset_path) print(use_network)