Spaces:
Runtime error
Runtime error
import time | |
import torch | |
def str2bool(s): | |
return s.lower() in ('true', '1') | |
class Timer: | |
def __init__(self): | |
self.clock = {} | |
def start(self, key="default"): | |
self.clock[key] = time.time() | |
def end(self, key="default"): | |
if key not in self.clock: | |
raise Exception(f"{key} is not in the clock.") | |
interval = time.time() - self.clock[key] | |
del self.clock[key] | |
return interval | |
def save_checkpoint(epoch, net_state_dict, optimizer_state_dict, best_score, checkpoint_path, model_path): | |
torch.save({ | |
'epoch': epoch, | |
'model': net_state_dict, | |
'optimizer': optimizer_state_dict, | |
'best_score': best_score | |
}, checkpoint_path) | |
torch.save(net_state_dict, model_path) | |
def load_checkpoint(checkpoint_path): | |
return torch.load(checkpoint_path) | |
def freeze_net_layers(net): | |
for param in net.parameters(): | |
param.requires_grad = False | |
def store_labels(path, labels): | |
with open(path, "w") as f: | |
f.write("\n".join(labels)) | |