Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,607 Bytes
e85fecb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import datetime
import json
import time
from pathlib import Path
import torch
import torch.nn as nn
from ..misc import dist_utils
from ._solver import BaseSolver
from .clas_engine import evaluate, train_one_epoch
class ClasSolver(BaseSolver):
def fit(
self,
):
print("Start training")
self.train()
args = self.cfg
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print("Number of params:", n_parameters)
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True)
start_time = time.time()
start_epoch = self.last_epoch + 1
for epoch in range(start_epoch, args.epochs):
if dist_utils.is_dist_available_and_initialized():
self.train_dataloader.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
self.model,
self.criterion,
self.train_dataloader,
self.optimizer,
self.ema,
epoch=epoch,
device=self.device,
)
self.lr_scheduler.step()
self.last_epoch += 1
if output_dir:
checkpoint_paths = [output_dir / "checkpoint.pth"]
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.checkpoint_freq == 0:
checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth")
for checkpoint_path in checkpoint_paths:
dist_utils.save_on_master(self.state_dict(epoch), checkpoint_path)
module = self.ema.module if self.ema else self.model
test_stats = evaluate(module, self.criterion, self.val_dataloader, self.device)
log_stats = {
**{f"train_{k}": v for k, v in train_stats.items()},
**{f"test_{k}": v for k, v in test_stats.items()},
"epoch": epoch,
"n_parameters": n_parameters,
}
if output_dir and dist_utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))
|