Learning Curve Plus Plus (LCPP)
algo::regression::KernelRidge< KERNEL, T > Class Template Reference

Public Member Functions

template<typename... Ts>
 KernelRidge (const arma::Mat< T > &inputs, const arma::Row< T > &labels, const double &lambda, const Ts &... args)
 
template<typename... Ts>
 KernelRidge (const arma::Mat< T > &inputs, const arma::Row< T > &labels)
 
template<typename... Ts>
 KernelRidge (const Ts &... args)
 
void Train (const arma::Mat< T > &inputs, const arma::Row< T > &labels)
 
void Predict (const arma::Mat< T > &inputs, arma::Row< T > &labels) const
 
ComputeError (const arma::Mat< T > &points, const arma::Row< T > &responses) const
 
const arma::Row< T > & Parameters () const
 
arma::Row< T > & Parameters ()
 
double Lambda () const
 
double & Lambda ()
 
template<typename Archive >
void serialize (Archive &ar, const unsigned int)
 

Detailed Description

template<class KERNEL, class T = DTYPE>
class algo::regression::KernelRidge< KERNEL, T >

Definition at line 21 of file kernelridge.h.

Constructor & Destructor Documentation

◆ KernelRidge() [1/2]

template<class KERNEL , class T >
template<class... Ts>
algo::regression::KernelRidge< KERNEL, T >::KernelRidge ( const arma::Mat< T > &  inputs,
const arma::Row< T > &  labels,
const double &  lambda,
const Ts &...  args 
)
Parameters
inputsX
labelsy
lambdaregularization hyper-parameter
argsfor the kernel

Definition at line 20 of file kernelridge_impl.h.

23  :
24  cov_(args...), lambda_(lambda)
25 {
26  Train(inputs, labels);
27 }
void Train(const arma::Mat< T > &inputs, const arma::Row< T > &labels)

References algo::regression::KernelRidge< KERNEL, T >::Train().

+ Here is the call graph for this function:

◆ KernelRidge() [2/2]

template<class KERNEL , class T = DTYPE>
template<typename... Ts>
algo::regression::KernelRidge< KERNEL, T >::KernelRidge ( const Ts &...  args)
inline

Non-working model

Definition at line 44 of file kernelridge.h.

44 : cov_(args...), lambda_(0.0) { }

Member Function Documentation

◆ ComputeError()

template<class KERNEL , class T >
T algo::regression::KernelRidge< KERNEL, T >::ComputeError ( const arma::Mat< T > &  points,
const arma::Row< T > &  responses 
) const

Calculate the L2 squared error

Parameters
inputs
labels

Definition at line 63 of file kernelridge_impl.h.

65 {
66  arma::Row<T> temp;
67  Predict(inputs, temp);
68  const size_t n_points = inputs.n_cols;
69 
70  temp = labels - temp;
71 
72  const T cost = (arma::dot(temp, temp) / n_points);
73 
74  return cost;
75 }
void Predict(const arma::Mat< T > &inputs, arma::Row< T > &labels) const

◆ Predict()

template<class KERNEL , class T >
void algo::regression::KernelRidge< KERNEL, T >::Predict ( const arma::Mat< T > &  inputs,
arma::Row< T > &  labels 
) const
Parameters
inputsX*
labelsy*

Definition at line 54 of file kernelridge_impl.h.

56 {
57 
58  arma::Mat<T> k_xpx = cov_.GetMatrix(train_inp_,inputs);
59  labels = (parameters_ * k_xpx );
60 }

◆ serialize()

template<class KERNEL , class T = DTYPE>
template<typename Archive >
void algo::regression::KernelRidge< KERNEL, T >::serialize ( Archive &  ar,
const unsigned int   
)
inline

Serialize the model.

Definition at line 81 of file kernelridge.h.

82  {
83  ar ( cereal::make_nvp("parameters",parameters_),
84  cereal::make_nvp("lambda",lambda_),
85  cereal::make_nvp("cov",cov_),
86  cereal::make_nvp("train_inp",train_inp_));
87  }

◆ Train()

template<class KERNEL , class T >
void algo::regression::KernelRidge< KERNEL, T >::Train ( const arma::Mat< T > &  inputs,
const arma::Row< T > &  labels 
)
Parameters
inputsX
labelsy

Definition at line 39 of file kernelridge_impl.h.

41 {
42  train_inp_ = inputs;
43 
44  arma::Mat<T> k_xx = cov_.GetMatrix(train_inp_,train_inp_);
45 
46  arma::Mat<T> KLambda = k_xx+
47  (lambda_ + 1.e-6) * arma::eye<arma::Mat<T>>(k_xx.n_rows, k_xx.n_rows);
48 
49  parameters_ =
50  arma::conv_to<arma::Row<T>>::from(arma::solve(KLambda, labels.t()));
51 }

Referenced by algo::regression::KernelRidge< KERNEL, T >::KernelRidge().

+ Here is the caller graph for this function:

The documentation for this class was generated from the following files: