Learning Curve Plus Plus (LCPP)
nonparamclass_impl.h
Go to the documentation of this file.
1 /**
2  * @file nonparamclass_impl.h
3  * @author Ozgur Taylan Turan
4  *
5  * Nonparametric Classifiers
6  *
7  */
8 
9 #ifndef NONPARAMCLASS_IMPL_H
10 #define NONPARAMCLASS_IMPL_H
11 
12 namespace algo {
13 namespace classification {
14 
15 //-----------------------------------------------------------------------------
16 // Nearest Neighbour Classifier
17 //-----------------------------------------------------------------------------
18 template<class T>
19 NNC<T>::NNC ( const arma::Mat<T>& inputs,
20  const arma::Row<size_t>& labels,
21  const size_t& num_class ): k_(1), nclass_(num_class)
22 {
23  Train(inputs, labels);
24 }
25 ///////////////////////////////////////////////////////////////////////////////
26 template<class T>
27 NNC<T>::NNC ( const arma::Mat<T>& inputs,
28  const arma::Row<size_t>& labels,
29  const size_t& num_class,
30  const size_t& k ) : k_(k), nclass_(num_class)
31 {
32  Train(inputs, labels);
33 }
34 ///////////////////////////////////////////////////////////////////////////////
35 template<class T>
36 void NNC<T>::Train ( const arma::Mat<T>& inputs,
37  const arma::Row<size_t>& labels )
38 {
39  dim_ = inputs.n_rows;
40  nuclass_ = arma::unique(labels).eval().n_elem;
41  if (nuclass_ == 1)
42  unique_ = arma::unique(labels);
43  else
44  unique_ = arma::regspace<arma::Row<size_t>>(0,nclass_-1);
45 
46  size_ = inputs.n_cols;
47  inputs_ = inputs;
48  labels_ = labels;
49  if ( k_ > inputs.n_cols )
50  k_ = inputs.n_cols-1;
51 }
52 ///////////////////////////////////////////////////////////////////////////////
53 template<class T>
54 void NNC<T>::Train ( const arma::Mat<T>& inputs,
55  const arma::Row<size_t>& labels,
56  const size_t num_class )
57 {
58  this ->nuclass_ = num_class;
59  this -> Train(inputs,labels);
60 }
61 ///////////////////////////////////////////////////////////////////////////////
62 template<class T>
63 void NNC<T>::Classify ( const arma::Mat<T>& inputs,
64  arma::Row<size_t>& labels ) const
65 {
66  arma::Mat<T> temp;
67  Classify(inputs, labels, temp);
68 }
69 ///////////////////////////////////////////////////////////////////////////////
70 template<class T>
71 void NNC<T>::Classify ( const arma::Mat<T>& inputs,
72  arma::Row<size_t>& labels,
73  arma::Mat<T>& probs ) const
74 {
75  const size_t N = inputs.n_cols;
76  probs.resize(nclass_,N);
77  labels.resize(N);
78  if ( nuclass_ == 1 )
79  {
80  labels.fill(unique_(0));
81  probs.row(unique_(0)).fill(1);
82  }
83  else
84  {
85  // Potentially much faster implementation
86 
87  mlpack::KNN knn(inputs_);
88  arma::Mat<T> dist;
89  arma::Mat<size_t> neig;
90  arma::Mat<size_t> select;
91 
92  // Find the nns to all the samples
93  knn.Search(inputs, k_, neig, dist);
94  arma::Col<size_t> unq;
95 
96 
97  /* #pragma omp parallel for */
98  for ( size_t j=0; j<N; j++ )
99  {
100  // select the labels of the nns per sample
101  select = labels_(arma::conv_to<arma::uvec>::from(neig.col(j)));
102  // count how many classes does one sample has per sample
103  auto count = arma::hist(select,unique_);
104  auto ps =
105  arma::conv_to<arma::Col<T>>::from(count);
106  probs.col(j) = ps / arma::accu(ps);
107  // assign the maximum number of seen class
108  labels(j) = unique_(count.index_max());
109  }
110  }
111 }
112 ///////////////////////////////////////////////////////////////////////////////
113 template<class T>
114 T NNC<T>::ComputeError ( const arma::Mat<T>& points,
115  const arma::Row<size_t>& responses ) const
116 {
117  arma::Row<size_t> predictions;
118  Classify(points,predictions);
119  arma::Row<size_t> temp = predictions - responses;
120  double total = responses.n_cols;
121 
122  return (arma::accu(temp != 0))/total;
123 }
124 ///////////////////////////////////////////////////////////////////////////////
125 template<class T>
126 T NNC<T>::ComputeAccuracy ( const arma::Mat<T>& points,
127  const arma::Row<size_t>& responses ) const
128 {
129  return (1. - ComputeError(points, responses))*100;
130 }
131 ///////////////////////////////////////////////////////////////////////////////
132 } // namespace classification
133 } // namespace algo
134 #endif
135 
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
T ComputeAccuracy(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const