Spaces:
Running
Running
import torch | |
import os | |
# Check GPU availability | |
use_cuda = torch.cuda.is_available() | |
gpu_ids = [0] if use_cuda else [] | |
device = torch.device('cuda' if use_cuda else 'cpu') | |
dataset_name = 'all' # DRIVE | |
#dataset_name = 'LES' # LES | |
# dataset_name = 'hrf' # HRF | |
# dataset_name = 'ukbb' # UKBB | |
# dataset_name = 'all' | |
dataset = dataset_name | |
max_step = 30000 # 30000 for ukbb | |
batch_size = 8 # default: 4 | |
print_iter = 100 # default: 100 | |
display_iter = 100 # default: 100 | |
save_iter = 5000 # default: 5000 | |
first_display_metric_iter = max_step - save_iter # default: 25000 | |
lr = 0.0002 # if dataset_name!='LES' else 0.00005 # default: 0.0002 | |
step_size = 7000 # 7000 for DRIVE | |
lr_decay_gamma = 0.5 # default: 0.5 | |
use_SGD = False # default:False | |
input_nc = 3 | |
ndf = 32 | |
netD_type = 'basic' | |
n_layers_D = 5 | |
norm = 'instance' | |
no_lsgan = False | |
init_type = 'normal' | |
init_gain = 0.02 | |
use_sigmoid = no_lsgan | |
use_noise_input_D = False | |
use_dropout_D = False | |
# torch.cuda.set_device(gpu_ids[0]) | |
use_GAN = True # default: True | |
# adam | |
beta1 = 0.5 | |
# settings for GAN loss | |
num_classes_D = 1 | |
lambda_GAN_D = 0.01 | |
lambda_GAN_G = 0.01 | |
lambda_GAN_gp = 100 | |
lambda_BCE = 5 | |
lambda_DICE = 5 | |
input_nc_D = input_nc + 3 | |
# settings for centerness | |
use_centerness = True # default: True | |
lambda_centerness = 1 | |
center_loss_type = 'centerness' | |
centerness_map_size = [128, 128] | |
# pretrained model | |
use_pretrained_G = True | |
use_pretrained_D = False | |
# model_path_pretrained_G = './log/patch_pretrain' | |
model_path_pretrained_G = '' | |
model_step_pretrained_G = 0 | |
stride_height = 0 | |
stride_width = 0 | |
patch_size_list=[] | |
use_CAM = False | |
#use resize | |
use_resize = False | |
resize_w_h = (1920,512) | |
def set_dataset(name): | |
global dataset_name, model_path_pretrained_G, model_step_pretrained_G | |
global stride_height, stride_width,patch_size,patch_size_list,dataset,use_CAM,use_resize,resize_w_h | |
dataset_name = name | |
dataset = name | |
if dataset_name == 'DRIVE': | |
model_path_pretrained_G = './AV/log/DRIVE-2023_10_20_08_36_50(6500)' | |
model_step_pretrained_G = 6500 | |
elif dataset_name == 'LES': | |
model_path_pretrained_G = './AV/log/LES-2023_09_28_14_04_06(0)' | |
model_step_pretrained_G = 0 | |
elif dataset_name == 'hrf': | |
model_path_pretrained_G = './AV/log/HRF-2023_10_19_11_07_31(1500)' | |
model_step_pretrained_G = 1500 | |
elif dataset_name == 'ukbb': | |
model_path_pretrained_G = './AV/log/UKBB-2023_11_02_23_22_07(5000)' | |
model_step_pretrained_G = 5000 | |
else: | |
model_path_pretrained_G = './AV/log/ALL-2024_09_06_09_17_18(9000)' | |
model_step_pretrained_G = 9000 | |
if dataset_name == 'DRIVE': | |
patch_size_list = [64, 128, 256] | |
elif dataset_name == 'LES': | |
patch_size_list = [96, 384, 256] | |
elif dataset_name == 'hrf': | |
patch_size_list = [64, 384, 256] | |
elif dataset_name == 'ukbb': | |
patch_size_list = [96, 384, 256] | |
else: | |
patch_size_list = [96, 384, 512] | |
patch_size = patch_size_list[2] | |
# path for dataset | |
if dataset_name == 'DRIVE' or dataset_name == 'LES': | |
stride_height = 50 | |
stride_width = 50 | |
elif dataset_name == 'ukbb' or dataset_name == 'hrf': | |
use_CAM=False | |
use_resize = True | |
stride_height = 150 | |
stride_width = 150 | |
else: | |
use_CAM=True | |
use_resize = True | |
stride_height = 150 | |
stride_width = 150 | |
n_classes = 3 | |
model_step = 0 | |
# use av_cross | |
use_av_cross = False | |
use_high_semantic = False | |
lambda_high = 1 # A,V,Vessel | |
# use global semantic in local, huggingface set false | |
use_global_semantic = False | |
global_warmup_step = 0 if use_pretrained_G else 5000 | |