Skip to content

Beta Scheduler

VAE.utils.beta_schedulers

Collection of beta schedulers.

Beta schedulers are used to schedule the beta value of the KL divergence during training. Beta schedulers are used in the :class:generators.

VAE.utils.beta_schedulers.BetaScheduler

BetaScheduler(dtype='float')

Bases: ABC

Abstract class for beta schedulers.

This is the abstract base class for all beta schedulers.

Beta schedulers should implement the method :func:__call__ and can optionally overwrite the method :func:get_config

Parameters:

  • dtype (str, default: 'float' ) –

    Data type of the returned beta values.

Source code in VAE/utils/beta_schedulers.py
28
29
def __init__(self, dtype: str = 'float'):
    self.dtype = dtype

VAE.utils.beta_schedulers.BetaScheduler.__call__ abstractmethod

__call__(epoch, shape=(1))

Return beta value.

Abstract method that has to be implemented by all beta schedulers. This method is called during training to obtain the beta values for the current epoch. The shape parameter defines the shape of the returned array of constant beta values.

Parameters:

  • epoch (int) –

    Training epoch for with the beta values will be return.

  • shape (tuple[int, ...], default: (1) ) –

    Output shape for the beta values that will be returned.

Returns:

  • ndarray

    Array of shape shape filled with beta values.

Source code in VAE/utils/beta_schedulers.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
@abc.abstractmethod
def __call__(self, epoch: int, shape: tuple[int, ...] = (1, )) -> np.ndarray:
    """Return beta value.

    Abstract method that has to be implemented by all beta schedulers. This method is called during training to
    obtain the beta values for the current `epoch`. The `shape` parameter defines the shape of the returned array
    of constant beta values.

    Parameters:
        epoch:
            Training epoch for with the beta values will be return.
        shape:
            Output shape for the beta values that will be returned.

    Returns:
        Array of shape `shape` filled with beta values.
    """
    pass

VAE.utils.beta_schedulers.BetaScheduler.get_config

get_config()

Get configuration.

Returns:

  • dict

    Dictionary with the configuration of the beta scheduler.

Source code in VAE/utils/beta_schedulers.py
50
51
52
53
54
55
56
57
def get_config(self) -> dict:
    """Get configuration.

    Returns:
        Dictionary with the configuration of the beta scheduler.

    """
    return {'dtype': self.dtype}

VAE.utils.beta_schedulers.BetaScheduler.summary

summary()

Print a summary of the beta scheduler.

Source code in VAE/utils/beta_schedulers.py
59
60
61
62
63
64
65
def summary(self):
    """Print a summary of the beta scheduler."""
    config = self.get_config()
    cols = len(max(config.keys(), key=len))
    print(f'Summary of "{self.__class__.__name__}" (BetaScheduler)')
    for key, value in config.items():
        print(f'  {key:{cols}} : {value}')

VAE.utils.beta_schedulers.Constant

Constant(value=1.0, **kwargs)

Bases: BetaScheduler

Return constant beta value.

This beta scheduler returns a constant beta value for all epochs.

Parameters:

  • value (float, default: 1.0 ) –

    Value of beta.

  • **kwargs

    Additional arguments for :class:BetaScheduler.

Source code in VAE/utils/beta_schedulers.py
79
80
81
def __init__(self, value: float = 1., **kwargs):
    super().__init__(**kwargs)
    self.value = value

VAE.utils.beta_schedulers.Linear

Linear(lower=0.0, upper=1.0, epochs=10, **kwargs)

Bases: BetaScheduler

Linearly increase beta value.

This beta scheduler returns a linearly increasing beta value.

Parameters:

  • lower (float, default: 0.0 ) –

    Lower (left) bound of beta.

  • upper (float, default: 1.0 ) –

    Upper (right) bound of beta.

  • epochs (float, default: 10 ) –

    Number of epochs for which beta will be increased. If the number of epochs is reached, beta will be constant at the upper bound.

  • **kwargs

    Additional arguments for :class:BetaScheduler.

Source code in VAE/utils/beta_schedulers.py
109
110
111
112
113
114
def __init__(self, lower: float = 0., upper: float = 1., epochs: float = 10, **kwargs):
    super().__init__(**kwargs)
    self.lower = lower
    self.upper = upper
    self.epochs = epochs
    self.values = np.linspace(lower, upper, epochs)

VAE.utils.beta_schedulers.LogisticGrowth

LogisticGrowth(lower=0.0, upper=1.0, midpoint=5.0, rate=1.0, **kwargs)

Bases: BetaScheduler

Increase beta to maximum value at given rate.

This beta scheduler returns a beta value that increases to a maximum value at a given rate. The beta value follows a logistic growth function.

Parameters:

  • lower (float, default: 0.0 ) –

    Lower (left) asymptote of beta.

  • upper (float, default: 1.0 ) –

    Upper (right) asymptote of beta.

  • midpoint (float, default: 5.0 ) –

    Epoch at which beta equals the mean of the upper and lower asymptote.

  • rate (float, default: 1.0 ) –

    Growth rate at which beta increases.

  • **kwargs

    Additional arguments for :class:BetaScheduler.

Source code in VAE/utils/beta_schedulers.py
148
149
150
151
152
153
def __init__(self, lower: float = 0., upper: float = 1., midpoint: float = 5., rate: float = 1., **kwargs):
    super().__init__(**kwargs)
    self.lower = lower
    self.upper = upper
    self.midpoint = midpoint
    self.rate = rate

VAE.utils.beta_schedulers.LogUniform

LogUniform(lower=0.01, upper=1.0, **kwargs)

Bases: BetaScheduler

Draw beta values from log-uniform distribution.

This beta scheduler draws beta values from a log-uniform distribution. The log-uniform distribution is a uniform distribution in log-space.

Parameters:

  • lower (float, default: 0.01 ) –

    Lower (minimum) value of the distribution.

  • upper (float, default: 1.0 ) –

    Upper (maximum) value of the distribution.

  • **kwargs

    Additional arguments for :class:BetaScheduler.

Source code in VAE/utils/beta_schedulers.py
185
186
187
188
189
def __init__(self, lower: float = 0.01, upper: float = 1., **kwargs):
    super().__init__(**kwargs)
    self.lower = lower
    self.upper = upper
    self.fcn = loguniform(lower, upper)