Losses
VAE.losses
Collection of loss functions.
VAE.losses.KLDivergence
KLDivergence(z_mean, z_log_var, kl_threshold=None, free_bits=None)
Kullback-Leibler divergence.
This is the KL divergence between N(z_mean
, z_log_var
) and the prior N(0, 1).
Parameters:
-
z_mean
(Tensor
) –Tensor of shape
(batch_size, latent_dim)
specifying mean. -
z_log_var
(Tensor
) –Tensor of shape
(batch_size, latent_dim)
specifying log of variance. -
kl_threshold
(float
, default:None
) –Lower bound for the KL divergence. Default is
None
. -
free_bits
(float
, default:None
) –Number of bits to keep free for the KL divergence per latent dimension; cf. Appendix C8 in [1]. Default is
None
.
Returns:
-
callable
–Loss function that returns a tensor of shape
(batch_size, 1, 1)
.
References
[1] Kingma et al. (2016): Improved Variational Inference with Inverse Autoregressive Flow. NIPS 2016.
Source code in VAE/losses.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
|
VAE.losses.SquaredError
SquaredError(size=1, taper=None)
Squared error loss.
This is the reconstruction loss of the model, without the KL divergence.
Parameters:
-
size
(int
, default:1
) –Size of the model output of shape
(set_size, output_length, channels)
. -
taper
(array_like
, default:None
) –Array of length
output_length
to taper the squared error.
Returns:
-
callable
–Loss function that returns a tensor of shape
(batch_size, set_size, 1)
.
Source code in VAE/losses.py
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
|
VAE.losses.Similarity
Similarity(temperature=1.0)
Similarity loss.
This loss flattens all but the leading batch dimension of y_pred
to the shape (batch_size, -1)
. The similarity
is then calculated of the reshaped input.
Parameters:
-
temperature
(float
, default:1.0
) –Temperature for the softmax.
Returns: Loss function.
Source code in VAE/losses.py
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
|
VAE.losses.SimilarityBetween
SimilarityBetween(repeat_samples=1, temperature=1.0)
Similarity between repeated samples (fast implementation).
This function returns the similarity between repeated samples. The input y_pred
is first reshaped to the shape
(batch_size // repeat_samples, repeat_samples, ...)
and the similarity is calculated for each slice along the
first dimension.
Note: This is an implementation with einsum that avoids calling :func:math.similarity
with :func:tf.map_fn
.
Parameters:
-
repeat_samples
(int
, default:1
) –Number of repeated samples.
-
temperature
(float
, default:1.0
) –Temperature for the softmax.
Returns: Loss function.
Source code in VAE/losses.py
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
|
VAE.losses.TotalCorrelation
TotalCorrelation(z, z_mean, z_log_var)
Total correlation.
The total correlation is already part of the KL divergence, see KL decomposition in [1]. This function only returns
the batch-wise sampled total correlation loss that will be added on top of the KL divergence. The total correlation
is computed separately for each step along the axis of length set_size
of the input tensor z
.
Parameters:
z:
Sample from latent space of shape (batch_size, set_size, latent_dim)
.
z_mean:
Mean of latent space of shape (batch_size, latent_dim)
.
z_log_var:
Log of variance of latent space of shape (batch_size, latent_dim)
.
Returns:
-
callable
–Loss function that returns a tensor of shape
(batch_size, set_size, 1)
.
References
[1] Chen et al. (2018): Isolating sources of disentanglement in Variational Autoencoders.
Source code in VAE/losses.py
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
|
VAE.losses.TotalCorrelationBetween
TotalCorrelationBetween(z, z_mean, z_log_var, repeat_samples=1)
Total correlation between repeated samples.
Returns the total correlation loss between repeated samples. This is the same as the total correlation but
restricted to the same repeated samples. This means that z_mean
and z_log_var
are split into segments of length
repeat_samples
along the first axis and the total correlation is computed within each segment.
This version of the total correlation is useful for the case where the model is trained with repeated input samples. It helps increase the diversity of the latent distribution between repeated samples.
Parameters:
-
z
(Tensor
) –Sample from latent space of shape
(batch_size, set_size, latent_dim)
. -
z_mean
(Tensor
) –Mean of latent space of shape
(batch_size, latent_dim)
. -
z_log_var
(Tensor
) –Log of variance of latent space of shape
(batch_size, latent_dim)
. -
repeat_samples
(int
, default:1
) –Number of repeated samples.
Returns:
-
callable
–Loss function that returns a tensor of shape
(batch_size, set_size, 1)
.
Source code in VAE/losses.py
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
|
VAE.losses.TotalCorrelationWithin
TotalCorrelationWithin(z, z_mean, z_log_var, repeat_samples=1)
Total correlation within repeated samples.
Returns the total correlation loss within all samples of same repetition. This is the same as the total correlation
but restricted to samples of the same repetition. This means that z_mean
and z_log_var
are split into strided
views with stride repeat_samples
along the first axis and the total correlation is computed within each view.
Parameters:
-
z
(Tensor
) –Sample from latent space of shape
(batch_size, set_size, latent_dim)
. -
z_mean
(Tensor
) –Mean of latent space of shape
(batch_size, latent_dim)
. -
z_log_var
(Tensor
) –Log of variance of latent space of shape
(batch_size, latent_dim)
. -
repeat_samples
(int
, default:1
) –Number of repeated samples.
Returns:
-
callable
–Loss function that returns a tensor of shape
(batch_size, set_size, 1)
.
Source code in VAE/losses.py
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 |
|
VAE.losses.VAEloss
VAEloss(z, z_mean, z_log_var, beta=1.0, size=1, gamma=0.0, gamma_between=0.0, gamma_within=0.0, delta=0.0, delta_between=0.0, kl_threshold=None, free_bits=None, repeat_samples=1, taper=None)
Variational auto-encoder loss function.
Loss function of variational auto-encoder, for use in :func:models.VAE
. The input to the loss function has shape
(batch_size, set_size, output_length, channels)
. The output of the loss function has shape (batch_size, set_size,
1)
. The sample weights from the generator must have shape (batch_size, set_size)
. This will make the sample
weights sample-dependent; see also sample_weight_mode='temporal'
in model compile.
Parameters:
-
z
(Tensor
) –Sample from latent space of shape
(batch_size, set_size, latent_dim)
. -
z_mean
(Tensor
) –Mean of latent space of shape
(batch_size, latent_dim)
. -
z_log_var
(Tensor
) –Log of variance of latent space of shape
(batch_size, latent_dim)
. -
size
(int
, default:1
) –Size of decoder output, i.e. total number of elements.
-
beta
(Union[float, Tensor]
, default:1.0
) –Loss weight of the KL divergence. If
beta
is a float, the loss weight is constant. Ifbeta
is a tensor, it should have shape(batch_size, 1)
. -
gamma
(float
, default:0.0
) –Scale of total correlation loss. See :func:
losses.TotalCorrelation
. -
gamma_between
(float
, default:0.0
) –Scale of total correlation loss between repeated samples. See :func:
losses.TotalCorrelationBetween
. -
gamma_within
(float
, default:0.0
) –Scale of total correlation loss within repeated samples. See :func:
losses.TotalCorrelationWithin
. -
delta
(float
, default:0.0
) –Scale of similarity loss. See :func:
losses.Similarity
. -
delta_between
(float
, default:0.0
) –Scale of similarity loss between repeated samples. See :func:
losses.SimilarityBetween
. -
kl_threshold
(float
, default:None
) –Lower bound for the KL divergence. See :func:
losses.KLDivergence
. -
free_bits
(float
, default:None
) –Number of bits to keep free for the KL divergence per latent dimension. See :func:
losses.KLDivergence
. -
repeat_samples
(int
, default:1
) –Number of repetitions of input samples present in the batch.
-
taper
(array_like
, default:None
) –Numpy array of length
output_length
to taper the squared error. See :func:losses.SquaredError
.
Returns:
-
callable
–Loss function that returns a tensor of shape
(batch_size, set_size, 1)
.
Source code in VAE/losses.py
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 |
|
VAE.losses.VAEploss
VAEploss(beta=1.0, delta=0.0, delta_between=0.0, repeat_samples=1, size=1, taper=None)
VAE prediction loss function.
Parameters:
-
beta
(Union[float, Tensor]
, default:1.0
) –Loss weight of the KL divergence. If
beta
is a float, the loss weight is constant. Ifbeta
is a tensor, it should have shape(batch_size, 1)
. -
delta
(float
, default:0.0
) –Scale of similarity loss. See :func:
losses.Similarity
. -
delta_between
(float
, default:0.0
) –Scale of similarity loss between repeated samples. See :func:
losses.SimilarityBetween
. -
size
(int
, default:1
) –Size of prediction output, i.e. total number of elements.
-
repeat_samples
(int
, default:1
) –Number of repetitions of input samples present in the batch.
-
taper
(array_like
, default:None
) –Array of length
output_length
to taper the squared error. See :func:losses.SquaredError
.
Returns:
-
callable
–Loss function that returns a tensor of shape
(batch_size, set_size, 1)
.
Source code in VAE/losses.py
446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 |
|
VAE.losses.example_total_correlation_losses
example_total_correlation_losses()
Example of total correlation loss functions.
Source code in VAE/losses.py
496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 |
|
VAE.losses.example_similarity_losses
example_similarity_losses()
Example of similarity loss functions.
Source code in VAE/losses.py
524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 |
|