Skip to content

Losses

VAE.losses

Collection of loss functions.

VAE.losses.KLDivergence

KLDivergence(z_mean, z_log_var, kl_threshold=None, free_bits=None)

Kullback-Leibler divergence.

This is the KL divergence between N(z_mean, z_log_var) and the prior N(0, 1).

Parameters:

  • z_mean (Tensor) –

    Tensor of shape (batch_size, latent_dim) specifying mean.

  • z_log_var (Tensor) –

    Tensor of shape (batch_size, latent_dim) specifying log of variance.

  • kl_threshold (float, default: None ) –

    Lower bound for the KL divergence. Default is None.

  • free_bits (float, default: None ) –

    Number of bits to keep free for the KL divergence per latent dimension; cf. Appendix C8 in [1]. Default is None.

Returns:

  • callable

    Loss function that returns a tensor of shape (batch_size, 1, 1).

References

[1] Kingma et al. (2016): Improved Variational Inference with Inverse Autoregressive Flow. NIPS 2016.

Source code in VAE/losses.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def KLDivergence(z_mean: tf.Tensor,
                 z_log_var: tf.Tensor,
                 kl_threshold: float = None,
                 free_bits: float = None) -> callable:
    """Kullback-Leibler divergence.

    This is the KL divergence between N(`z_mean`, `z_log_var`) and the prior N(0, 1).

    Parameters:
        z_mean:
            Tensor of shape `(batch_size, latent_dim)` specifying mean.
        z_log_var:
            Tensor of shape `(batch_size, latent_dim)` specifying log of variance.
        kl_threshold:
            Lower bound for the KL divergence. Default is `None`.
        free_bits:
            Number of bits to keep free for the KL divergence per latent dimension; cf. Appendix C8 in [1]. Default is
            `None`.

    Returns:
        Loss function that returns a tensor of shape `(batch_size, 1, 1)`.

    References:
        [1] Kingma et al. (2016): Improved Variational Inference with Inverse Autoregressive Flow. NIPS 2016.

    """
    def kl_divergence(y_true, y_pred):
        # KL divergence to N(0, 1) of shape (batch_size, latent_dim)
        kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
        kl_loss *= -0.5

        # apply threshold per latent dimension
        if free_bits is not None:
            kl_loss = tf.maximum(kl_loss, free_bits)

        # reduce to shape (batch_size, 1)
        kl_loss = tf.reduce_sum(kl_loss, axis=-1, keepdims=True)

        # apply global threshold
        if kl_threshold is not None:
            kl_loss = tf.maximum(kl_loss, kl_threshold)

        # expand to shape (batch_size, 1, 1)
        kl_loss = tf.expand_dims(kl_loss, axis=-1)

        return kl_loss

    return kl_divergence

VAE.losses.SquaredError

SquaredError(size=1, taper=None)

Squared error loss.

This is the reconstruction loss of the model, without the KL divergence.

Parameters:

  • size (int, default: 1 ) –

    Size of the model output of shape (set_size, output_length, channels).

  • taper (array_like, default: None ) –

    Array of length output_length to taper the squared error.

Returns:

  • callable

    Loss function that returns a tensor of shape (batch_size, set_size, 1).

Source code in VAE/losses.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def SquaredError(size: int = 1, taper=None) -> callable:
    """Squared error loss.

    This is the reconstruction loss of the model, without the KL divergence.

    Parameters:
        size:
            Size of the model output of shape `(set_size, output_length, channels)`.
        taper (array_like):
            Array of length `output_length` to taper the squared error.

    Returns:
        Loss function that returns a tensor of shape `(batch_size, set_size, 1)`.

    """
    def squared_error(y_true, y_pred):
        # losses reduce last channel dimension to shape (batch_size, set_size, output_length)
        reconstruction_loss = ks.losses.mse(y_true, y_pred)
        if taper is not None:
            reconstruction_loss *= taper
        # further reduce to shape (batch_size, set_size, 1)
        reconstruction_loss = tf.reduce_mean(reconstruction_loss, axis=-1, keepdims=True)
        # scale back to sum of squared errors
        reconstruction_loss *= size

        return reconstruction_loss

    return squared_error

VAE.losses.Similarity

Similarity(temperature=1.0)

Similarity loss.

This loss flattens all but the leading batch dimension of y_pred to the shape (batch_size, -1). The similarity is then calculated of the reshaped input.

Parameters:

  • temperature (float, default: 1.0 ) –

    Temperature for the softmax.

Returns: Loss function.

Source code in VAE/losses.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def Similarity(temperature: float = 1.) -> callable:
    """Similarity loss.

    This loss flattens all but the leading batch dimension of `y_pred` to the shape `(batch_size, -1)`. The similarity
    is then calculated of the reshaped input.

    Parameters:
        temperature:
            Temperature for the softmax.
    Returns:
        Loss function.

    """
    def sim(y_true, y_pred):
        batch_size = tf.shape(y_pred)[0]
        # flatten input to shape (batch_size, -1)
        x = tf.reshape(y_pred, (batch_size, -1))
        # compute similarity loss
        loss = vaemath.similarity(x, temperature=temperature)
        # broadcast loss to shape (batch_size, 1, 1)
        loss = tf.reshape(loss, (-1, 1, 1))

        return loss

    return sim

VAE.losses.SimilarityBetween

SimilarityBetween(repeat_samples=1, temperature=1.0)

Similarity between repeated samples (fast implementation).

This function returns the similarity between repeated samples. The input y_pred is first reshaped to the shape (batch_size // repeat_samples, repeat_samples, ...) and the similarity is calculated for each slice along the first dimension.

Note: This is an implementation with einsum that avoids calling :func:math.similarity with :func:tf.map_fn.

Parameters:

  • repeat_samples (int, default: 1 ) –

    Number of repeated samples.

  • temperature (float, default: 1.0 ) –

    Temperature for the softmax.

Returns: Loss function.

Source code in VAE/losses.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def SimilarityBetween(repeat_samples: int = 1, temperature: float = 1.) -> callable:
    """Similarity between repeated samples (fast implementation).

    This function returns the similarity between repeated samples. The input `y_pred` is first reshaped to the shape
    `(batch_size // repeat_samples, repeat_samples, ...)` and the similarity is calculated for each slice along the
    first dimension.

    Note: This is an implementation with einsum that avoids calling :func:`math.similarity` with :func:`tf.map_fn`.

    Parameters:
        repeat_samples:
            Number of repeated samples.
        temperature:
            Temperature for the softmax.
    Returns:
        Loss function.

    """
    def sim_between(y_true, y_pred):
        batch_size = tf.shape(y_pred)[0]
        # reshape to (batch_size // repreat_samples, repeat_samples, -1)
        inputs = tf.reshape(y_pred, (batch_size // repeat_samples, repeat_samples, -1))
        # normalize input
        l2 = tf.math.l2_normalize(inputs, axis=-1)
        # correlation matrices of slices along first axis, each of shape (repeat_samples, -1)
        # shape (batch_size // repreat_samples, repeat_samples, repeat_samples)
        similarity = tf.einsum('ijk, ilk -> ijl', l2, l2)
        # reshape to (batch_size, repeat_samples)
        similarity = tf.reshape(similarity, (-1, repeat_samples))
        # apply temperature
        similarity /= temperature
        # target labels = diagonal elements
        labels = tf.tile(tf.range(repeat_samples), (batch_size // repeat_samples, ))
        # cross entropy loss between target labels and similarity matrices
        loss = ks.losses.sparse_categorical_crossentropy(labels, similarity, from_logits=True, axis=-1)

        # broadcast to shape (batch_size, 1, 1)
        loss = tf.reshape(loss, (-1, 1, 1))

        return loss

    return sim_between

VAE.losses.TotalCorrelation

TotalCorrelation(z, z_mean, z_log_var)

Total correlation.

The total correlation is already part of the KL divergence, see KL decomposition in [1]. This function only returns the batch-wise sampled total correlation loss that will be added on top of the KL divergence. The total correlation is computed separately for each step along the axis of length set_size of the input tensor z.

Parameters: z: Sample from latent space of shape (batch_size, set_size, latent_dim). z_mean: Mean of latent space of shape (batch_size, latent_dim). z_log_var: Log of variance of latent space of shape (batch_size, latent_dim).

Returns:

  • callable

    Loss function that returns a tensor of shape (batch_size, set_size, 1).

References

[1] Chen et al. (2018): Isolating sources of disentanglement in Variational Autoencoders.

Source code in VAE/losses.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def TotalCorrelation(z: tf.Tensor, z_mean: tf.Tensor, z_log_var: tf.Tensor) -> callable:
    """Total correlation.

    The total correlation is already part of the KL divergence, see KL decomposition in [1]. This function only returns
    the batch-wise sampled total correlation loss that will be added on top of the KL divergence. The total correlation
    is computed separately for each step along the axis of length `set_size` of the input tensor `z`.

     Parameters:
        z:
            Sample from latent space of shape `(batch_size, set_size, latent_dim)`.
        z_mean:
            Mean of latent space of shape `(batch_size, latent_dim)`.
        z_log_var:
            Log of variance of latent space of shape `(batch_size, latent_dim)`.

    Returns:
        Loss function that returns a tensor of shape `(batch_size, set_size, 1)`.

    References:
        [1] Chen et al. (2018): Isolating sources of disentanglement in Variational Autoencoders.

    """
    # expand to shape (batch_size, 1, latent_dim) for broadcasting with z of shape (batch_size, set_size, latent_dim)
    z_mean = tf.expand_dims(z_mean, axis=1)
    z_log_var = tf.expand_dims(z_log_var, axis=1)

    def tc(y_true, y_pred) -> tf.Tensor:
        # log prob of all combinations along first axis of length batch_size
        # shape (batch_size, batch_size, set_size, latent_dim)
        mat_log_qz = vaemath.log_density_gaussian(z, z_mean, z_log_var, all_combinations=True, axis=0)

        # log prob of joint distribution of shape (batch_size, set_size, 1)
        log_qz = vaemath.reduce_logmeanexp(tf.reduce_sum(mat_log_qz, axis=-1, keepdims=True), axis=1)

        # log prob of product of marginal distributions of shape (batch_size, set_size, 1)
        log_prod_qz = tf.reduce_sum(vaemath.reduce_logmeanexp(mat_log_qz, axis=1), axis=-1, keepdims=True)

        # total correlation loss of shape (batch_size, set_size, 1)
        tc_loss = log_qz - log_prod_qz

        return tc_loss

    return tc

VAE.losses.TotalCorrelationBetween

TotalCorrelationBetween(z, z_mean, z_log_var, repeat_samples=1)

Total correlation between repeated samples.

Returns the total correlation loss between repeated samples. This is the same as the total correlation but restricted to the same repeated samples. This means that z_mean and z_log_var are split into segments of length repeat_samples along the first axis and the total correlation is computed within each segment.

This version of the total correlation is useful for the case where the model is trained with repeated input samples. It helps increase the diversity of the latent distribution between repeated samples.

Parameters:

  • z (Tensor) –

    Sample from latent space of shape (batch_size, set_size, latent_dim).

  • z_mean (Tensor) –

    Mean of latent space of shape (batch_size, latent_dim).

  • z_log_var (Tensor) –

    Log of variance of latent space of shape (batch_size, latent_dim).

  • repeat_samples (int, default: 1 ) –

    Number of repeated samples.

Returns:

  • callable

    Loss function that returns a tensor of shape (batch_size, set_size, 1).

Source code in VAE/losses.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def TotalCorrelationBetween(z: tf.Tensor, z_mean: tf.Tensor, z_log_var: tf.Tensor, repeat_samples: int = 1) -> callable:
    """Total correlation between repeated samples.

    Returns the total correlation loss between repeated samples. This is the same as the total correlation but
    restricted to the same repeated samples. This means that `z_mean` and `z_log_var` are split into segments of length
    `repeat_samples` along the first axis and the total correlation is computed within each segment.

    This version of the total correlation is useful for the case where the model is trained with repeated input samples.
    It helps increase the diversity of the latent distribution between repeated samples.

    Parameters:
        z:
            Sample from latent space of shape `(batch_size, set_size, latent_dim)`.
        z_mean:
            Mean of latent space of shape `(batch_size, latent_dim)`.
        z_log_var:
            Log of variance of latent space of shape `(batch_size, latent_dim)`.
        repeat_samples:
            Number of repeated samples.

    Returns:
        Loss function that returns a tensor of shape `(batch_size, set_size, 1)`.

    """
    # reshape z_mean and z_log_var to shape (batch_size // repeat_samples, repeat_samples, 1, latent_dim)
    z_mean = tf.reshape(z_mean, (-1, repeat_samples, 1, z_mean.shape[-1]))
    z_log_var = tf.reshape(z_log_var, (-1, repeat_samples, 1, z_log_var.shape[-1]))

    # reshape z to shape (batch_size // repeat_samples, repeat_samples, set_size, latent_dim)
    z = tf.reshape(z, (-1, repeat_samples, z.shape[-2], z.shape[-1]))

    def tc_between(y_true, y_pred):
        # log prob of all combinations along second axis of size repeat_samples
        # shape (batch_size // repeat_samples, repeat_samples, repeat_samples, set_size, latent_dim)
        mat_log_qz = vaemath.log_density_gaussian(z, z_mean, z_log_var, all_combinations=True, axis=1)

        # log prob of joint distribution, shape (batch_size // repeat_samples, repeat_samples, set_size, 1)
        log_qz = vaemath.reduce_logmeanexp(tf.reduce_sum(mat_log_qz, axis=-1, keepdims=True), axis=2)

        # log prob of product of marg. distribution, shape (batch_size // repeat_samples, repeat_samples, set_size, 1)
        log_prod_qz = tf.reduce_sum(vaemath.reduce_logmeanexp(mat_log_qz, axis=2), axis=-1, keepdims=True)

        # total correlation loss, shape (batch_size // repeat_samples, repeat_samples, set_size, 1)
        tc_loss = log_qz - log_prod_qz

        # reshape to shape (batch_size, set_size, 1)
        tc_loss = tf.reshape(tc_loss, (-1, tc_loss.shape[-2], 1))

        return tc_loss

    return tc_between

VAE.losses.TotalCorrelationWithin

TotalCorrelationWithin(z, z_mean, z_log_var, repeat_samples=1)

Total correlation within repeated samples.

Returns the total correlation loss within all samples of same repetition. This is the same as the total correlation but restricted to samples of the same repetition. This means that z_mean and z_log_var are split into strided views with stride repeat_samples along the first axis and the total correlation is computed within each view.

Parameters:

  • z (Tensor) –

    Sample from latent space of shape (batch_size, set_size, latent_dim).

  • z_mean (Tensor) –

    Mean of latent space of shape (batch_size, latent_dim).

  • z_log_var (Tensor) –

    Log of variance of latent space of shape (batch_size, latent_dim).

  • repeat_samples (int, default: 1 ) –

    Number of repeated samples.

Returns:

  • callable

    Loss function that returns a tensor of shape (batch_size, set_size, 1).

Source code in VAE/losses.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
def TotalCorrelationWithin(z: tf.Tensor, z_mean: tf.Tensor, z_log_var: tf.Tensor, repeat_samples: int = 1) -> callable:
    """Total correlation within repeated samples.

    Returns the total correlation loss within all samples of same repetition. This is the same as the total correlation
    but restricted to samples of the same repetition. This means that `z_mean` and `z_log_var` are split into strided
    views with stride `repeat_samples` along the first axis and the total correlation is computed within each view.

    Parameters:
        z:
            Sample from latent space of shape `(batch_size, set_size, latent_dim)`.
        z_mean:
            Mean of latent space of shape `(batch_size, latent_dim)`.
        z_log_var:
            Log of variance of latent space of shape `(batch_size, latent_dim)`.
        repeat_samples:
            Number of repeated samples.

    Returns:
        Loss function that returns a tensor of shape `(batch_size, set_size, 1)`.

    """
    # reshape z_mean and z_log_var to shape (batch_size // repeat_samples, repeat_samples, 1, latent_dim)
    z_mean = tf.reshape(z_mean, (-1, repeat_samples, 1, z_mean.shape[-1]))
    z_log_var = tf.reshape(z_log_var, (-1, repeat_samples, 1, z_log_var.shape[-1]))

    # reshape z to shape (batch_size // repeat_samples, repeat_samples, set_size, latent_dim)
    z = tf.reshape(z, (-1, repeat_samples, z.shape[-2], z.shape[-1]))

    def tc_within(y_true, y_pred):
        # log prob of all combinations along first axis of size batch_size // repeat_samples
        # shape (batch_size // repeat_samples, batch_size // repeat_samples, repeat_samples, set_size, latent_dim)
        mat_log_qz = vaemath.log_density_gaussian(z, z_mean, z_log_var, all_combinations=True, axis=0)

        # log prob of joint distribution, shape (batch_size // repeat_samples, repeat_samples, set_size, 1)
        log_qz = vaemath.reduce_logmeanexp(tf.reduce_sum(mat_log_qz, axis=-1, keepdims=True), axis=1)

        # log prob of product of marg. distribution, shape (batch_size // repeat_samples, repeat_samples, set_size, 1)
        log_prod_qz = tf.reduce_sum(vaemath.reduce_logmeanexp(mat_log_qz, axis=1), axis=-1, keepdims=True)

        # total correlation loss, shape (batch_size // repeat_samples, repeat_samples, set_size, 1)
        tc_loss = log_qz - log_prod_qz

        # revert shape to (batch_size, set_size, 1)
        tc_loss = tf.reshape(tc_loss, (-1, tc_loss.shape[-2], 1))

        return tc_loss

    return tc_within

VAE.losses.VAEloss

VAEloss(z, z_mean, z_log_var, beta=1.0, size=1, gamma=0.0, gamma_between=0.0, gamma_within=0.0, delta=0.0, delta_between=0.0, kl_threshold=None, free_bits=None, repeat_samples=1, taper=None)

Variational auto-encoder loss function.

Loss function of variational auto-encoder, for use in :func:models.VAE. The input to the loss function has shape (batch_size, set_size, output_length, channels). The output of the loss function has shape (batch_size, set_size, 1). The sample weights from the generator must have shape (batch_size, set_size). This will make the sample weights sample-dependent; see also sample_weight_mode='temporal' in model compile.

Parameters:

  • z (Tensor) –

    Sample from latent space of shape (batch_size, set_size, latent_dim).

  • z_mean (Tensor) –

    Mean of latent space of shape (batch_size, latent_dim).

  • z_log_var (Tensor) –

    Log of variance of latent space of shape (batch_size, latent_dim).

  • size (int, default: 1 ) –

    Size of decoder output, i.e. total number of elements.

  • beta (Union[float, Tensor], default: 1.0 ) –

    Loss weight of the KL divergence. If beta is a float, the loss weight is constant. If beta is a tensor, it should have shape (batch_size, 1).

  • gamma (float, default: 0.0 ) –

    Scale of total correlation loss. See :func:losses.TotalCorrelation.

  • gamma_between (float, default: 0.0 ) –

    Scale of total correlation loss between repeated samples. See :func:losses.TotalCorrelationBetween.

  • gamma_within (float, default: 0.0 ) –

    Scale of total correlation loss within repeated samples. See :func:losses.TotalCorrelationWithin.

  • delta (float, default: 0.0 ) –

    Scale of similarity loss. See :func:losses.Similarity.

  • delta_between (float, default: 0.0 ) –

    Scale of similarity loss between repeated samples. See :func:losses.SimilarityBetween.

  • kl_threshold (float, default: None ) –

    Lower bound for the KL divergence. See :func:losses.KLDivergence.

  • free_bits (float, default: None ) –

    Number of bits to keep free for the KL divergence per latent dimension. See :func:losses.KLDivergence.

  • repeat_samples (int, default: 1 ) –

    Number of repetitions of input samples present in the batch.

  • taper (array_like, default: None ) –

    Numpy array of length output_length to taper the squared error. See :func:losses.SquaredError.

Returns:

  • callable

    Loss function that returns a tensor of shape (batch_size, set_size, 1).

Source code in VAE/losses.py
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
def VAEloss(z: tf.Tensor,
            z_mean: tf.Tensor,
            z_log_var: tf.Tensor,
            beta: Union[float, tf.Tensor] = 1.,
            size: int = 1,
            gamma: float = 0.,
            gamma_between: float = 0.,
            gamma_within: float = 0.,
            delta: float = 0.,
            delta_between: float = 0.,
            kl_threshold: float = None,
            free_bits: float = None,
            repeat_samples: int = 1,
            taper=None) -> callable:
    """Variational auto-encoder loss function.

    Loss function of variational auto-encoder, for use in :func:`models.VAE`. The input to the loss function has shape
    `(batch_size, set_size, output_length, channels)`. The output of the loss function has shape `(batch_size, set_size,
    1)`. The sample weights from the generator must have shape `(batch_size, set_size)`. This will make the sample
    weights sample-dependent; see also `sample_weight_mode='temporal'` in model compile.

    Parameters:
        z:
            Sample from latent space of shape `(batch_size, set_size, latent_dim)`.
        z_mean:
            Mean of latent space of shape `(batch_size, latent_dim)`.
        z_log_var:
            Log of variance of latent space of shape `(batch_size, latent_dim)`.
        size:
            Size of decoder output, i.e. total number of elements.
        beta:
            Loss weight of the KL divergence. If `beta` is a float, the loss weight is constant. If `beta` is a
            tensor, it should have shape `(batch_size, 1)`.
        gamma:
            Scale of total correlation loss. See :func:`losses.TotalCorrelation`.
        gamma_between:
            Scale of total correlation loss between repeated samples. See :func:`losses.TotalCorrelationBetween`.
        gamma_within:
            Scale of total correlation loss within repeated samples. See :func:`losses.TotalCorrelationWithin`.
        delta:
            Scale of similarity loss. See :func:`losses.Similarity`.
        delta_between:
            Scale of similarity loss between repeated samples. See :func:`losses.SimilarityBetween`.
        kl_threshold:
            Lower bound for the KL divergence. See :func:`losses.KLDivergence`.
        free_bits:
            Number of bits to keep free for the KL divergence per latent dimension. See :func:`losses.KLDivergence`.
        repeat_samples:
            Number of repetitions of input samples present in the batch.
        taper (array_like):
            Numpy array of length `output_length` to taper the squared error. See :func:`losses.SquaredError`.

    Returns:
        Loss function that returns a tensor of shape `(batch_size, set_size, 1)`.

    """
    if isinstance(beta, tf.Tensor):
        # add singleton dimension to beta to shape (batch_size, 1, 1)
        beta = tf.expand_dims(beta, axis=-1)

    def vae_loss(y_true, y_pred):
        squared_error_fcn = SquaredError(size=size, taper=taper)
        squared_error = squared_error_fcn(y_true, y_pred)

        kl_loss_fcn = KLDivergence(z_mean, z_log_var, kl_threshold=kl_threshold, free_bits=free_bits)
        entropy = kl_loss_fcn(y_true, y_pred)

        if gamma:
            tc_loss_fcn = TotalCorrelation(z, z_mean, z_log_var)
            entropy += gamma * tc_loss_fcn(y_true, y_pred)

        if gamma_between:
            tc_between_loss_fcn = TotalCorrelationBetween(z, z_mean, z_log_var, repeat_samples=repeat_samples)
            entropy += gamma_between * tc_between_loss_fcn(y_true, y_pred)

        if gamma_within:
            tc_within_loss_fcn = TotalCorrelationWithin(z, z_mean, z_log_var, repeat_samples=repeat_samples)
            entropy += gamma_within * tc_within_loss_fcn(y_true, y_pred)

        if delta:
            sim_loss_fcn = Similarity()
            entropy += delta * sim_loss_fcn(y_true, y_pred)

        if delta_between:
            sim_between_loss_fcn = SimilarityBetween(repeat_samples=repeat_samples)
            entropy += delta * sim_between_loss_fcn(y_true, y_pred)

        return squared_error + beta * entropy

    return vae_loss

VAE.losses.VAEploss

VAEploss(beta=1.0, delta=0.0, delta_between=0.0, repeat_samples=1, size=1, taper=None)

VAE prediction loss function.

Parameters:

  • beta (Union[float, Tensor], default: 1.0 ) –

    Loss weight of the KL divergence. If beta is a float, the loss weight is constant. If beta is a tensor, it should have shape (batch_size, 1).

  • delta (float, default: 0.0 ) –

    Scale of similarity loss. See :func:losses.Similarity.

  • delta_between (float, default: 0.0 ) –

    Scale of similarity loss between repeated samples. See :func:losses.SimilarityBetween.

  • size (int, default: 1 ) –

    Size of prediction output, i.e. total number of elements.

  • repeat_samples (int, default: 1 ) –

    Number of repetitions of input samples present in the batch.

  • taper (array_like, default: None ) –

    Array of length output_length to taper the squared error. See :func:losses.SquaredError.

Returns:

  • callable

    Loss function that returns a tensor of shape (batch_size, set_size, 1).

Source code in VAE/losses.py
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
def VAEploss(beta: Union[float, tf.Tensor] = 1.,
             delta: float = 0.,
             delta_between: float = 0.,
             repeat_samples: int = 1,
             size: int = 1,
             taper=None) -> callable:
    """VAE prediction loss function.

    Parameters:
        beta:
            Loss weight of the KL divergence. If `beta` is a float, the loss weight is constant. If `beta` is a
            tensor, it should have shape `(batch_size, 1)`.
        delta:
            Scale of similarity loss. See :func:`losses.Similarity`.
        delta_between:
            Scale of similarity loss between repeated samples. See :func:`losses.SimilarityBetween`.
        size:
            Size of prediction output, i.e. total number of elements.
        repeat_samples:
            Number of repetitions of input samples present in the batch.
        taper (array_like):
            Array of length `output_length` to taper the squared error. See :func:`losses.SquaredError`.

    Returns:
        Loss function that returns a tensor of shape `(batch_size, set_size, 1)`.

    """
    if isinstance(beta, tf.Tensor):
        # add singleton dimension to beta to shape (batch_size, 1, 1)
        beta = tf.expand_dims(beta, axis=-1)

    def vaep_loss(y_true, y_pred):
        squared_error_fcn = SquaredError(size=size, taper=taper)
        squared_error = squared_error_fcn(y_true, y_pred)

        entropy = tf.zeros_like(squared_error)

        if delta:
            sim_loss_fcn = Similarity()
            entropy += delta * sim_loss_fcn(y_true, y_pred)

        if delta_between:
            sim_between_loss_fcn = SimilarityBetween(repeat_samples=repeat_samples)
            entropy += delta_between * sim_between_loss_fcn(y_true, y_pred)

        return squared_error + beta * entropy

    return vaep_loss

VAE.losses.example_total_correlation_losses

example_total_correlation_losses()

Example of total correlation loss functions.

Source code in VAE/losses.py
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
def example_total_correlation_losses():
    """Example of total correlation loss functions."""
    batch_size = 32
    repeat_samples = 20
    shape = (batch_size * repeat_samples, 8)
    set_size = 7

    z_mean = tf.random.normal(shape)
    z_log_var = tf.random.normal(shape) * 0.1 - 1.
    z = z_mean + tf.exp(z_log_var * 0.5) * tf.random.normal(shape)
    z = tf.expand_dims(z, axis=1)
    z = tf.repeat(z, repeats=set_size, axis=1)

    fcns = {
        'TC loss': TotalCorrelation(z, z_mean, z_log_var),
        'TC loss between': TotalCorrelationBetween(z, z_mean, z_log_var, repeat_samples=repeat_samples),
        'TC loss within': TotalCorrelationWithin(z, z_mean, z_log_var, repeat_samples=repeat_samples),
    }

    print(f'{"Batch size":<20} {batch_size} * {repeat_samples} = {batch_size * repeat_samples}')

    for name, fcn in fcns.items():
        tc_loss = fcn(None, None)
        tc_mean = tf.reduce_mean(tc_loss)
        tc_std = tf.math.reduce_std(tc_loss)
        print(f'{name:<20} mean={tc_mean:.2f}  std={tc_std:.2f}  shape={tc_loss.shape}')

VAE.losses.example_similarity_losses

example_similarity_losses()

Example of similarity loss functions.

Source code in VAE/losses.py
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
def example_similarity_losses():
    """Example of similarity loss functions."""
    batch_size = 32
    repeat_samples = 5
    shape = (batch_size * repeat_samples, 1, 160, 3)

    inputs = tf.random.normal(shape)

    fcns = {
        'Sim loss': Similarity(),
        'Sim loss between': _SimilarityBetween(repeat_samples=repeat_samples),
        'Sim loss between (fast)': SimilarityBetween(repeat_samples=repeat_samples),
    }

    print(f'{"Batch size":<25} {batch_size} * {repeat_samples} = {batch_size * repeat_samples}')

    losses = []
    for name, fcn in fcns.items():
        loss = fcn(None, inputs)
        losses.append(loss)
        mean_loss = tf.reduce_mean(loss)
        std_loss = tf.math.reduce_std(loss)
        print(f'{name:<25} mean={mean_loss:.2f}  std={std_loss:.2f}  shape={loss.shape}')

    return losses