--I created a decorator that reads parameters from a file and adds them to arguments, so I want to share how to use them.
--The motivation is to make it easier to read parameters for models such as DL.
--Because the code that links the arguments from the main function is troublesome ...
--omegaconf It's convenient, so I want people who don't know to know it.
――If you have any other useful things, please let me know!
--Preparation
--Install omegaconf
%%bash
pip install omegaconf
--Preparing the decorator
import functools
from omegaconf import OmegaConf
def add_args(params_file: str, as_default: bool = False) -> callable:
    @functools.wraps(add_args)
    def _decorator(f: callable) -> callable:
        @functools.wraps(f)
        def _wrapper(*args, **kwargs) -> None:
            cfg_params = OmegaConf.load(params_file)
            if as_default:
                cfg_params.update(kwargs)
                kwargs = cfg_params
            else:
                kwargs.update(cfg_params)
            return f(*args, **kwargs)
        return _wrapper
    return _decorator
--Prepare a parameter file to read (yaml or json)
--omegaconf supports yaml, json
%%bash
cat <<__YML__ > params.yml
n_encoder_layer: 3
n_decoder_layer: 5
n_heads: 4
n_embedding: 16
__YML__
:
echo "===== [ params.yml ] ====="
cat params.yml
echo "====="
--Call
@add_args("params.yml")
def use_params(a, b, n_encoder_layer, n_decoder_layer, n_heads, n_embedding):
    assert a == 0.25
    assert b == "world"
    assert n_encoder_layer == 3
    assert n_decoder_layer == 5
    assert n_heads == 4
    assert n_embedding == 16
use_params(a=0.25, b="world")
print("OK")
Here, only a and b are specified in the use_params () function.
You can also programmatically overwrite the params.yml setting as the default by specifying as_default = True as the decorator argument, as shown below. (By the way, in the case of as_default = False (default of the decorator), the direct of the configuration file is prioritized over the actual argument specified by the program.)
@add_args("params.yml", as_default=True)
def use_params(n_encoder_layer, n_decoder_layer, n_heads, n_embedding):
    assert n_encoder_layer == 128   # notice !!
    assert n_decoder_layer == 5
    assert n_heads == 4
    assert n_embedding == 16
use_params(n_encoder_layer=128)
print("OK")
--Other
--You can decorate it with the class __init__, so please give it a try.
--In omegaconf, you can refer to environment variables and direct variables in the configuration file as variables.
--For more information on omegaconf, see here
――It's subtle to write the same code every time, so I want to be able to pip install
Recommended Posts