Learning Curve Plus Plus (LCPP)
svm.h
Go to the documentation of this file.
1 /**
2  * @file svm.h
3  * @author Ozgur Taylan Turan
4  *
5  * SVM Classfier
6  * fanSMO -> Fan, R.-E., Chen, P.-H., & Lin, C.-J. (2005).
7  * Working set selection using second order information for training
8  * support vector machines.
9  * Journal of Machine Learning Research, 6, 1889–1918.
10  * Available at: https://www.jmlr.org/papers/v6/fan05a.html
11  */
12 
13 #ifndef SVM_H
14 #define SVM_H
15 
16 namespace algo {
17 namespace classification {
18 //-----------------------------------------------------------------------------
19 // SVM : Kernel version of the support vector classifier
20 //-----------------------------------------------------------------------------
21 template<class KERNEL=mlpack::LinearKernel,
22  size_t SOLVER=0,
23  class T=DTYPE>
24 class SVM
25 {
26  public:
27 
28  SVM ( ) = default;
29 
30  /**
31  * Non-working model
32  * @param num_class : number of classes
33  * @param args : kernel parameters
34  */
35  /* template<class... Args> */
36  /* SVM ( const size_t& num_class, */
37  /* const Args&... args ) : C_(T(1.)), cov_(args...), oneclass_(false) */
38  /* { }; */
39 
40  /**
41  * @param num_class : number of classes
42  * @param C : regularization
43  * @param args : kernel parameters
44  */
45  template<class... Args>
46  SVM ( const size_t num_class, const T& C, const Args&... args ) :
47  solver_("fanSMO"),C_(C),cov_(args...), oneclass_(false) { } ;
48  /**
49  * @param num_class : number of classes
50  * @param solver : which optimization method fanSMO
51  * @param C : regularization
52  * @param args : kernel parameters
53  */
54  template<class... Args>
55  SVM ( const size_t num_class, const std::string solver,
56  const T& C, const Args&... args ) :
57  solver_(solver),C_(C),cov_(args...), oneclass_(false) { } ;
58  /**
59  * @param num_class : number of classes
60  * @param inputs : X
61  * @param labels : y
62  * @param C : regularization
63  * @param args : kernel parameters
64  */
65  template<class... Args>
66  SVM ( const arma::Mat<T>& inputs,
67  const arma::Row<size_t>& labels,
68  const size_t num_class,
69  const T& C,
70  const Args&... args );
71  /**
72  * @param num_class : number of classes
73  * @param inputs : X
74  * @param labels : y
75  * @param args : kernel parameters
76  */
77  template<class... Args>
78  SVM ( const arma::Mat<T>& inputs,
79  const arma::Row<size_t>& labels,
80  const size_t num_class,
81  const Args&... args );
82 
83  /**
84  * @param inputs : X
85  * @param labels : y
86  */
87  void Train ( const arma::Mat<T>& inputs,
88  const arma::Row<size_t>& labels,
89  const size_t num_class );
90  /**
91  * @param inputs : X
92  * @param labels : y
93  * @param num_class : y
94  */
95  void Train ( const arma::Mat<T>& inputs,
96  const arma::Row<size_t>& labels );
97 
98  /**
99  * @param inputs : X*
100  * @param labels : y*
101  */
102  void Classify ( const arma::Mat<T>& inputs,
103  arma::Row<size_t>& labels ) const ;
104 
105 /**
106  * @param inputs : X*
107  * @param labels : y*
108  * @param dec_func : f(x)
109  */
110  void Classify ( const arma::Mat<T>& inputs,
111  arma::Row<size_t>& labels,
112  arma::Mat<T>& dec_func ) const;
113  /**
114  * Calculate the Error Rate
115  *
116  * @param inputs : X*
117  * @param labels : y*
118  */
119  T ComputeError ( const arma::Mat<T>& points,
120  const arma::Row<size_t>& responses );
121  /**
122  * Calculate the Accuracy
123  *
124  * @param inputs : X*
125  * @param labels : y*
126  *
127  */
128  T ComputeAccuracy ( const arma::Mat<T>& points,
129  const arma::Row<size_t>& responses );
130 
131  /**
132  * Serialize the model.
133  */
134  template<typename Archive>
135  void serialize ( Archive& ar,
136  const unsigned int /* version */ )
137  {
138  ar ( cereal::make_nvp("X",X_),
139  cereal::make_nvp("nclass",nclass_),
140  cereal::make_nvp("y",y_),
141  cereal::make_nvp("alphas",alphas_),
142  cereal::make_nvp("oldalphas",old_alphas_),
143  cereal::make_nvp("ulab",ulab_),
144  cereal::make_nvp("idx",idx_),
145  cereal::make_nvp("C",C_),
146  cereal::make_nvp("oneclass",oneclass_),
147  cereal::make_nvp("eps",eps_),
148  cereal::make_nvp("tau",tau_),
149  cereal::make_nvp("cov",cov_),
150  cereal::make_nvp("max_iter",max_iter_),
151  cereal::make_nvp("iter",iter_),
152  cereal::make_nvp("solver",solver_) );
153  }
154 
155 private:
156  std::map<int,std::string> solvers_ = {{0,"fanSMO"}};
157  std::string solver_ = solvers_[SOLVER];
158  size_t nclass_;
159  T C_;
160  data::Gram<KERNEL> cov_;
161  arma::Row<int> y_;
162  arma::Mat<T> X_; // When I use here a pointer to the matrix it goes wrong...
163  arma::Row<size_t> ulab_;
164  arma::Row<T> alphas_;
165  arma::Row<T> old_alphas_;
166  arma::uvec idx_;
167  bool oneclass_ = false;
168  T eps_ = 1e-3;
169  T tau_ = 1e-12;
170  size_t max_iter_ = 5000;
171  size_t iter_ = 0;
172 
173  OnevAll<SVM<KERNEL,SOLVER,T>> ova_;
174 
175  /**
176  * @param inputs : X
177  * @param labels : y
178  */
179 
180  void _fanSMO ( const arma::Mat<T>& inputs,
181  const arma::Row<size_t>& labels );
182 
183  /**
184  * @param G : Vector containing the Gradients
185  * @param Q : Matrix with the Kernel x y
186  */
187  std::pair<int, int> _selectset ( arma::Row<T> G, arma::Mat<T> Q );
188 
189 
190 
191 };
192 
193 } // namespace classification
194 } // namespace algo
195 #include "svm_impl.h"
196 #endif
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t num_class)
Definition: svm_impl.h:85
SVM(const size_t num_class, const T &C, const Args &... args)
Definition: svm.h:46
SVM(const size_t num_class, const std::string solver, const T &C, const Args &... args)
Definition: svm.h:55
T ComputeAccuracy(const arma::Mat< T > &points, const arma::Row< size_t > &responses)
Definition: svm_impl.h:292
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses)
Definition: svm_impl.h:282
void serialize(Archive &ar, const unsigned int)
Definition: svm.h:135
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
Definition: svm_impl.h:212