Source code for secml.model_zoo.load_model

"""
.. module:: LoadModel
   :synopsis: Functions to load pre-trained SecML models

.. moduleauthor:: Marco Melis <marco.melis@unica.it>

"""
import json
import re
from datetime import datetime, timedelta

import secml
from secml.settings import SECML_LOGS_PATH, SECML_STORE_LOGS
from secml.utils import fm, CLog
from secml.utils.download_utils import dl_file_gitlab, md5

from secml.settings import SECML_MODELS_DIR

MODEL_ZOO_REPO_URL = 'https://gitlab.com/secml/secml-zoo'
MODELS_DICT_FILE = 'models_dict.json'
MODELS_DICT_PATH = fm.join(SECML_MODELS_DIR, MODELS_DICT_FILE)

_logger = CLog(
    logger_id=__name__,
    file_handler=SECML_LOGS_PATH if SECML_STORE_LOGS is True else None)


def _dl_data_versioned(file_path, output_dir, md5_digest=None):
    """Download the from different branches depending on version.

    This function tries to download a model zoo resource from:
     1. the branch corresponding to current version,
        e.g. branch `v0.12` for `0.12.*` version
     2. the `master` branch

    Parameters
    ----------
    file_path : str
        Path to the file to download, relative to the repository.
    output_dir : str
        Path to the directory where the file should be stored.
        If folder does not exists, will be created.
    md5_digest : str or None, optional
        Expected MD5 digest of the downloaded file.
        If a different digest is computed, the downloaded file will be
        removed and ValueError is raised.

    """
    try:
        # Try downloading from the branch corresponding to current version
        min_version = re.search(r'^\d+.\d+', secml.__version__).group(0)
        dl_file_gitlab(MODEL_ZOO_REPO_URL, file_path, output_dir,
                       branch='v' + min_version, md5_digest=md5_digest)

    except Exception as e:  # Try looking into 'master' branch...
        _logger.debug(e)
        _logger.debug("Looking in the `master` branch...")
        dl_file_gitlab(MODEL_ZOO_REPO_URL, file_path, output_dir,
                       branch='master', md5_digest=md5_digest)


def _get_models_dict():
    """Downloads the ditionary of models definitions.

    File will be re-downloaded every 30 minutes (upon request) to update
    the models definitions from repository.

    Returns
    -------
    models_dict : dict
        Dictionary with models definitions. Each key is an available model.
        Each model entry is defined by:
         - "model", path to the script with model definition
         - "state", path to the archive containing the pre-saved model state
         - "model_md5", md5 checksum of model definition
         - "state_md5", md5 checksum of pre-saved model state

    """
    # The `.last_update` contains the last time MODELS_DICT_FILE
    # has been download. Read the last update time if this file is available.
    # Otherwise the file will be created later
    last_update_path = fm.join(SECML_MODELS_DIR, '.last_update')
    last_update_format = "%d %m %Y %H:%M"  # Specific format to avoid locale
    current_datetime = datetime.utcnow()  # UTC datetime to avoid locale

    update_models_dict = None  # Trigger flag for model definitions update
    if fm.file_exist(MODELS_DICT_PATH):
        update_models_dict = True  # By default, trigger update
        if fm.file_exist(last_update_path):
            try:
                with open(last_update_path) as fp:
                    last_update = \
                        datetime.strptime(fp.read(), last_update_format)
                    # Compute the threshold for triggering an update
                    last_update_th = last_update + timedelta(minutes=30)
            except ValueError as e:
                # Error occurred while parsing the last update date from file
                # Clean it and re-create later. Definitions update stays True
                _logger.debug(e)  # Log the error for debug purposes
                _logger.debug("Removing `{:}`".format(last_update_path))
                fm.remove_file(last_update_path)
            else:
                # Do not trigger update if last update threshold is not passed
                if current_datetime < last_update_th:
                    update_models_dict = False

    if update_models_dict is not False:
        # if update_models_dict is None means that models dict is not available
        # if it is True means that an update has been triggered
        # Either cases, we need to download the data and extract it

        try:  # Catch download errors

            # Download definitions from current version's branch first,
            # then from master branch
            _dl_data_versioned(MODELS_DICT_FILE, SECML_MODELS_DIR)

        except Exception as e:
            if update_models_dict is None:
                # If update_models_dict is still None, means that models dict
                # is not available, so we propagate the error. Otherwise pass
                raise e
            _logger.debug(e)  # Log the error for debug purposes
            _logger.debug("Error when updating the models definitions. "
                          "Using the last available ones...")

        else:  # No error raised during download process

            # Check if file has been correctly downloaded
            if not fm.file_exist(MODELS_DICT_PATH):
                raise RuntimeError(
                    'Something wrong happened while downloading the '
                    'models definitions. Please try again.')

            # Update or create the "last update" file
            with open(last_update_path, "w") as fp:
                fp.write(current_datetime.strftime(last_update_format))

    with open(MODELS_DICT_PATH) as fp:
        return json.loads(fp.read())


[docs]def load_model(model_id): """Load a pre-trained classifier. Returns a pre-trained SecML classifier given the id of the model. Check https://gitlab.com/secml/secml-zoo for the list of available models. Parameters ---------- model_id : str Identifier of the pre-trained model to load. Returns ------- CClassifier Desired pre-trained model. """ model_info = _get_models_dict()[model_id] model_path = fm.join(SECML_MODELS_DIR, model_info['model'] + '.py') # Download (if needed) model's script, check md5 and extract it if not fm.file_exist(model_path) or \ model_info['model_md5'] != md5(model_path): model_url_parts = ('models', model_info['model'] + '.py') model_url = '/'.join(s.strip('/') for s in model_url_parts) out_dir = fm.abspath(model_path) # Download requested model from current version's branch first, # then from master branch _dl_data_versioned(model_url, out_dir, model_info['model_md5']) # Check if file has been correctly downloaded if not fm.file_exist(model_path): raise RuntimeError('Something wrong happened while ' 'downloading the model. Please try again.') state_path = fm.join(SECML_MODELS_DIR, model_info['state'] + '.gz') # Download (if needed) state, check md5 and extract it if not fm.file_exist(state_path) or \ model_info['state_md5'] != md5(state_path): state_url_parts = ('models', model_info['state'] + '.gz') state_url = '/'.join(s.strip('/') for s in state_url_parts) out_dir = fm.abspath(state_path) # Download requested model state from current version's branch first, # then from master branch _dl_data_versioned(state_url, out_dir, model_info['state_md5']) # Check if file has been correctly downloaded if not fm.file_exist(state_path): raise RuntimeError('Something wrong happened while ' 'downloading the model. Please try again.') def import_module(full_name, path): """Import a python module from a path.""" from importlib import util spec = util.spec_from_file_location(full_name, path) mod = util.module_from_spec(spec) spec.loader.exec_module(mod) return mod # Name of the function returning the model model_name = model_info["model"].split('/')[-1] # Import the python module containing the function returning the model model_module = import_module(model_name, model_path) # Run the function returning the model model = getattr(model_module, model_name)() # Restore the state of the model from file model.load_state(state_path) return model