Source code for tinder.saver

import os
import torch
import urllib
import time
import sys
from types import SimpleNamespace


def _reporthook(count, block_size, total_size):
    global start_time
    if count == 0:
        start_time = time.time()
        return
    duration = time.time() - start_time
    progress_size = int(count * block_size)
    speed = int(progress_size / (1024 * duration))
    percent = int(count * block_size * 100 / total_size)
    sys.stdout.write(
        "\r...%d%%, %d MB, %d KB/s, %d seconds passed"
        % (percent, progress_size / (1024 * 1024), speed, duration)
    )
    sys.stdout.flush()


def assert_download(weight_url, weight_dest):
    if not os.path.exists(weight_dest):
        if weight_url:
            print("downloading weight:")
            print("    " + weight_url)
            print("    " + weight_dest)
            urllib.request.urlretrieve(weight_url, weight_dest, reporthook=_reporthook)
        else:
            raise NotImplementedError("please specify url to download in your model")


[docs]class Saver(object): """A helper class to save and load your model. Example:: saver = Saver('/data/weights/', 'resnet152-cosine') saver.load_latest(alexnet, opt) # resume from the latest for epoch in range(100): .. saver.save(alexnet, opt, epoch=epoch, score=acc) # inference saver.load_best(alexnet, opt=None) # no need for optimizer The batch dimension is implicit. The above code is the same as `tensor.view(tensor.size(0), 3, -1, 256)`. Args: weight_dir (str): directory for your weights exp_name (str): name of your experiment (e.g. resnet152-cosine) """ def __init__(self, weight_dir, exp_name): self.weight_dir = weight_dir self.exp_name = exp_name self.dir_path = weight_dir + "/" + exp_name os.makedirs(self.dir_path, exist_ok=True) self.best_epoch_path = self.dir_path + "/best_epoch" if os.path.exists(self.best_epoch_path): with open(self.best_epoch_path) as f: self.best_epoch = int(f.readline()) self.best_score = float(f.readline()) else: self.best_epoch = None self.best_score = None def path_for_epoch(self, epoch): return self.dir_path + "/" + "epoch_%04d.pth" % epoch # ex. ~/imagenet/weights/alexnet/epoch_0001.pth
[docs] def save(self, dic: dict, epoch: int, score: float = None): """Save the model. `score` is used to choose the best model. An example for score is validation accuracy. Example:: model = { 'net':net, 'opt':opt, 'scheduler':cosine_annealing, 'lr': 0.01 } saver = Saver() saver.save(model, epoch=3, score=val_acc) saver.save(model, epoch=4, score=val_acc) meta = saver.load_latest(model) print(meta.lr) print(meta.epoch) Args: dic (dict): the values are objects with `state_dict` and `load_state_dict` epoch (int): number of epochs completed score (float, optional): Defaults to None """ if isinstance(dic, SimpleNamespace): dic = dic.__dict__ if score != None: if (self.best_score is None) or self.best_score < score: self.best_epoch = epoch self.best_score = score with open(self.dir_path + "/best_epoch", "w") as f: print(epoch, file=f) print(score, file=f) new_dic = {} for key, value in dic.items(): if hasattr(value, "state_dict"): new_dic[key] = value.state_dict() else: new_dic[key] = value new_dic["epoch"] = epoch torch.save(new_dic, self.path_for_epoch(epoch))
[docs] def load(self, model_dict: dict, epoch: int) -> bool: """Load the model. It is recommended to use `load_latest` or `load_best` instead. Args: model_dict (dict): see save() epoch (int): epoch to load """ if isinstance(model_dict, SimpleNamespace): model_dict = model_dict.__dict__ p = self.path_for_epoch(epoch) if not os.path.exists(p): print("[tinder] weight doesn't exist: ", p) return False print("[tinder] loading weights: ", p) states = torch.load(p, map_location=lambda storage, loc: storage) assert epoch == states["epoch"] for key, value in model_dict.items(): if key in states: if hasattr(value, "load_state_dict"): value.load_state_dict(states[key]) else: model_dict[key] = states[key] else: print("missing key in the checkpoint: ", key) return True
[docs] def load_latest(self, dic: dict) -> bool: """Load the latest model. Args: dic (dict): see save() Return: int: the epoch of the loaded model. -1 if no model exists. """ files = list(filter(lambda x: x.endswith(".pth"), os.listdir(self.dir_path))) if len(files) == 0: print("[tinder] no weights found in ", self.dir_path) return False latest: str = max(files) assert latest.startswith("epoch_") and latest.endswith(".pth") epoch = int(latest[6:10]) return self.load(dic, epoch)
[docs] def load_best(self, dic: dict) -> bool: """Load the best model. Args: dic (dict): see save() Return: SimpleNamespace """ if self.best_epoch is not None: return self.load(dic, self.best_epoch) return False