Source code for ignite_simple.utils

"""Utility functions which are used through ignite-simple but are short
enough not to warrant their own module.
"""

import torch.utils.data as data
import torch
import numpy as np
import typing
import importlib
import os
import sys

[docs]def noop(*args, **kwargs): """This function does nothing.""" pass
[docs]def fix_imports(loader: typing.Tuple[str, str, tuple, dict]): """Returns the loader which represents the same callable as the given one, except potentially with the __main__ name replaced with the correct module name. :param loader: the loader whose imports might need to be cleaned :returns: the cleaned loader """ if loader[0] != '__main__': return loader called_file = sys.argv[0] if called_file == '-c': raise ValueError('cannot fix __main__ when python invoked with -c') called_file = os.path.abspath(called_file) cwd = os.getcwd() if not called_file.startswith(cwd): raise ValueError('cannot fix __main__ when not called from a parent ' + 'directory. you should be called with something ' + 'like python -m examples.mnist_direct or ' + 'python examples/mnist_direct.py but never ' + 'python ../examples/mnist_direct.py') called_file = called_file[len(cwd):] called_file, _ = os.path.splitext(called_file) new_path = [] cur_path, tail = os.path.split(called_file) new_path.append(tail) while cur_path: new_cur_path, tail = os.path.split(cur_path) if tail: new_path.append(tail) if new_cur_path == cur_path: break cur_path = new_cur_path new_path.reverse() correct_module = '.'.join(new_path) res = [correct_module] res.extend(loader[1:]) return tuple(res)
[docs]def invoke(loader: typing.Tuple[str, str, tuple, dict]): """Invokes the callable which has the given name in the given module, using the given arguments and keyword arguments :param loader: (module, attrname, args, kwargs) - the callable to invoke :returns: the result of the callable """ modulename, attrname, args, kwargs = loader module = importlib.import_module(modulename) return getattr(module, attrname)(*args, **kwargs)
[docs]def create_partial_loader(dset: data.Dataset, amt: int, batch_size: int = 256) -> data.DataLoader: """Creates a dataloader which loads only a random subset of the specified length from the dataset, using the specified batch size. :param dset: the dataset to create a loader for a partial subset of :param amt: the number of items in the partial subset :returns: the described dataloader with a reasonable batch size """ if amt == len(dset): return data.DataLoader(dset, batch_size=batch_size) inds = np.random.choice(len(dset), amt, replace=False) inds = torch.from_numpy(inds).long() return data.DataLoader(data.Subset(dset, inds), batch_size=batch_size)
[docs]def task_loader(dataset_loader, batch_size, shuffle, drop_last): """Creates a task loader from a dataset loader. :param dataset_loader: the dataset loader (str, str, tuple, dict) :param batch_size: the batch size for the data loader :param shuffle: if the training dataset should be shuffled between epochs :param drop_last: if the last batch should be dropped if its not the same size as the rest. should only be used if shuffle is True """ train_set, val_set = invoke(dataset_loader) train_loader = data.DataLoader( train_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) return train_set, val_set, train_loader
[docs]def split(full: data.Dataset, val_perc: float, filen: str = None) -> typing.Tuple[data.Dataset, data.Dataset]: """Splits the given dataset into two datasets, the first of which has (1 - val_perc) fraction of the data and the other has val_perc fraction of the data, distributed randomly. :param data.Dataset full: the entire dataset to split into two :param float val_perc: the amount to be broken away from full :returns: (train_set, val_set) """ if filen is not None: _, ext = os.path.splitext(filen) if ext == '': filen += '.npz' elif ext != '.npz': raise ValueError(f'bad file extension: {filen} (should be .npz)') if filen is None or not os.path.exists(filen): n_held_out = int(len(full) * val_perc) held_out_inds_np = ( np.random.choice(len(full), n_held_out, replace=False)) if filen is not None: np.savez_compressed(filen, held_out_inds=held_out_inds_np) else: with np.load(filen) as infile: held_out_inds_np = infile['held_out_inds'] held_out_inds = torch.from_numpy(held_out_inds_np).long() not_held_out = torch.arange(len(full)) not_held_out[held_out_inds] = -1 not_held_out = not_held_out[not_held_out != -1] return data.Subset(full, not_held_out), data.Subset(full, held_out_inds)