9 #ifndef PARAMCLASS_IMPL_H
10 #define PARAMCLASS_IMPL_H
13 namespace classification {
20 const arma::Row<size_t>& labels,
21 const size_t& num_class ) : num_class_(num_class), lambda_(0.)
23 Train(inputs, labels);
28 const arma::Row<size_t>& labels,
29 const size_t& num_class,
31 const arma::Row<T>& priors ) : num_class_(num_class),
32 lambda_(lambda), priors_(priors)
34 Train(inputs, labels);
39 const arma::Row<size_t>& labels,
40 const size_t& num_class,
41 const double& lambda ) : num_class_(num_class), lambda_(lambda)
43 Train(inputs, labels);
48 const arma::Row<size_t>& labels )
50 class_ = arma::regspace<arma::Row<size_t>>(0,1,num_class_);
51 priors_ = get_prior<T>(labels, num_class_);
54 size_ = inputs.n_cols;
55 unique_ = arma::unique(labels);
56 if (unique_.n_elem != 1)
58 cov_.resize(dim_,dim_);
61 arma::Row<size_t>::iterator it = unique_.begin();
62 arma::Row<size_t>::iterator end = unique_.end();
68 auto extract = extract_class(inputs, labels, *it);
69 inx = std::get<0>(extract);
70 means_[*it] = arma::conv_to<arma::Row<T>>::from(arma::mean(inx,1));
71 if ( inx.n_cols == 1 )
72 covs_[*it] = arma::eye<arma::Mat<T>>(dim_,dim_);
75 covs_[*it] = arma::cov(inx.t());
76 covs_[*it].diag() += jitter_+lambda_;
80 cov_ = arma::pinv(cov_) / num_class_;
86 const arma::Row<size_t>& labels,
87 const size_t num_class )
89 this->num_class_ = num_class;
90 this->Train(inputs,labels);
95 arma::Row<size_t>& labels )
const
98 Classify(inputs, labels, temp);
103 arma::Row<size_t>& labels,
104 arma::Mat<T>& probs )
const
106 const size_t N = inputs.n_cols;
108 probs.resize(num_class_,N);
109 if ( unique_.n_elem == 1 )
111 labels.fill(unique_(0));
112 probs.row(unique_(0)).fill(1.);
116 #pragma omp parallel for
117 for (
size_t n=0; n<inputs.n_cols; n++ )
119 for (
size_t c=0; c<unique_.n_elem; c++ )
121 probs(class_(unique_(c)),n) = std::log(priors_(c))
122 - 0.5*arma::dot(means_.at(unique_(c))*
123 cov_, means_.at(unique_(c)))
124 + arma::dot(inputs.col(n).t()*cov_,means_.at(unique_(c)));
126 labels(n) = class_(probs.col(n).index_max());
129 probs = arma::exp(probs.each_row() - arma::max(probs,0));
130 probs = probs.each_row()/arma::sum(probs,0);
137 const arma::Row<size_t>& responses )
const
139 arma::Row<size_t> predictions;
140 Classify(points,predictions);
141 arma::Row<size_t> temp = predictions - responses;
142 return (arma::accu(temp != 0))/T(predictions.n_elem);
147 const arma::Row<size_t>& responses )
const
149 return (1. - ComputeError(points, responses))*100;
157 const arma::Row<size_t>& labels,
158 const size_t& num_class ) : num_class_(num_class), lambda_(0.)
160 Train(inputs, labels);
165 const arma::Row<size_t>& labels,
166 const size_t& num_class,
167 const double& lambda,
168 const arma::Row<T>& priors ) : num_class_(num_class),
169 lambda_(lambda), priors_(priors)
171 Train(inputs, labels);
176 const arma::Row<size_t>& labels,
177 const size_t& num_class,
178 const double& lambda ) : num_class_(num_class), lambda_(lambda)
180 Train(inputs, labels);
185 const arma::Row<size_t>& labels )
188 dim_ = inputs.n_rows;
189 size_ = inputs.n_cols;
190 unique_ = arma::unique(labels);
192 class_ = arma::regspace<arma::Row<size_t>>(0,1,num_class_);
193 priors_ = get_prior<T>(labels,num_class_);
195 arma::Row<size_t>::iterator it = unique_.begin();
196 arma::Row<size_t>::iterator end = unique_.end();
202 auto extract = extract_class(inputs, labels, *it);
204 inx = std::get<0>(extract);
205 means_[*it] = arma::conv_to<arma::Row<T>>::from(arma::mean(inx,1));
206 if ( inx.n_cols == 1 )
208 covs_[*it] = arma::eye<arma::Mat<T>>(dim_,dim_);
209 icovs_[*it] = arma::eye<arma::Mat<T>>(dim_,dim_);
213 covs_[*it] = arma::cov(inx.t());
214 covs_[*it].diag() += jitter_+lambda_;
215 icovs_[*it] = arma::pinv(covs_[*it]);
217 icovs_[*it] = arma::pinv(covs_[*it]);
223 const arma::Row<size_t>& labels,
224 const size_t num_class )
226 this -> num_class_=num_class;
227 this -> Train(inputs,labels);
232 arma::Row<size_t>& labels )
const
235 Classify(inputs, labels, temp);
240 arma::Row<size_t>& labels,
241 arma::Mat<T>& probs )
const
243 const size_t N = inputs.n_cols;
245 probs.resize(num_class_,N);
247 if ( num_class_ == 1 )
249 labels.fill(unique_(0));
250 probs.row(unique_(0)).fill(1.);
257 for (
size_t n=0; n<inputs.n_cols; n++ )
259 for (
size_t c=0; c<unique_.n_elem; c++ )
261 norm = inputs.col(n).t() - means_.at(unique_(c));
262 probs(class_(unique_(c)),n) = std::log(priors_(c))
263 - 0.5*(arma::det(covs_.at(unique_(c)))+inputs.n_rows*std::log(2*arma::datum::pi))
264 - 0.5* arma::dot(norm*icovs_.at(unique_(c)),norm);
266 labels(n) = class_(probs.col(n).index_max());
269 probs = arma::exp(probs.each_row() - arma::max(probs,0));
270 probs = probs.each_row()/arma::sum(probs,0);
276 const arma::Row<size_t>& responses )
const
278 arma::Row<size_t> predictions;
279 Classify(points,predictions);
280 arma::Row<size_t> temp = predictions - responses;
281 return (arma::accu(temp != 0))/T(predictions.n_elem);
286 const arma::Row<size_t>& responses )
const
288 return (1. - ComputeError(points, responses))*100;
295 const arma::Row<size_t>& labels,
296 const size_t& num_class,
297 const double& shrink ) : shrink_(shrink), num_class_(num_class)
299 Train(inputs, labels);
304 const arma::Row<size_t>& labels,
305 const size_t& num_class ) : shrink_(0.), num_class_(num_class)
307 Train(inputs, labels);
312 const arma::Row<size_t>& labels )
314 dim_ = inputs.n_rows;
315 unique_ = arma::unique(labels);
317 size_ = inputs.n_cols;
318 arma::vec nk(num_class_);
319 parameters_.resize(inputs.n_rows, num_class_);
321 arma::Row<size_t>::iterator it = unique_.begin();
322 arma::Row<size_t>::iterator it_end = unique_.end();
327 for ( ;it!=it_end; ++it)
329 auto extract = extract_class(inputs, labels, *it);
330 index = std::get<1>(extract);
331 nk(counter) = index.n_rows;
332 parameters_.col(counter) = arma::mean(inputs.cols(index),1);
335 centroid_ = arma::mean(inputs, 1);
338 if (shrink_ > 0 && num_class_ != 1)
340 arma::Mat<T> nk = arma::ones<arma::Mat<T>>(num_class_,1);
341 arma::Mat<T> m = arma::sqrt(1./nk)-1/size_;
342 arma::uvec labs = arma::conv_to<arma::uvec>::from(labels);
343 arma::Mat<T> variance = arma::sum(
344 arma::pow(inputs - parameters_.cols(labs),2),1);
345 arma::Mat<T> s = arma::sqrt(variance/(size_-num_class_)).t();
346 arma::Mat<T> ms = m*s;
347 arma::inplace_trans(ms);
348 arma::Mat<T> devi = (parameters_.each_col() - centroid_) / ms;
349 arma::Mat<T> signs = arma::sign(devi);
350 arma::Mat<T> dev = arma::abs(devi) - shrink_;
351 dev = arma::clamp(dev, 0, arma::datum::inf);
353 arma::Mat<T> msd = dev % ms;
354 parameters_ = centroid_ + msd.each_col();
360 const arma::Row<size_t>& labels,
361 const size_t num_class )
363 this -> num_class_ = num_class;
364 this -> Train (inputs,labels);
369 arma::Row<size_t>& labels )
const
372 Classify(inputs,labels,temp);
377 arma::Row<size_t>& labels,
378 arma::Mat<T>& probs )
const
380 const size_t N = inputs.n_cols;
381 probs.resize(num_class_, N);
383 if ( unique_.n_elem == 1 )
385 labels.fill(unique_(0));
386 probs.row(unique_(0)).fill(1.);
390 arma::Mat<T> distances(num_class_, N);
391 distances.fill(arma::datum::inf);
392 for (
size_t j=0; j<N; j++ )
394 for (
size_t i=0; i<unique_.n_elem; i++ )
396 distances(unique_(i), j) = metric_.Evaluate(parameters_.col(unique_(i))
399 labels(j) = arma::index_min(distances.col(j));
401 if (unique_.n_elem == num_class_)
402 probs = distances.each_row() / arma::sum(distances,0);
405 for (arma::uword i = 0; i < distances.n_rows; ++i)
408 if (arma::is_finite(distances.row(i)))
411 probs.row(i) = distances.row(i) / arma::sum(distances, 0);
414 probs.elem(arma::find_nonfinite(probs)).fill(1/num_class_);
420 const arma::Row<size_t>& responses )
const
422 arma::Row<size_t> predictions;
424 Classify(points,predictions);
425 arma::Row<size_t> temp = predictions - responses;
426 return (arma::accu(temp != 0))/T(predictions.n_elem);
431 const arma::Row<size_t>& responses )
const
433 return (1. - ComputeError(points, responses))*100;
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
T ComputeAccuracy(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
T ComputeAccuracy(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
T ComputeAccuracy(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const