Source code for secml.data.loader.c_dataloader_mnist

"""
.. module:: DataLoaderMNIST
   :synopsis: Loader the MNIST Handwritten digits dataset

.. moduleauthor:: Marco Melis <marco.melis@unica.it>

"""
import gzip
import struct
from array import array
from multiprocessing import Lock
import numpy as np

from secml.data.loader import CDataLoader
from secml.data import CDataset, CDatasetHeader
from secml.array import CArray
from secml.utils import fm
from secml.utils.download_utils import dl_file_gitlab, md5
from secml.settings import SECML_DS_DIR


MODEL_ZOO_REPO_URL = 'https://gitlab.com/secml/secml-zoo'
MNIST_REPO_PATH = 'datasets/MNIST/'

TRAIN_DATA_FILE = 'train-images-idx3-ubyte.gz'
TRAIN_DATA_MD5 = '6bbc9ace898e44ae57da46a324031adb'
TRAIN_LABELS_FILE = 'train-labels-idx1-ubyte.gz'
TRAIN_LABELS_MD5 = 'a25bea736e30d166cdddb491f175f624'
TEST_DATA_FILE = 't10k-images-idx3-ubyte.gz'
TEST_DATA_MD5 = '2646ac647ad5339dbf082846283269ea'
TEST_LABELS_FILE = 't10k-labels-idx1-ubyte.gz'
TEST_LABELS_MD5 = '27ae3e4e09519cfbb04c329615203637'

MNIST_PATH = fm.join(SECML_DS_DIR, 'mnist')


[docs]class CDataLoaderMNIST(CDataLoader): """Loads the MNIST Handwritten Digits dataset. This dataset has a training set of 60,000 examples, and a test set of 10,000 examples. All images are 28 x 28 black and white 8bit (0 - 255). Available at: http://yann.lecun.com/exdb/mnist/ Attributes ---------- class_type : 'mnist' """ __class_type = 'mnist' __lock = Lock() # Lock to prevent multiple parallel download/extraction def __init__(self): # Build paths of MNIST dataset self.train_data_path = fm.join(MNIST_PATH, 'train-images-idx3-ubyte') self.train_labels_path = fm.join(MNIST_PATH, 'train-labels-idx1-ubyte') self.test_data_path = fm.join(MNIST_PATH, 't10k-images-idx3-ubyte') self.test_labels_path = fm.join(MNIST_PATH, 't10k-labels-idx1-ubyte') with CDataLoaderMNIST.__lock: # For each file check if already downloaded and extracted if not fm.file_exist(self.train_data_path) or \ md5(self.train_data_path) != TRAIN_DATA_MD5: self._get_data(TRAIN_DATA_FILE, MNIST_PATH, self.train_data_path, TRAIN_DATA_MD5) if not fm.file_exist(self.train_labels_path) or \ md5(self.train_labels_path) != TRAIN_LABELS_MD5: self._get_data(TRAIN_LABELS_FILE, MNIST_PATH, self.train_labels_path, TRAIN_LABELS_MD5) if not fm.file_exist(self.test_data_path) or \ md5(self.test_data_path) != TEST_DATA_MD5: self._get_data(TEST_DATA_FILE, MNIST_PATH, self.test_data_path, TEST_DATA_MD5) if not fm.file_exist(self.test_labels_path) or \ md5(self.test_labels_path) != TEST_LABELS_MD5: self._get_data(TEST_LABELS_FILE, MNIST_PATH, self.test_labels_path, TEST_LABELS_MD5)
[docs] def load(self, ds, digits=tuple(range(0, 10)), num_samples=None): """Load all images of specified format inside given path. Adapted from: http://cvxopt.org/_downloads/mnist.py Extra dataset attributes: - 'img_w', 'img_h': size of the images in pixels. - 'y_original': array with the original labels (before renumbering) Parameters ---------- ds : str Identifier of the dataset to download, either 'training' or 'testing'. digits : tuple Tuple with the digits to load. By default all digits are loaded. num_samples : int or None, optional Number of expected samples in resulting ds. If int, an equal number of samples will be taken from each class until `num_samples` have been loaded. If None, all samples will be loaded. """ if ds == "training": data_path = self.train_data_path lbl_path = self.train_labels_path elif ds == "testing": data_path = self.test_data_path lbl_path = self.test_labels_path else: raise ValueError("ds must be 'training' or 'testing'") self.logger.info( "Loading MNIST {:} dataset from {:}...".format(ds, MNIST_PATH)) # Opening the labels data flbl = open(lbl_path, 'rb') magic_nr, size = struct.unpack(">II", flbl.read(8)) if magic_nr != 2049: raise ValueError('Magic number mismatch, expected 2049,' 'got {}'.format(magic_nr)) lbl = array("b", flbl.read()) flbl.close() # Opening the images data fimg = open(data_path, 'rb') magic_nr, size, rows, cols = struct.unpack(">IIII", fimg.read(16)) if magic_nr != 2051: raise ValueError('Magic number mismatch, expected 2051,' 'got {}'.format(magic_nr)) img = array("B", fimg.read()) fimg.close() # Convert digits to tuple in case was passed as array/list digits = tuple(digits) # Number of samples per class if num_samples is not None: div = len(digits) n_samples_class = [ int(num_samples / div) + (1 if x < num_samples % div else 0) for x in range(div)] n_samples_class = { e: n_samples_class[e_i] for e_i, e in enumerate(digits)} else: # No constraint on the number of samples n_samples_class = {e: size for e in digits} # Counter of already taken sample for a class count_samples_class = {e: 0 for e in digits} # Extract the indices of samples to load ind = [] for k in range(size): if lbl[k] in digits: # Check the maximum number of samples for current digits if count_samples_class[lbl[k]] < n_samples_class[lbl[k]]: ind += [k] count_samples_class[lbl[k]] += 1 # Number of loaded samples num_loaded = sum(count_samples_class.values()) # Check if dataset has enough samples if num_samples is not None and num_loaded < num_samples: min_val = min(count_samples_class.values()) raise ValueError( "not enough samples in dataset for one ore more of the " "desired classes ({:} available)".format(min_val)) images = CArray.zeros((len(ind), rows * cols), dtype=np.uint8) labels = CArray.zeros(len(ind), dtype=int) digs_array = CArray(digits) # To use find method for i in range(len(ind)): images[i, :] = CArray(img[ ind[i] * rows * cols: (ind[i] + 1) * rows * cols]) labels[i] = CArray(digs_array.find(digs_array == lbl[ind[i]])) header = CDatasetHeader(img_w=28, img_h=28, y_original=digits) return CDataset(images, labels, header=header)
def _get_data(self, file_name, dl_folder, output_path, md5sum): """Download input datafile, unzip and store in output_path. Parameters ---------- file_name : str Name of the file to download. dl_folder : str Path to the folder where to store the downloaded file. output_path : str Full path of output file. md5sum : str Expected MD5 of the downloaded file (after unpacking). """ # Download file and unpack fh = dl_file_gitlab( MODEL_ZOO_REPO_URL, MNIST_REPO_PATH + file_name, dl_folder) with gzip.open(fh, 'rb') as infile: with open(output_path, 'wb') as outfile: for line in infile: outfile.write(line) # Remove download zipped file fm.remove_file(fh) # Check the hash of the downloaded file (unpacked) if md5(output_path) != md5sum: raise RuntimeError('Something wrong happened while ' 'downloading the dataset. Please try again.')