Source code for secml.ml.model_zoo.load_model

"""
.. module:: LoadModel
   :synopsis: Functions to load pre-trained SecML models

.. moduleauthor:: Marco Melis <marco.melis@unica.it>

"""
import json

from secml.utils import fm
from secml.utils.download_utils import dl_file

from secml.settings import SECML_MODELS_DIR

MODEL_ZOO_URL = 'https://gitlab.com/secml/secml-zoo/raw/v0.11/models/'
MODELS_DICT_PATH = fm.join(fm.abspath(__file__), 'models_dict.json')
with open(MODELS_DICT_PATH) as fp:
    MODELS_DICT = json.loads(fp.read())


[docs]def load_model(model_id): """Load a pre-trained classifier. Returns a pre-trained SecML classifier given the id of the model. The following models are available: - `mnist-svm`, multiclass `CClassifierSVM` trained on MNIST - `mnist59-svm`, multiclass `CClassifierSVM` with RBF Kernel trained on MNIST59 - `mnist59-svm-rbf`, multiclass `CClassifierSVM` with RBF Kernel trained on MNIST59 - `mnist159-cnn`, `CClassifierPyTorch` CNN trained on MNIST159 Parameters ---------- model_id : str Identifier of the pre-trained model to load. Returns ------- CClassifier Desired pre-trained model. """ model_info = MODELS_DICT[model_id] data_path = fm.join(SECML_MODELS_DIR, model_id, model_id + '.gz') # Download (if needed) data and extract it if not fm.file_exist(data_path): model_url = MODEL_ZOO_URL + model_info['url'] + '.gz' dl_file(model_url, fm.join(SECML_MODELS_DIR, model_id), md5_digest=model_info['md5']) # Check if file has been correctly downloaded if not fm.file_exist(data_path): raise RuntimeError('Something wrong happened while ' 'downloading the model. Please try again') def import_module(full_name, path): """Import a python module from a path.""" from importlib import util spec = util.spec_from_file_location(full_name, path) mod = util.module_from_spec(spec) spec.loader.exec_module(mod) return mod # Name of the function returning the model model_name = model_info["model"].split('/')[-1] # Import the python module containing the function returning the model model_path = fm.join( fm.abspath(__file__), 'models', model_info["model"] + '.py') model_module = import_module(model_name, model_path) # Run the function returning the model model = getattr(model_module, model_name)() # Restore the state of the model from file model.load_state(data_path) return model