Source code for secml.optim.function.c_function_linear

"""
.. module:: CFunctionLinear
   :synopsis: Linear function

.. moduleauthor:: Marco Melis <marco.melis@unica.it>

"""
from secml.optim.function import CFunction
from secml.array import CArray


[docs]class CFunctionLinear(CFunction): """Implements linear functions of the form: b' x + c = 0 Attributes ---------- class_type : 'linear' """ __class_type = 'linear' def __init__(self, b, c): if b.ndim != 2 or b.shape[1] != 1: raise ValueError('b is not a column vector!') self._b = b self._c = c # Passing data to CFunction super(CFunctionLinear, self).__init__(fun=self._linear_fun, n_dim=b.shape[0], gradient=self._linear_grad) def _linear_fun(self, x): """Apply linear function to point x. Parameters ---------- x : CArray Data point. Returns ------- scalar Result of the function applied to input point. """ return x.dot(self._b) + self._c def _linear_grad(self): """Implements gradient of linear function wrt point x.""" return CArray(self._b.ravel())