:mod:`base_loader` ================== .. py:module:: base_loader .. autoapi-nested-parse:: Base Dataset class to work with fwi-forcings data. Module Contents --------------- .. py:class:: ModelDataset(out_var=None, out_mean=None, forecast_dir=None, forcings_dir=None, reanalysis_dir=None, transform=None, hparams=None, **kwargs) Bases: :class:`torch.utils.data.Dataset` .. autoapi-inheritance-diagram:: base_loader.ModelDataset :parts: 1 The dataset class responsible for loading the data and providing the samples for training. :param Dataset: Base Dataset class to use with PyTorch models :type Dataset: torch.utils.data.Dataset .. method:: __len__(self) The internal method used to obtain the number of iteration samples. :return: The maximum possible iterations with the provided data. :rtype: int .. method:: __getitem__(self, idx) Internal method used by pytorch to fetch input and corresponding output tensors. :param idx: The index number of data sample. :type idx: int :return: Batch of data containing input and output tensors :rtype: tuple .. method:: get_cb_loss_factor(self, y) Compute the Class-Balanced loss factor mask using output value frequency distribution and the supplied beta factor. :param y: The 1D ground truth value tensor :type y: torch.tensor .. method:: apply_mask(self, *y_list) Returns batch_size x channels x N sized matrices after applying the mask. :param *y_list: The interable of tensors to be masked :type y_list: torch.Tensor :return: The list of masked tensors :rtype: list(torch.Tensor) .. method:: get_loss(self, y, y_hat) Do the applicable processing and return the loss for the supplied prediction and the label tensors. :param y: Label tensor :type y: torch.Tensor :param y_hat: Predicted tensor :type y_hat: torch.Tensor :return: Prediction loss :rtype: torch.Tensor .. method:: training_step(self, model, batch) Called inside the training loop with the data from the training dataloader passed in as `batch`. :param model: The chosen model :type model: Model :param batch: Batch of input and ground truth variables :type batch: int :return: Loss and logs :rtype: dict .. method:: validation_step(self, model, batch) Called inside the validation loop with the data from the validation dataloader passed in as `batch`. :param model: The chosen model :type model: Model :param batch: Batch of input and ground truth variables :type batch: int :return: Loss and logs :rtype: dict .. method:: inference_step(self, y_pre, y_hat_pre) Run inference for the target and predicted values and return the loss and the metrics values as logs. :param y_pre: Label values :type y_pre: torch.Tensor :param y_hat_pre: Predicted value :type y_hat_pre: torch.Tensor :return: Loss and the log dictionary :rtype: tuple .. method:: test_step(self, model, batch) Called inside the testing loop with the data from the testing dataloader passed in as `batch`. :param model: The chosen model :type model: Model :param batch: Batch of input and ground truth variables :type batch: int :return: Loss and logs :rtype: dict .. method:: benchmark_step(self, batch) Called inside the testing loop with the data from the testing dataloader passed in as `batch`. :param model: The chosen model :type model: Model :param batch: Batch of input and ground truth variables :type batch: int :return: Loss and logs :rtype: dict