base_loader
¶
Base Dataset class to work with fwi-forcings data.
Module Contents¶
-
class
base_loader.
ModelDataset
(out_var=None, out_mean=None, forecast_dir=None, forcings_dir=None, reanalysis_dir=None, transform=None, hparams=None, **kwargs)¶ Bases:
torch.utils.data.Dataset
The dataset class responsible for loading the data and providing the samples for training.
- Parameters
Dataset (torch.utils.data.Dataset) – Base Dataset class to use with PyTorch models
-
__len__
(self)¶ The internal method used to obtain the number of iteration samples.
- Returns
The maximum possible iterations with the provided data.
- Return type
int
-
__getitem__
(self, idx)¶ Internal method used by pytorch to fetch input and corresponding output tensors.
- Parameters
idx (int) – The index number of data sample.
- Returns
Batch of data containing input and output tensors
- Return type
tuple
-
get_cb_loss_factor
(self, y)¶ Compute the Class-Balanced loss factor mask using output value frequency distribution and the supplied beta factor.
- Parameters
y (torch.tensor) – The 1D ground truth value tensor
-
apply_mask
(self, *y_list)¶ Returns batch_size x channels x N sized matrices after applying the mask.
- Parameters
*y_list –
The interable of tensors to be masked
- Returns
The list of masked tensors
- Return type
list(torch.Tensor)
-
get_loss
(self, y, y_hat)¶ Do the applicable processing and return the loss for the supplied prediction and the label tensors.
- Parameters
y (torch.Tensor) – Label tensor
y_hat (torch.Tensor) – Predicted tensor
- Returns
Prediction loss
- Return type
torch.Tensor
-
training_step
(self, model, batch)¶ Called inside the training loop with the data from the training dataloader passed in as batch.
- Parameters
model (Model) – The chosen model
batch (int) – Batch of input and ground truth variables
- Returns
Loss and logs
- Return type
dict
-
validation_step
(self, model, batch)¶ Called inside the validation loop with the data from the validation dataloader passed in as batch.
- Parameters
model (Model) – The chosen model
batch (int) – Batch of input and ground truth variables
- Returns
Loss and logs
- Return type
dict
-
inference_step
(self, y_pre, y_hat_pre)¶ Run inference for the target and predicted values and return the loss and the metrics values as logs.
- Parameters
y_pre (torch.Tensor) – Label values
y_hat_pre (torch.Tensor) – Predicted value
- Returns
Loss and the log dictionary
- Return type
tuple