Source code for secml.optim.function.c_function_quadratic

"""
.. module:: CFunctionQuadratic
   :synopsis: Quadratic function

.. moduleauthor:: Battista Biggio <battista.biggio@unica.it>

"""
from secml.optim.function import CFunction
from secml.array import CArray


[docs]class CFunctionQuadratic(CFunction): """Implements quadratic functions of the form: x' A x + b' x + c = 0 Attributes ---------- class_type : 'quadratic' """ __class_type = 'quadratic' def __init__(self, A, b, c): if len(A.shape) != 2: raise ValueError('A is not a 2D matrix!') elif A.shape[0] != A.shape[1]: raise ValueError('A is not a squared matrix!') # TODO: Add check: A should be symmetric as well if len(b.shape) != 2 or b.shape[1] != 1: raise ValueError('b is not a column vector!') if b.shape[0] != A.shape[0]: raise ValueError( 'A and b have inconsistent dimensions!') self._A = A self._b = b self._c = c # Passing data to CFunction super(CFunctionQuadratic, self).__init__(fun=self._quadratic_fun, n_dim=A.shape[0], gradient=self._quadratic_grad) def _quadratic_fun(self, x): """Apply quadratic function to point x. Parameters ---------- x : CArray Data point. Returns ------- scalar Result of the function applied to input point. """ return x.dot(self._A).dot(x.T) + x.dot(self._b) + self._c def _quadratic_grad(self, x): """Implements gradient of quadratic function wrt point x.""" return CArray(2 * x.dot(self._A) + self._b.T).ravel()