Learning Curve Plus Plus (LCPP)
algo::classification::SVM< KERNEL, SOLVER, T > Class Template Reference

Public Member Functions

template<class... Args>
 SVM (const size_t num_class, const T &C, const Args &... args)
 
template<class... Args>
 SVM (const size_t num_class, const std::string solver, const T &C, const Args &... args)
 
template<class... Args>
 SVM (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t num_class, const T &C, const Args &... args)
 
template<class... Args>
 SVM (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t num_class, const Args &... args)
 
void Train (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t num_class)
 
void Train (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
 
void Classify (const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
 
void Classify (const arma::Mat< T > &inputs, arma::Row< size_t > &labels, arma::Mat< T > &dec_func) const
 
ComputeError (const arma::Mat< T > &points, const arma::Row< size_t > &responses)
 
ComputeAccuracy (const arma::Mat< T > &points, const arma::Row< size_t > &responses)
 
template<typename Archive >
void serialize (Archive &ar, const unsigned int)
 

Detailed Description

template<class KERNEL = mlpack::LinearKernel, size_t SOLVER = 0, class T = DTYPE>
class algo::classification::SVM< KERNEL, SOLVER, T >

Definition at line 24 of file svm.h.

Constructor & Destructor Documentation

◆ SVM() [1/4]

template<class KERNEL = mlpack::LinearKernel, size_t SOLVER = 0, class T = DTYPE>
template<class... Args>
algo::classification::SVM< KERNEL, SOLVER, T >::SVM ( const size_t  num_class,
const T &  C,
const Args &...  args 
)
inline

Non-working model

Parameters
num_class: number of classes
args: kernel parameters
num_class: number of classes
C: regularization
args: kernel parameters

Definition at line 46 of file svm.h.

46  :
47  solver_("fanSMO"),C_(C),cov_(args...), oneclass_(false) { } ;

◆ SVM() [2/4]

template<class KERNEL = mlpack::LinearKernel, size_t SOLVER = 0, class T = DTYPE>
template<class... Args>
algo::classification::SVM< KERNEL, SOLVER, T >::SVM ( const size_t  num_class,
const std::string  solver,
const T &  C,
const Args &...  args 
)
inline
Parameters
num_class: number of classes
solver: which optimization method fanSMO
C: regularization
args: kernel parameters

Definition at line 55 of file svm.h.

56  :
57  solver_(solver),C_(C),cov_(args...), oneclass_(false) { } ;

◆ SVM() [3/4]

template<class KERNEL , size_t SOLVER, class T >
template<class... Args>
algo::classification::SVM< KERNEL, SOLVER, T >::SVM ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels,
const size_t  num_class,
const T &  C,
const Args &...  args 
)
Parameters
num_class: number of classes
inputs: X
labels: y
C: regularization
args: kernel parameters

Definition at line 19 of file svm_impl.h.

23  :
24  nclass_(num_class), C_(C), cov_(args...)
25 
26 {
27  ulab_ = arma::unique(labels);
28 
29  if (ulab_.n_elem == 1)
30  {
31  oneclass_ = true;
32  return;
33  }
34 
35  else
36  {
37  if (nclass_ == 2)
38  {
39  this->Train(inputs,labels);
40  }
41  else
42  {
43  ova_ = OnevAll<SVM<KERNEL,SOLVER,T>>(inputs, labels,
44  nclass_, size_t(2), C, args...);
45  }
46  }
47 }
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t num_class)
Definition: svm_impl.h:85

References algo::classification::SVM< KERNEL, SOLVER, T >::Train().

+ Here is the call graph for this function:

◆ SVM() [4/4]

template<class KERNEL , size_t SOLVER, class T >
template<class... Args>
algo::classification::SVM< KERNEL, SOLVER, T >::SVM ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels,
const size_t  num_class,
const Args &...  args 
)
Parameters
num_class: number of classes
inputs: X
labels: y
args: kernel parameters

Definition at line 51 of file svm_impl.h.

54  :
55  nclass_(num_class), C_(T(1.0)), cov_(args...)
56 {
57  ulab_ = arma::unique(labels);
58 
59  if (ulab_.n_elem == 1)
60  {
61  oneclass_ = true;
62  return;
63  }
64  else
65  {
66  if (nclass_ == 2)
67  this->Train(inputs,labels);
68  else
69  ova_ = OnevAll<SVM<KERNEL,SOLVER,T>>(inputs, labels,
70  nclass_, size_t(2),C_,args...);
71  }
72 }

References algo::classification::SVM< KERNEL, SOLVER, T >::Train().

+ Here is the call graph for this function:

Member Function Documentation

◆ Classify() [1/2]

template<class KERNEL , size_t SOLVER, class T >
void algo::classification::SVM< KERNEL, SOLVER, T >::Classify ( const arma::Mat< T > &  inputs,
arma::Row< size_t > &  labels 
) const
Parameters
inputs: X*
labels: y*

Definition at line 212 of file svm_impl.h.

214 {
215  if (!oneclass_)
216  {
217  arma::Mat<T> temp;
218  if (nclass_==2)
219  {
220  Classify(inputs,preds,temp);
221  }
222  else
223  {
224  ova_.Classify(inputs,preds,temp);
225  }
226  }
227  else
228  {
229  preds.resize(inputs.n_cols);
230  preds.fill(ulab_[0]);
231  }
232 }
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
Definition: svm_impl.h:212

◆ Classify() [2/2]

template<class KERNEL , size_t SOLVER, class T >
void algo::classification::SVM< KERNEL, SOLVER, T >::Classify ( const arma::Mat< T > &  inputs,
arma::Row< size_t > &  labels,
arma::Mat< T > &  dec_func 
) const
Parameters
inputs: X*
labels: y*
dec_func: f(x)

Definition at line 235 of file svm_impl.h.

238 {
239  arma::Mat<T> dec_func;
240  T b = 0;
241  if (!oneclass_)
242  {
243  if (nclass_==2)
244  {
245  if (idx_.n_elem>0)
246  {
247  probs.set_size(nclass_,inputs.n_cols);
248  preds.set_size(inputs.n_cols);
249  arma::Mat<T> svs = X_.cols(idx_);
250  arma::Mat<T> K = cov_.GetMatrix(svs,inputs);
251  arma::Mat<T> Ksv = cov_.GetMatrix(svs);
252 
253  b = arma::accu(arma::conv_to<arma::Row<T>>::from(y_.cols(idx_))
254  - ((alphas_.cols(idx_) % y_.cols(idx_)) * Ksv)) /idx_.n_elem;
255 
256  dec_func = (alphas_.cols(idx_) % y_.cols(idx_)) * K + b;
257 
258  preds.elem( arma::find( dec_func <= 0.) ).fill(ulab_[0]);
259  preds.elem( arma::find( dec_func > 0.) ).fill(ulab_[1]);
260  probs.row(0) = 1. / (1. + arma::exp(dec_func));
261  probs.row(1) = 1 - probs.row(0);
262  }
263  else
264  {
265  ERR("No support vectors->No prediction");
266  return;
267  }
268  }
269  else
270  ova_.Classify(inputs, preds, probs);
271  }
272  else
273  {
274  probs.resize(nclass_,inputs.n_cols);
275  probs.row(ulab_[0]).fill(1.);
276  preds.resize(inputs.n_cols);
277  preds.fill(ulab_[0]);
278  }
279 }

◆ ComputeAccuracy()

template<class KERNEL , size_t SOLVER, class T >
T algo::classification::SVM< KERNEL, SOLVER, T >::ComputeAccuracy ( const arma::Mat< T > &  points,
const arma::Row< size_t > &  responses 
)

Calculate the Accuracy

Parameters
inputs: X*
labels: y*

Definition at line 292 of file svm_impl.h.

294 {
295  return (1. - ComputeError(points, responses));
296 }
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses)
Definition: svm_impl.h:282

◆ ComputeError()

template<class KERNEL , size_t SOLVER, class T >
T algo::classification::SVM< KERNEL, SOLVER, T >::ComputeError ( const arma::Mat< T > &  points,
const arma::Row< size_t > &  responses 
)

Calculate the Error Rate

Parameters
inputs: X*
labels: y*

Definition at line 282 of file svm_impl.h.

284 {
285  arma::Row<size_t> predictions;
286  Classify(points,predictions);
287  arma::Row<size_t> temp = predictions - responses;
288  return (arma::accu(temp != 0))/T(predictions.n_elem);
289 }

◆ serialize()

template<class KERNEL = mlpack::LinearKernel, size_t SOLVER = 0, class T = DTYPE>
template<typename Archive >
void algo::classification::SVM< KERNEL, SOLVER, T >::serialize ( Archive &  ar,
const unsigned int   
)
inline

Serialize the model.

Definition at line 135 of file svm.h.

137  {
138  ar ( cereal::make_nvp("X",X_),
139  cereal::make_nvp("nclass",nclass_),
140  cereal::make_nvp("y",y_),
141  cereal::make_nvp("alphas",alphas_),
142  cereal::make_nvp("oldalphas",old_alphas_),
143  cereal::make_nvp("ulab",ulab_),
144  cereal::make_nvp("idx",idx_),
145  cereal::make_nvp("C",C_),
146  cereal::make_nvp("oneclass",oneclass_),
147  cereal::make_nvp("eps",eps_),
148  cereal::make_nvp("tau",tau_),
149  cereal::make_nvp("cov",cov_),
150  cereal::make_nvp("max_iter",max_iter_),
151  cereal::make_nvp("iter",iter_),
152  cereal::make_nvp("solver",solver_) );
153  }

◆ Train() [1/2]

template<class KERNEL , size_t SOLVER, class T >
void algo::classification::SVM< KERNEL, SOLVER, T >::Train ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels 
)
Parameters
inputs: X
labels: y
num_class: y

Definition at line 75 of file svm_impl.h.

77 {
78  if (solver_ == "fanSMO")
79  _fanSMO(X,y);
80  else
81  ERR("Not Implemented: Try fanSMO");
82 }

◆ Train() [2/2]

template<class KERNEL , size_t SOLVER, class T >
void algo::classification::SVM< KERNEL, SOLVER, T >::Train ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels,
const size_t  num_class 
)
Parameters
inputs: X
labels: y

Definition at line 85 of file svm_impl.h.

88 {
89  this -> nclass_ = num_class;
90  this -> Train(X,y);
91 }

Referenced by algo::classification::SVM< KERNEL, SOLVER, T >::SVM().

+ Here is the caller graph for this function:

The documentation for this class was generated from the following files: