Skip to content

Generators

VAE.generators

Collection of generators for VAE model training.

Generator class for data given as a Numpy array(s).

VAE.generators.FitGenerator

FitGenerator(datasets, input_length, batch_size=32, beta_scheduler=None, condition=None, ensemble_size=None, ensemble_type='random', ensemble_index=None, ensemble_range=None, ensemble_replace=False, ensemble_sync=False, filter_length=None, initial_epoch=0, input_channels=None, latitude=0, longitude=None, prediction_channels=None, prediction_length=None, repeat_samples=1, sample_weights=None, shuffle=True, sph_degree=None, strides=1, time=None, tp_period=None, dtype='float32', **kwargs)

Bases: Sequence

Generator class for model training.

Given an Numpy array of shape (set_size, data_length, channels), the generator prepares the inputs andtargetsfor the model training in :func:keras.Model.fit_generator()`.

Parameters:

  • datasets (Union[ndarray, list[ndarray]]) –

    Dataset used for training. The dataset can be either a single numpy array of shape (set_size, data_length, channels) or a list of Numpy arrays. In case of a list of Numpy arrays, set_size and channels must be the same, while data_length can vary. Missing (non-finite) values will be excluded from the samples.

  • input_length (int) –

    Length of input to the encoder.

  • batch_size (int, default: 32 ) –

    Batch size. Note that the effective batch size is batch_size * repeat_samples.

  • beta_scheduler

    Instance of :class:BetaScheduler that returns the beta parameters for the KL loss in each epoch.

  • condition (Union[ndarray, list[ndarray], dict], default: None ) –

    Additional data used as condition. The Numpy arrays must be of length data_length matching the Numpy array(s) in dataset. If a list is provided, the length of the list must match the length of datasets. If a dict is provided, the keys must match encoder and decoder. This allows to pass different conditions to the encoder and decoder, provided as the corresponding dict values.

  • ensemble_size (int, default: None ) –

    Size for the one-hot encoded ensemble condition.

  • ensemble_type (str, default: 'random' ) –

    Whether to use the dataset index (index) or random ensemble condition (random, 'random_full'). If index, the ensemble condition corresponds to the dataset index. If 'random' or random_full, the ensemble condition is sampled from a uniform distribution in the range ensemble_range. The samples are the same for all samples in a batch ('random') or different for each sample in a batch ('random_full'). Defaults to random.

  • ensemble_index (int, default: None ) –

    Array of indices used as ensemble condition if ensemble_type is index. Must match the length of dataset and must be in range (0, ensemble_size). Defaults to None meaning the dataset index is used.

  • ensemble_range (tuple[int, int], default: None ) –

    Range of the random ensemble condition. Must be a subrange of (0, ensemble_size). Defaults to None and is set to (0, ensemble_size).

  • ensemble_replace (bool, default: False ) –

    Whether to sample the random ensemble condition with replacement if repeat_samples > 1. Defaults to False.

  • ensemble_sync (bool, default: False ) –

    Synchronize random ensemble conditions between encoder and decoder. If True , the random ensemble conditions of the encoder and decoder are the same. Defaults to False, i.e. the random ensemble conditions of the encoder an decoder are different random samples. Note that the ensemble conditions of the decoder and prediction are always the same.

  • filter_length (Union[int, tuple[int, int]], default: None ) –

    Length of the temporal filter for the inputs and targets. A centered moving average filter of length 2 * filter_length + 1 is applied to the inputs and targets. If a tuple of two ints is given, the first int is the length of the filter for the input to the encoder and the target to the decoder. The second int is the length of the filter for the target to the prediction. Defaults to None, i.e. no filter.

  • initial_epoch (int, default: 0 ) –

    Initial epoch at which the generator will start. This will affect the beta parameter.

  • input_channels (list[int], default: None ) –

    Range of channels used as input. The items in the tuple refer to start, stop and step in slice notation. Defaults to None means all channels are used.

  • latitude (ndarray, default: 0 ) –

    Latitude in degree if the spherical harmonics are used for spatial condition.

  • longitude (ndarray, default: None ) –

    Longitude of the data in degree if the spherical harmonics are used for spatial condition. Length of longitude must be equal to set_size. Defaults to None and is set to np.arange(0, 360, 360/set_size).

  • prediction_channels (list[int], default: None ) –

    Range of channels used for prediction. The items in the tuple refer to start, stop and step in slice notation. Defaults to None means all channels are used.

  • prediction_length (int, default: None ) –

    Length of prediction. Defaults to None means no prediction.

  • repeat_samples (int, default: 1 ) –

    Number of times the same sample is repeated in the batch. This will augment the batch size. This options is useful in combination with the ensemble condition, in which the same samples is presented multiple times with different random samples of the ensemble conditions. Defaults to 1.

  • sample_weights (ndarray, default: None ) –

    Sample weights of shape (nr_datasets, set_size). Defaults to None.

  • shuffle (bool, default: True ) –

    Shuffle samples order.

  • sph_degree (int, default: None ) –

    Number of spherical degrees if the spherical harmonics are used for spatial condition.

  • strides (int, default: 1 ) –

    Sample strides along second dimension of data of size data_length.

  • time (Union[ndarray, list[ndarray]], default: None ) –

    Time of the data if the time-periodic harmonics are used for temporal condition. Must be of length data_length. If a list is provided, the length of the list must match the length of datasets.

  • tp_period (float, default: None ) –

    Maximal period for the temporal harmonics. See :func:get_tp_harmonics.

  • dtype (str, default: 'float32' ) –

    Dtype of the data that will be returned.

Source code in VAE/generators.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def __init__(self,
             datasets: Union[np.ndarray, list[np.ndarray]],
             input_length: int,
             batch_size: int = 32,
             beta_scheduler=None,
             condition: Union[np.ndarray, list[np.ndarray], dict] = None,
             ensemble_size: int = None,
             ensemble_type: str = 'random',
             ensemble_index: int = None,
             ensemble_range: tuple[int, int] = None,
             ensemble_replace: bool = False,
             ensemble_sync: bool = False,
             filter_length: Union[int, tuple[int, int]] = None,
             initial_epoch: int = 0,
             input_channels: list[int] = None,
             latitude: np.ndarray = 0,
             longitude: np.ndarray = None,
             prediction_channels: list[int] = None,
             prediction_length: int = None,
             repeat_samples: int = 1,
             sample_weights: np.ndarray = None,
             shuffle: bool = True,
             sph_degree: int = None,
             strides: int = 1,
             time: Union[np.ndarray, list[np.ndarray]] = None,
             tp_period: float = None,
             dtype: str = 'float32',
             **kwargs):
    """Instantiate generator."""
    if not isinstance(datasets, (list, tuple)):
        datasets = [datasets]

    shapes = {dataset[:, 0, :].shape for dataset in datasets}
    if len(shapes) > 1:
        raise ValueError('all datasets must have the same set_size and number of channels')

    set_size, channels = shapes.pop()
    self.channels = channels
    self.set_size = set_size

    self.batch_size = batch_size
    self.beta_scheduler = beta_scheduler
    self.datasets = datasets
    self.dtype = dtype
    self.ensemble_replace = ensemble_replace
    self.ensemble_size = ensemble_size
    self.ensemble_sync = ensemble_sync
    self.epoch = initial_epoch
    self.filter_length = np.broadcast_to(filter_length, (2, ))
    self.input_length = input_length
    self.repeat_samples = repeat_samples
    self.shuffle = shuffle
    self.strides = strides
    self.prediction_length = prediction_length if prediction_length is not None else 0

    self.input_channels = slice(*input_channels) if input_channels is not None else slice(None)
    self.prediction_channels = slice(*prediction_channels) if prediction_channels is not None else slice(None)

    if condition is not None:
        if not isinstance(condition, dict):
            # same condition for encoder and decoder
            condition = {'encoder': condition}
        else:
            if 'encoder' not in condition.keys():
                raise KeyError('Require at least `encoder` item in `condition`.')

        # prepare condition
        for key, value in condition.items():
            if isinstance(value, (list, tuple)):
                condition.update({key: [self._prepare_condition(v) for v in value]})
            else:
                condition.update({key: self._prepare_condition(value)})

    self.condition = condition

    if ensemble_size is not None:
        if ensemble_range is None:
            ensemble_range = (0, ensemble_size)

        if repeat_samples > len(range(*ensemble_range)) and not ensemble_replace:
            error_msg = f'{repeat_samples=} must not be larger than {ensemble_range=}' \
                f' if sampling without replacement ({ensemble_replace=})'
            raise ValueError(error_msg)

        val_ensemble_type = {'index', 'random', 'random_full'}
        if ensemble_type not in val_ensemble_type:
            raise ValueError(f'{ensemble_type=} must be in {val_ensemble_type}')

    self.ensemble_index = np.array(ensemble_index) if ensemble_index is not None else None
    self.ensemble_type = ensemble_type
    self.ensemble_range = ensemble_range

    self._prepare_embedding()
    if self.shuffle:
        self._shuffle_data()

    if tp_period is not None:
        if time is None:
            raise ValueError('time must be given if tp_period is given')

        if isinstance(time, (list, tuple)):
            if len(time) == len(datasets):
                self.tp_harmonics = [self.get_tp_harmonics(t, tp_period) for t in time]
            else:
                raise ValueError('Length of `time` must match length of `datasets`')
        else:
            self.tp_harmonics = self.get_tp_harmonics(time, tp_period)

    else:
        self.tp_harmonics = None

    if sph_degree is not None:
        if longitude is None:
            longitude = np.arange(0, 360, 360 / set_size)

        self.sph_harmonics = self.get_sph_harmonics(latitude, longitude, sph_degree)
    else:
        self.sph_harmonics = None

    if sample_weights is not None:
        if len(sample_weights) != len(datasets):
            raise ValueError('`sample_weights` must have the same length as `datasets`')

    self.sample_weights = sample_weights

VAE.generators.FitGenerator.nr_samples property

nr_samples

Return number of samples.

VAE.generators.FitGenerator.__getitem__

__getitem__(idx)

Return batch of data.

Note that the effective batch size is batch_size * repeat_samples.

Parameters:

  • idx (int) –

    Batch index.

Returns:

  • tuple[dict, dict]

    Two dicts, one for the inputs and one for the targets.

Source code in VAE/generators.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def __getitem__(self, idx: int) -> tuple[dict, dict]:
    """Return batch of data.

    Note that the effective batch size is `batch_size * repeat_samples`.

    Parameters:
        idx:
            Batch index.

    Returns:
        Two dicts, one for the inputs and one for the targets.
    """

    if idx >= self.__len__():
        raise IndexError('batch index out of range')

    inputs = dict()
    targets = dict()

    batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size, ..., self.input_channels]
    batch_size = len(batch_x)
    batch_x = self._repeat_samples(batch_x)
    inputs['encoder_input'] = batch_x
    targets['decoder'] = batch_x

    encoder_cond = []
    decoder_cond = []

    if self.sph_harmonics is not None:
        sph_cond = self.sph_harmonics[None, :, :]
        sph_cond = np.repeat(sph_cond, batch_size * self.repeat_samples, axis=0)
        encoder_cond.append(sph_cond)
        decoder_cond.append(sph_cond)

    if self.tp_harmonics is not None:
        tp_cond = self._get_condition(self.tp_harmonics, idx)
        encoder_cond.append(tp_cond)
        decoder_cond.append(tp_cond)

    if self.condition is not None:
        ex_cond = self._get_condition(self.condition['encoder'], idx)
        if 'decoder' in self.condition.keys():
            dx_cond = self._get_condition(self.condition['decoder'], idx)
        else:
            dx_cond = ex_cond
        encoder_cond.append(ex_cond)
        decoder_cond.append(dx_cond)

    if self.ensemble_size is not None:
        if self.ensemble_sync:
            ens_cond = self.get_ensemble_condition(batch_size, idx)
            encoder_cond.append(ens_cond)
            decoder_cond.append(ens_cond)
        else:
            encoder_cond.append(self.get_ensemble_condition(batch_size, idx))
            decoder_cond.append(self.get_ensemble_condition(batch_size, idx))

    if encoder_cond:
        inputs['encoder_cond'] = np.concatenate(encoder_cond, axis=-1)
        inputs['decoder_cond'] = np.concatenate(decoder_cond, axis=-1)

    if self.y is not None:
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size, ..., self.prediction_channels]
        batch_y = self._repeat_samples(batch_y)
        targets['prediction'] = batch_y
        if 'decoder_cond' in inputs:
            inputs['prediction_cond'] = inputs['decoder_cond']

    if self.beta_scheduler is not None:
        inputs['beta'] = self.beta_scheduler(self.epoch, shape=(batch_size * self.repeat_samples, 1))

    if self.sample_weights is not None:
        indices = self.get_index(idx)[:, 0]
        sw = [self.sample_weights[i] for i in indices]
        sw = np.stack(sw, axis=0)
        samples_weights = {'decoder': sw, 'prediction': sw}
        return inputs, targets, samples_weights

    else:
        return inputs, targets

VAE.generators.FitGenerator.get_ensemble_condition

get_ensemble_condition(batch_size, idx=None)

Return ensemble condition for given batch.

A one-hot encoded ensemble index is return that is the same for all samples in the batch. The code is broadcasted along the second dimension of size set_size.

In case of repeat_samples > 1, the actual batch size is batch_size * repeat_samples and a set of repeat_samplesrandom indices is sampled.

To alter between sampling with and without replacement, the ensemble_replace flag can be set.

Parameters:

  • batch_size (int) –

    Batch size.

  • idx (int, default: None ) –

    Required if ensemble_type='index. Returns the ensemble condition corresponding to the batch with index idx.

Returns:

  • ndarray

    Array of shape (batch_size * repeat_samples, set_size, ensemble_size)

Source code in VAE/generators.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
def get_ensemble_condition(self, batch_size: int, idx: int = None) -> np.ndarray:
    """Return ensemble condition for given batch.

    A one-hot encoded ensemble index is return that is the same for all samples in the batch. The code is
    broadcasted along the second dimension of size `set_size`.

    In case of `repeat_samples > 1`, the actual batch size is `batch_size * repeat_samples` and a set of
    `repeat_samples`random indices is sampled.

    To alter between sampling with and without replacement, the `ensemble_replace` flag can be set.

    Parameters:
        batch_size:
            Batch size.
        idx:
            Required if `ensemble_type='index`. Returns the ensemble condition corresponding to the batch with index
            `idx`.

    Returns:
        Array of shape `(batch_size * repeat_samples, set_size, ensemble_size)`
    """
    if self.ensemble_type == 'random':
        ensemble_idx = np.random.choice(np.arange(*self.ensemble_range),
                                        size=self.repeat_samples,
                                        replace=self.ensemble_replace)
        condition = np.zeros((self.repeat_samples, self.set_size, self.ensemble_size), dtype=self.dtype)
        condition[np.arange(self.repeat_samples), :, ensemble_idx] = 1
        condition = np.tile(condition, (batch_size, 1, 1))

    elif self.ensemble_type == 'random_full':
        ensemble_idx = [
            np.random.choice(np.arange(*self.ensemble_range),
                             size=self.repeat_samples,
                             replace=self.ensemble_replace) for _ in range(batch_size)
        ]
        ensemble_idx = np.stack(ensemble_idx, axis=0)
        condition = np.zeros((batch_size * self.repeat_samples, self.set_size, self.ensemble_size),
                             dtype=self.dtype)
        condition[np.arange(batch_size * self.repeat_samples), :, ensemble_idx.flat] = 1

    else:  # `index`
        if idx >= self.__len__():
            raise IndexError('batch index out of range')
        if idx is None:
            raise ValueError('idx must be given')

        ensemble_idx = self.get_index(idx)[:, 0]
        if self.ensemble_index is not None:
            ensemble_idx = self.ensemble_index[ensemble_idx]

        condition = np.zeros((len(ensemble_idx), self.set_size, self.ensemble_size), dtype=self.dtype)
        condition[np.arange(len(ensemble_idx)), :, ensemble_idx] = 1
    return condition

VAE.generators.FitGenerator.get_index

get_index(idx)

Return array of dataset and time index of samples in batch.

The returned array is of shape (batch_size * repeat_samples, 2) with the first column containing the dataset index and the second column the time index of the sample. The time index refers to the first sample of the target sequence for the prediction.

Parameters:

  • idx (int) –

    Batch index.

Returns:

  • ndarray

    Array of shape (batch_size * repeat_samples, 2).

Source code in VAE/generators.py
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
def get_index(self, idx: int) -> np.ndarray:
    """Return array of dataset and time index of samples in batch.

    The returned array is of shape `(batch_size * repeat_samples, 2)` with the first column containing the dataset
    index and the second column the time index of the sample. The time index refers to the first sample of the
    target sequence for the prediction.

    Parameters:
        idx:
            Batch index.

    Returns:
        Array of shape `(batch_size * repeat_samples, 2)`.
    """
    if idx >= self.__len__():
        raise IndexError('batch index out of range')

    indices = self.index[idx * self.batch_size:(idx + 1) * self.batch_size, ...]
    indices = self._repeat_samples(indices)
    return indices

VAE.generators.FitGenerator.get_sph_harmonics

get_sph_harmonics(latitude, longitude, sph_degree)

Get spherical harmonics.

The returned array is of shape (set_size, 2 * sph_degree + 1) with the rows containing the spherical harmonics for the given latitude and longitude values.

Parameters:

  • latitude

    float Latitude value in degree.

  • longitude

    array_like Array of longitude values in degree of shape (set_size,).

Returns:

  • ndarray

    Array of shape (set_size, 2 * sph_degree + 1).

Source code in VAE/generators.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
def get_sph_harmonics(self, latitude: np.ndarray, longitude: np.ndarray, sph_degree: int) -> np.ndarray:
    """Get spherical harmonics.

    The returned array is of shape `(set_size, 2 * sph_degree + 1)` with the rows containing the spherical harmonics
    for the given latitude and longitude values.

    Parameters:
        latitude : float
            Latitude value in degree.
        longitude : array_like
            Array of longitude values in degree of shape `(set_size,)`.

    Returns:
        Array of shape `(set_size, 2 * sph_degree + 1)`.
    """
    colat = np.pi / 2 - np.deg2rad(latitude)  # [0, pi]
    lon = np.deg2rad(longitude) % (2 * np.pi)  # [0, 2*pi]

    sph = []
    for n in range(sph_degree + 1):
        s = sph_harm(n, n, lon, colat)
        sph.append(np.real(s))
        if n > 0:
            s = sph_harm(-n, n, lon, colat)
            sph.append(np.imag(s))

    return np.stack(sph, axis=1).astype(self.dtype)

VAE.generators.FitGenerator.get_tp_harmonics

get_tp_harmonics(time, tp_period)

Get temporal harmonics.

Parameters:

  • time (ndarray) –

    Array of time values for which the harmonics are calculated.

  • tp_period (int) –

    Maximal period for the temporal harmonics in time units.

Returns:

  • ndarray

    Array of shape (len(times), tp_period).

Source code in VAE/generators.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
def get_tp_harmonics(self, time: np.ndarray, tp_period: int) -> np.ndarray:
    """Get temporal harmonics.

    Parameters:
        time:
            Array of time values for which the harmonics are calculated.
        tp_period:
            Maximal period for the temporal harmonics in time units.

    Returns:
        Array of shape `(len(times), tp_period)`.
    """
    f = np.fft.rfftfreq(tp_period)[None, 1:]  # omit DC
    time = np.array(time)[:, None]
    harmonics = np.concatenate([np.sin(2 * np.pi * f * time), np.cos(2 * np.pi * f * time)], axis=1)

    return harmonics

VAE.generators.FitGenerator.on_epoch_end

on_epoch_end()

Shuffle data after each epoch.

This method is called after each epoch and shuffles the data if shuffle=True.

Source code in VAE/generators.py
503
504
505
506
507
508
509
510
511
def on_epoch_end(self):
    """Shuffle data after each epoch.

    This method is called after each epoch and shuffles the data if `shuffle=True`.

    """
    self.epoch += 1
    if self.shuffle:
        self._shuffle_data()

VAE.generators.FitGenerator.summary

summary()

Print summary.

Source code in VAE/generators.py
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
def summary(self):
    """Print summary."""
    total_size = sum([dataset.size for dataset in self.datasets])
    total_length = sum([dataset.shape[1] for dataset in self.datasets])
    print(f'Number of datasets : {len(self.datasets):,}')
    print(f'Total data size    : {total_size:,}')
    print(f'Total data length  : {total_length:,}')
    print(f'Strides            : {self.strides:,}')
    print(f'Number of samples  : {self.nr_samples:,}')
    print(f'Batch size         : {self.batch_size:,}')
    print(f'Number of batches  : {len(self):,}')

    act_batch_size = self.batch_size * self.repeat_samples
    print(f'Sample repetitions : {self.repeat_samples:,}')
    print(f'Actual batch size  : {self.batch_size:,} * {self.repeat_samples} = {act_batch_size:,}')

    print(f'Shuffle            : {self.shuffle}')

    nx, ny = self.filter_length
    if (nx is not None) or (ny is not None):
        print('Filter length')
        print(f'  input      : {nx}')
        print(f'  prediction : {ny}')

    if self.ensemble_size is not None:
        print('Ensemble condition')
        print(f'  size : {self.ensemble_size}')
        print(f'  type : {self.ensemble_type}')
        if self.ensemble_type in ['random', 'random_full']:
            print(f'  range   : {self.ensemble_range}')
            print(f'  sync    : {self.ensemble_sync}')
            print(f'  replace : {self.ensemble_replace}')

    channels = tuple(range(self.channels))[self.input_channels]
    if len(channels) == self.channels:
        print('Input channels     : all')
    else:
        print(f'Input channels     : {channels}')

    if self.prediction_length:
        channels = tuple(range(self.channels))[self.prediction_channels]
        if len(channels) == self.channels:
            print('Predicted channels : all')
        else:
            print(f'Predicted channels : {channels}')

    # get samples
    items = self.__getitem__(0)
    # unpack items
    inputs, targets, sample_weights, *_ = chain(items, [None] * 3)

    print('Output shapes')
    print('  inputs')
    for key, value in inputs.items():
        print(f'    {key:<16.16} : {value.shape}')

    if targets is not None:
        print('  targets')
        for key, value in targets.items():
            print(f'    {key:<16.16} : {value.shape}')

    if sample_weights is not None:
        print('  sample_weights')
        for key, value in sample_weights.items():
            print(f'    {key:<16.16} : {value.shape}')

VAE.generators.PredictGenerator

PredictGenerator(datasets, input_length, batch_size=32, beta_scheduler=None, condition=None, ensemble_size=None, ensemble_type='random', ensemble_index=None, ensemble_range=None, ensemble_replace=False, ensemble_sync=False, filter_length=None, initial_epoch=0, input_channels=None, latitude=0, longitude=None, prediction_channels=None, prediction_length=None, repeat_samples=1, sample_weights=None, shuffle=True, sph_degree=None, strides=1, time=None, tp_period=None, dtype='float32', **kwargs)

Bases: FitGenerator

Generator class for model prediction.

The generator prepares the inputs for the model prediction with :func:ks.Model.predict.

Parameters:

  • **kwargs

    See :class:FitGenerator for parameters.

Returns:

  • Dictionary containing the inputs for the model prediction.

Source code in VAE/generators.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def __init__(self,
             datasets: Union[np.ndarray, list[np.ndarray]],
             input_length: int,
             batch_size: int = 32,
             beta_scheduler=None,
             condition: Union[np.ndarray, list[np.ndarray], dict] = None,
             ensemble_size: int = None,
             ensemble_type: str = 'random',
             ensemble_index: int = None,
             ensemble_range: tuple[int, int] = None,
             ensemble_replace: bool = False,
             ensemble_sync: bool = False,
             filter_length: Union[int, tuple[int, int]] = None,
             initial_epoch: int = 0,
             input_channels: list[int] = None,
             latitude: np.ndarray = 0,
             longitude: np.ndarray = None,
             prediction_channels: list[int] = None,
             prediction_length: int = None,
             repeat_samples: int = 1,
             sample_weights: np.ndarray = None,
             shuffle: bool = True,
             sph_degree: int = None,
             strides: int = 1,
             time: Union[np.ndarray, list[np.ndarray]] = None,
             tp_period: float = None,
             dtype: str = 'float32',
             **kwargs):
    """Instantiate generator."""
    if not isinstance(datasets, (list, tuple)):
        datasets = [datasets]

    shapes = {dataset[:, 0, :].shape for dataset in datasets}
    if len(shapes) > 1:
        raise ValueError('all datasets must have the same set_size and number of channels')

    set_size, channels = shapes.pop()
    self.channels = channels
    self.set_size = set_size

    self.batch_size = batch_size
    self.beta_scheduler = beta_scheduler
    self.datasets = datasets
    self.dtype = dtype
    self.ensemble_replace = ensemble_replace
    self.ensemble_size = ensemble_size
    self.ensemble_sync = ensemble_sync
    self.epoch = initial_epoch
    self.filter_length = np.broadcast_to(filter_length, (2, ))
    self.input_length = input_length
    self.repeat_samples = repeat_samples
    self.shuffle = shuffle
    self.strides = strides
    self.prediction_length = prediction_length if prediction_length is not None else 0

    self.input_channels = slice(*input_channels) if input_channels is not None else slice(None)
    self.prediction_channels = slice(*prediction_channels) if prediction_channels is not None else slice(None)

    if condition is not None:
        if not isinstance(condition, dict):
            # same condition for encoder and decoder
            condition = {'encoder': condition}
        else:
            if 'encoder' not in condition.keys():
                raise KeyError('Require at least `encoder` item in `condition`.')

        # prepare condition
        for key, value in condition.items():
            if isinstance(value, (list, tuple)):
                condition.update({key: [self._prepare_condition(v) for v in value]})
            else:
                condition.update({key: self._prepare_condition(value)})

    self.condition = condition

    if ensemble_size is not None:
        if ensemble_range is None:
            ensemble_range = (0, ensemble_size)

        if repeat_samples > len(range(*ensemble_range)) and not ensemble_replace:
            error_msg = f'{repeat_samples=} must not be larger than {ensemble_range=}' \
                f' if sampling without replacement ({ensemble_replace=})'
            raise ValueError(error_msg)

        val_ensemble_type = {'index', 'random', 'random_full'}
        if ensemble_type not in val_ensemble_type:
            raise ValueError(f'{ensemble_type=} must be in {val_ensemble_type}')

    self.ensemble_index = np.array(ensemble_index) if ensemble_index is not None else None
    self.ensemble_type = ensemble_type
    self.ensemble_range = ensemble_range

    self._prepare_embedding()
    if self.shuffle:
        self._shuffle_data()

    if tp_period is not None:
        if time is None:
            raise ValueError('time must be given if tp_period is given')

        if isinstance(time, (list, tuple)):
            if len(time) == len(datasets):
                self.tp_harmonics = [self.get_tp_harmonics(t, tp_period) for t in time]
            else:
                raise ValueError('Length of `time` must match length of `datasets`')
        else:
            self.tp_harmonics = self.get_tp_harmonics(time, tp_period)

    else:
        self.tp_harmonics = None

    if sph_degree is not None:
        if longitude is None:
            longitude = np.arange(0, 360, 360 / set_size)

        self.sph_harmonics = self.get_sph_harmonics(latitude, longitude, sph_degree)
    else:
        self.sph_harmonics = None

    if sample_weights is not None:
        if len(sample_weights) != len(datasets):
            raise ValueError('`sample_weights` must have the same length as `datasets`')

    self.sample_weights = sample_weights

VAE.generators.PredictGenerator.__getitem__

__getitem__(idx)

Return inputs to the model for given batch.

Source code in VAE/generators.py
602
603
604
605
def __getitem__(self, idx: int) -> dict:
    """Return inputs to the model for given batch."""
    inputs, _ = super().__getitem__(idx)
    return inputs

VAE.generators.example_FitGenerator

example_FitGenerator()

Example of :class:FitGenerator.

This example shows how to use the :class:FitGenerator class.

Source code in VAE/generators.py
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
def example_FitGenerator():
    """Example of :class:`FitGenerator`.

    This example shows how to use the :class:`FitGenerator` class.

    """

    # first we create some dummy data
    shape = (1, 32, 3)  # (set_size, data_length, channels)
    dataset = np.reshape(np.arange(np.prod(shape)), shape)
    datasets = [dataset] * 3

    # the corresponding time values
    time = range(shape[1])

    # and the corresponding conditions
    # the encoder and decoder conditions are different
    encoder_cond = np.linspace(-1, 1, 32)
    decoder_cond = np.linspace(1, -1, 32)

    # then we create the generator
    fit_gen = FitGenerator(datasets,
                           condition={
                               'encoder': encoder_cond,
                               'decoder': decoder_cond
                           },
                           input_length=1,
                           prediction_length=4,
                           batch_size=128,
                           ensemble_size=len(datasets),
                           ensemble_type='index',
                           tp_period=12,
                           time=time,
                           shuffle=False)

    # we can see the summary of the generator
    fit_gen.summary()

    # we can now use the generator to get the inputs for the model
    inputs, *_ = fit_gen[0]

    # we can plot the inputs, to see what the model will get
    # we show the encoder and decoder conditions
    fig, (lax, rax) = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(16, 5))
    lax.pcolormesh(inputs['encoder_cond'][:, 0, :])
    lax.set_title("inputs['encoder_cond']")
    mp = rax.pcolormesh(inputs['decoder_cond'][:, 0, :])
    rax.set_title("inputs['decoder_cond']")

    fig.colorbar(mp, ax=(lax, rax))