Learning Curve Plus Plus (LCPP)
svm_impl.h
1 /**
2  * @file svm.h
3  * @author Ozgur Taylan Turan
4  *
5  * SVM Classfier
6  *
7  */
8 
9 #ifndef SVM_IMPL_H
10 #define SVM_IMPL_H
11 
12 namespace algo {
13 namespace classification {
14 //-----------------------------------------------------------------------------
15 // SVM
16 //-----------------------------------------------------------------------------
17 template<class KERNEL,size_t SOLVER,class T>
18 template<class... Args>
19 SVM<KERNEL,SOLVER,T>::SVM ( const arma::Mat<T>& inputs,
20  const arma::Row<size_t>& labels,
21  const size_t num_class,
22  const T& C,
23  const Args&... args) :
24  nclass_(num_class), C_(C), cov_(args...)
25 
26 {
27  ulab_ = arma::unique(labels);
28 
29  if (ulab_.n_elem == 1)
30  {
31  oneclass_ = true;
32  return;
33  }
34 
35  else
36  {
37  if (nclass_ == 2)
38  {
39  this->Train(inputs,labels);
40  }
41  else
42  {
43  ova_ = OnevAll<SVM<KERNEL,SOLVER,T>>(inputs, labels,
44  nclass_, size_t(2), C, args...);
45  }
46  }
47 }
48 ///////////////////////////////////////////////////////////////////////////////
49 template<class KERNEL,size_t SOLVER,class T>
50 template<class... Args>
51 SVM<KERNEL,SOLVER,T>::SVM ( const arma::Mat<T>& inputs,
52  const arma::Row<size_t>& labels,
53  const size_t num_class,
54  const Args&... args ) :
55  nclass_(num_class), C_(T(1.0)), cov_(args...)
56 {
57  ulab_ = arma::unique(labels);
58 
59  if (ulab_.n_elem == 1)
60  {
61  oneclass_ = true;
62  return;
63  }
64  else
65  {
66  if (nclass_ == 2)
67  this->Train(inputs,labels);
68  else
69  ova_ = OnevAll<SVM<KERNEL,SOLVER,T>>(inputs, labels,
70  nclass_, size_t(2),C_,args...);
71  }
72 }
73 ///////////////////////////////////////////////////////////////////////////////
74 template<class KERNEL,size_t SOLVER,class T>
75 void SVM<KERNEL,SOLVER,T>::Train ( const arma::Mat<T>& X,
76  const arma::Row<size_t>& y )
77 {
78  if (solver_ == "fanSMO")
79  _fanSMO(X,y);
80  else
81  ERR("Not Implemented: Try fanSMO");
82 }
83 ///////////////////////////////////////////////////////////////////////////////
84 template<class KERNEL,size_t SOLVER,class T>
85 void SVM<KERNEL,SOLVER,T>::Train ( const arma::Mat<T>& X,
86  const arma::Row<size_t>& y,
87  const size_t num_class )
88 {
89  this -> nclass_ = num_class;
90  this -> Train(X,y);
91 }
92 ///////////////////////////////////////////////////////////////////////////////
93 template<class KERNEL,size_t SOLVER,class T>
94 std::pair<int,int> SVM<KERNEL,SOLVER,T>::_selectset ( arma::Row<T> G,
95  arma::Mat<T> Q )
96 {
97  T inf = arma::datum::inf;
98  T G_max = -inf;
99  T G_min = inf;
100  T obj_min = inf;
101  size_t len = y_.n_elem;
102 
103  int i = -1;
104  for (size_t t=0; t<len; t++)
105  {
106  if ( (y_[t] == +1 && alphas_[t] < C_) || (y_[t] == -1 && alphas_[t] > 0.) )
107  {
108  if ( -y_[t]*G[t] >= G_max )
109  {
110  i = t;
111  G_max = -y_[t] * G[t];
112  }
113  }
114  }
115 
116  int j = -1;
117  for (size_t t=0; t<len; t++)
118  {
119  if ( (y_[t] == +1 && alphas_[t] > 0.) || (y_[t] == -1 && alphas_[t] < C_) )
120  {
121  T b = G_max + y_[t]*G[t];
122  if ( -y_[t]*G[t] <= G_max )
123  G_min = -y_[t]*G[t];
124  if ( b > 0. )
125  {
126  T a = Q(i,i) + Q(t,t) - 2.*y_[i]*y_[t]*Q(i,t);
127  if ( a <= 0. )
128  a = tau_;
129  if (-(b*b)/a <= obj_min)
130  {
131  j = t;
132  obj_min = -(b*b)/a;
133  }
134  }
135 
136  }
137  }
138 
139  if (G_max-G_min < eps_)
140  return {-1,-1};
141  else
142  return {i, j};
143 }
144 ///////////////////////////////////////////////////////////////////////////////
145 template<class KERNEL,size_t SOLVER,class T>
146 void SVM<KERNEL,SOLVER,T>::_fanSMO ( const arma::Mat<T>& X,
147  const arma::Row<size_t>& y )
148 {
149  X_ = X;
150  y_ = (arma::conv_to<arma::Row<int>>::from((y==ulab_(0)) * -2 + 1));
151  size_t N = y_.n_elem;
152  alphas_.resize(N);
153  /* alphas_ = 0.; */
154  arma::Row<T> G(N); G.fill(-1.);
155  arma::Mat<T> K;
156  /* if (N > 200) */
157  /* K = cov_.GetMatrix_approx(X,X,200); */
158  /* else */
159  /* K = cov_.GetMatrix(X,X); */
160  K = cov_.GetMatrix(X,X);
161  arma::Mat<T> Q = (y_.t() * y_) % K;
162  while (max_iter_>iter_++)
163  {
164  auto [i, j] = _selectset(G, Q);
165  if (j == -1) break; // Termination condition if no valid (i, j) is found
166 
167  // Compute `a` and set to tau if non-positive
168  T a = Q(i, i) + Q(j, j) - 2 * y_[i]*y_[j] * Q(i, j);
169  if (a <= 0.)
170  a = tau_;
171 
172  // Compute `b`
173  T b = -y_[i] * G[i] + y_[j] * G[j];
174 
175  // Store old alpha values
176  T oldAi = alphas_[i];
177  T oldAj = alphas_[j];
178 
179  // Update alpha values for i and j
180  alphas_[i] += y_[i] * b / a;
181  alphas_[j] -= y_[j] * b / a;
182 
183  // Project alpha values back to the feasible region [0, C]
184  T sum = y_[i] * oldAi + y_[j] * oldAj;
185 
186  // Project A[i] back to [0, C]
187  alphas_[i] = std::clamp(alphas_[i], T(0.0), C_);
188 
189  // Adjust A[j] based on updated A[i] and maintain feasibility constraint
190  alphas_[j] = y_[j] * (sum - y_[i] * alphas_[i]);
191  alphas_[j] = std::clamp(alphas_[j], T(0.0), C_);
192 
193  // Re-adjust A[i] based on adjusted A[j]
194  alphas_[i] = y_[i] * (sum - y_[j] * alphas_[j]);
195 
196  // Compute changes in alpha
197  T deltaAi = alphas_[i] - oldAi;
198  T deltaAj = alphas_[j] - oldAj;
199 
200  // Vectorized gradient update: G += Q.col(i) * deltaAi + Q.col(j) * deltaAj
201  /* G += (Q.col(i) * deltaAi + Q.col(j) * deltaAj).t(); */
202  size_t len = y_.n_elem;
203  for (size_t h=0; h<len; h++)
204  {
205  G[h] += Q(h,i)*deltaAi+Q(h,j)*deltaAj;
206  }
207  }
208  idx_ = arma::find(alphas_ > tau_);
209 }
210 ///////////////////////////////////////////////////////////////////////////////
211 template<class KERNEL,size_t SOLVER,class T>
212 void SVM<KERNEL,SOLVER,T>::Classify ( const arma::Mat<T>& inputs,
213  arma::Row<size_t>& preds ) const
214 {
215  if (!oneclass_)
216  {
217  arma::Mat<T> temp;
218  if (nclass_==2)
219  {
220  Classify(inputs,preds,temp);
221  }
222  else
223  {
224  ova_.Classify(inputs,preds,temp);
225  }
226  }
227  else
228  {
229  preds.resize(inputs.n_cols);
230  preds.fill(ulab_[0]);
231  }
232 }
233 ///////////////////////////////////////////////////////////////////////////////
234 template<class KERNEL,size_t SOLVER,class T>
235 void SVM<KERNEL,SOLVER,T>::Classify ( const arma::Mat<T>& inputs,
236  arma::Row<size_t>& preds,
237  arma::Mat<T>& probs ) const
238 {
239  arma::Mat<T> dec_func;
240  T b = 0;
241  if (!oneclass_)
242  {
243  if (nclass_==2)
244  {
245  if (idx_.n_elem>0)
246  {
247  probs.set_size(nclass_,inputs.n_cols);
248  preds.set_size(inputs.n_cols);
249  arma::Mat<T> svs = X_.cols(idx_);
250  arma::Mat<T> K = cov_.GetMatrix(svs,inputs);
251  arma::Mat<T> Ksv = cov_.GetMatrix(svs);
252 
253  b = arma::accu(arma::conv_to<arma::Row<T>>::from(y_.cols(idx_))
254  - ((alphas_.cols(idx_) % y_.cols(idx_)) * Ksv)) /idx_.n_elem;
255 
256  dec_func = (alphas_.cols(idx_) % y_.cols(idx_)) * K + b;
257 
258  preds.elem( arma::find( dec_func <= 0.) ).fill(ulab_[0]);
259  preds.elem( arma::find( dec_func > 0.) ).fill(ulab_[1]);
260  probs.row(0) = 1. / (1. + arma::exp(dec_func));
261  probs.row(1) = 1 - probs.row(0);
262  }
263  else
264  {
265  ERR("No support vectors->No prediction");
266  return;
267  }
268  }
269  else
270  ova_.Classify(inputs, preds, probs);
271  }
272  else
273  {
274  probs.resize(nclass_,inputs.n_cols);
275  probs.row(ulab_[0]).fill(1.);
276  preds.resize(inputs.n_cols);
277  preds.fill(ulab_[0]);
278  }
279 }
280 ///////////////////////////////////////////////////////////////////////////////
281 template<class KERNEL,size_t SOLVER,class T>
282 T SVM<KERNEL,SOLVER,T>::ComputeError ( const arma::Mat<T>& points,
283  const arma::Row<size_t>& responses )
284 {
285  arma::Row<size_t> predictions;
286  Classify(points,predictions);
287  arma::Row<size_t> temp = predictions - responses;
288  return (arma::accu(temp != 0))/T(predictions.n_elem);
289 }
290 ///////////////////////////////////////////////////////////////////////////////
291 template<class KERNEL,size_t SOLVER,class T>
292 T SVM<KERNEL,SOLVER,T>::ComputeAccuracy ( const arma::Mat<T>& points,
293  const arma::Row<size_t>& responses )
294 {
295  return (1. - ComputeError(points, responses));
296 }
297 ///////////////////////////////////////////////////////////////////////////////
298 } // namespace classification
299 } // namespace algo
300 #endif
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t num_class)
Definition: svm_impl.h:85
T ComputeAccuracy(const arma::Mat< T > &points, const arma::Row< size_t > &responses)
Definition: svm_impl.h:292
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses)
Definition: svm_impl.h:282
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
Definition: svm_impl.h:212