Source code for secml.data.c_dataset_pytorch

"""
.. module:: CDatasetPyTorch
   :synopsis: An interface for using the CDataset in a PyTorch model

.. moduleauthor:: Maura Pintor <maura.pintor@unica.it>

"""
import torch
from torch.utils.data import Dataset

from secml.array import CArray
from secml.core.type_utils import is_int
from secml.data import CDataset


[docs]class CDatasetPyTorch(Dataset): """CDataset to PyTorch Dataset wrapper. Parameters ---------- data : CDataset or CArray Dataset to be wrapped. Can also be a CArray with the samples and in this case the labels can be passed using the `labels` parameter. labels : None or CArray Labels of the dataset. Can be defined if the samples have been passed to the `data` parameter. Input must be a flat array of shape (num_samples, ) or a 2-D array with shape (num_samples, num_classes). transform : torchvision.transforms or None, optional Transformation(s) to be applied to each ds sample. """ def __init__(self, data, labels=None, transform=None): """Class constructor.""" if isinstance(data, CDataset): if labels is not None: raise TypeError("labels must be defined inside the dataset") self._samples = data.X.atleast_2d() # Labels inside a CDataset are always stored as flat arrays self._labels = data.Y if data.Y is not None else None else: self._samples = data.atleast_2d() self._labels = labels # 1-D, 2-D or None self.transform = transform @property def X(self): return self._samples @property def Y(self): return self._labels def __len__(self): """Returns dataset size.""" return self._samples.shape[0] def __getitem__(self, i): """Return desired pair (sample, label) from the dataset.""" if not is_int(i): raise ValueError("only integer indexing is supported") sample = CArray(self._samples[i, :]).tondarray() if self.transform is not None: sample = self.transform(sample) # Ensure we return tensors if not isinstance(sample, torch.Tensor): sample = torch.from_numpy(sample) if self._labels is not None: if self._labels.ndim == 1: # (num_samples, ) label = torch.tensor(self._labels[i].item()) else: # (num_samples, num_classes) label = CArray(self._labels[i, :]).tondarray() if not isinstance(label, torch.Tensor): label = torch.from_numpy(label) else: label = torch.tensor(-1) # Tensor with null label return sample.float(), label