Source code for ignite_simple.helper

"""This acts as a potential runner for files, or as an import to reduce the
amount of boilerplate in a runner. Given a module which has a model() function,
dataset() function, and loss() function this uses argparse to fill in the rest
of the parameters to train() or reanalyze() as requested.


.. code-block:: python

    # in file
    import ignite_simple.helper
    import torch

    def model():
        pass # omitted, should return torch.nn.Module

    def dataset():
        pass # omitted, return train_set, val_set

    accuracy_style = 'multiclass'
    loss = torch.nn.MSELoss # any callable that returns a loss works

    if __name__ == '__main__':

.. code-block:: none

    > python3 -m mymod --help

import ignite_simple
import os
import argparse
import importlib
import logging.config

[docs]def train(module, args): """Uses the given arguments from argparse to determine the arguments to ignite_simple.train :param module: module containing the model(), dataset(), loss(), and accuracy_style :param args: argparse result """ mod = importlib.import_module(module) ignite_simple.train( (module, 'model', tuple(), dict()), (module, 'dataset', tuple(), dict()), (module, 'loss', tuple(), dict()), folder=os.path.join(args.folder, 'current'), hyperparameters=args.hparams, analysis=args.analysis, allow_later_analysis_up_to=args.analysis_up_to, accuracy_style=getattr(mod, 'accuracy_style'), trials=args.trials, is_continuation=args.is_continuation, history_folder=os.path.join(args.folder, 'history'), cores=args.cores, trials_strict=args.strict_trials )
[docs]def reanalyze(module, args): """Uses the given arguments from argparse to determine the arguments to ignite_simple.reanalyze :param module: module containing model(), dataset(), loss(), accuracy_style :param args: argparse result """ mod = importlib.import_module(module) ignite_simple.analyze( (__name__, 'dataset', tuple(), dict()), (__name__, 'loss', tuple(), dict()), folder=os.path.join(args.folder, 'current'), settings=args.analysis, accuracy_style=getattr(mod, 'accuracy_style'), cores=args.cores)
[docs]def handle(module=None): """Uses the given module containing model(), dataset(), loss() and accuracy_style as the module for train() or analyze() with everything else determined by the command line arguments. :param module: module containing model(), dataset(), loss(), accuracy_style """ parser = argparse.ArgumentParser(description='Simple model/dataset helper') parser.add_argument( '--folder', type=str, default=None, help='Where to store the output') parser.add_argument( '--hparams', type=str, default='fast', help='Level of hyperparameter tuning, one of \'fastest\', \'fast\', ' + '\'slow\', and \'slowest\'') parser.add_argument( '--analysis', type=str, default='images', help='Level of analysis to perform, typically images or videos') parser.add_argument( '--analysis_up_to', type=str, default='videos', help='Level of analysis that will be possible without repeating ' + 'trials') parser.add_argument( '--trials', type=int, default=1, help='Minimum number of trials to perform') parser.add_argument( '--not_continuation', action='store_true', help='If specified, trials will be archived if they exist first') parser.add_argument( '--cores', type=int, default=-1, help='Number of cores to use, default is all physical cores available') parser.add_argument( '--reanalyze', action='store_true', help='Instead of training, just perform analysis on existing trials') parser.add_argument( '--module', type=str, help='Which module to load the model, dataset, loss, and ' + ' accuracy style from') parser.add_argument( '--strict_trials', action='store_true', help='Instead of using all available resources, just perform the ' + 'specified number of trials' ) parser.add_argument( '--loggercfg', type=str, default='logging.conf', help='The logging configuration file to use should it exist' ) args = parser.parse_args() args.is_continuation = not args.not_continuation args.cores = args.cores if args.cores != -1 else 'all' args.module = args.module if args.module is not None else module if args.folder is None: args.folder = os.path.join('out', *args.module.split('.')) if args.module is None: raise ValueError('Module must be set from command line or ' + 'delegating file') if os.path.exists(args.loggercfg): logging.config.fileConfig(args.loggercfg) else: print('No logging configuration found - continuing without logging') if args.reanalyze: reanalyze(args.module, args) else: train(args.module, args)
if __name__ == '__main__': handle()