Spaces:
Sleeping
Sleeping
File size: 6,046 Bytes
6e32a75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import torch
import argparse
from modules.dataloader import R2DataLoader
from modules.tokenizers import Tokenizer
from modules.loss import compute_loss
from modules.metrics import compute_scores
from modules.optimizers import build_optimizer, build_lr_scheduler
from models.models import MedCapModel
from modules.trainer import Trainer
import numpy as np
def main():
parser = argparse.ArgumentParser()
# Data input Settings
parser.add_argument('--json_path', default='data/mimic_cxr/annotation.json',
help='Path to the json file')
parser.add_argument('--image_dir', default='data/mimic_cxr/images/',
help='Directory of images')
# Dataloader Settings
parser.add_argument('--dataset', default='mimic_cxr', help='dataset for training MedCap')
parser.add_argument('--bs', type=int, default=16)
parser.add_argument('--threshold', type=int, default=10, help='the cut off frequency for the words.')
parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.')
parser.add_argument('--max_seq_length', type=int, default=1024, help='the maximum sequence length of the reports.')
#Trainer Settings
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.')
parser.add_argument('--save_dir', type=str, default='results/mimic_cxr/', help='the patch to save the models.')
parser.add_argument('--record_dir', type=str, default='./record_dir/',
help='the patch to save the results of experiments.')
parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).')
parser.add_argument('--save_period', type=int, default=1)
parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.')
parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.')
parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.')
# Training related
parser.add_argument('--noise_inject', default='no', choices=['yes', 'no'])
# Sample related
parser.add_argument('--sample_method', type=str, default='greedy', help='the sample methods to sample a report.')
parser.add_argument('--prompt', default='/prompt/prompt.pt')
parser.add_argument('--prompt_load', default='no',choices=['yes','no'])
# Optimization
parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.')
parser.add_argument('--lr_ve', type=float, default=1e-5, help='the learning rate for the visual extractor.')
parser.add_argument('--lr_ed', type=float, default=5e-4, help='the learning rate for the remaining parameters.')
parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.')
parser.add_argument('--adam_betas', type=tuple, default=(0.9, 0.98), help='the weight decay.')
parser.add_argument('--adam_eps', type=float, default=1e-9, help='the weight decay.')
parser.add_argument('--amsgrad', type=bool, default=True, help='.')
parser.add_argument('--noamopt_warmup', type=int, default=5000, help='.')
parser.add_argument('--noamopt_factor', type=int, default=1, help='.')
# Learning Rate Scheduler
parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.')
parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.')
parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.')
# Others
parser.add_argument('--seed', type=int, default=9153, help='.')
parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.')
parser.add_argument('--train_mode', default='base', choices=['base', 'fine-tuning'],
help='Training mode: base (autoencoding) or fine-tuning (full supervised training or fine-tuned on downstream datasets)')
parser.add_argument('--F_version', default='v1', choices=['v1', 'v2'],)
parser.add_argument('--clip_update', default='no' , choices=['yes','no'])
# Fine-tuning
parser.add_argument('--random_init', default='yes', choices=['yes', 'no'],
help='Whether to load the pre-trained weights for fine-tuning.')
parser.add_argument('--weight_path', default='path_to_default_weights', type=str,
help='Path to the pre-trained model weights.')
args = parser.parse_args()
# fix random seeds
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# create tokenizer
tokenizer = Tokenizer(args)
# create data loader
train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True)
val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False)
test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)
# get function handles of loss and metrics
criterion = compute_loss
metrics = compute_scores
model = MedCapModel(args, tokenizer)
if args.train_mode == 'fine-tuning' and args.random_init == 'no':
# Load weights from the specified path
checkpoint = torch.load(args.weight_path)
model.load_state_dict(checkpoint)
# build optimizer, learning rate scheduler
optimizer = build_optimizer(args, model)
lr_scheduler = build_lr_scheduler(args, optimizer)
# build trainer and start to train
trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader)
trainer.train()
if __name__ == '__main__':
main()
|