9 #ifndef NONPARAMCLASS_IMPL_H
10 #define NONPARAMCLASS_IMPL_H
13 namespace classification {
20 const arma::Row<size_t>& labels,
21 const size_t& num_class ): k_(1), nclass_(num_class)
23 Train(inputs, labels);
28 const arma::Row<size_t>& labels,
29 const size_t& num_class,
30 const size_t& k ) : k_(k), nclass_(num_class)
32 Train(inputs, labels);
37 const arma::Row<size_t>& labels )
40 nuclass_ = arma::unique(labels).eval().n_elem;
42 unique_ = arma::unique(labels);
44 unique_ = arma::regspace<arma::Row<size_t>>(0,nclass_-1);
46 size_ = inputs.n_cols;
49 if ( k_ > inputs.n_cols )
55 const arma::Row<size_t>& labels,
56 const size_t num_class )
58 this ->nuclass_ = num_class;
59 this -> Train(inputs,labels);
64 arma::Row<size_t>& labels )
const
67 Classify(inputs, labels, temp);
72 arma::Row<size_t>& labels,
73 arma::Mat<T>& probs )
const
75 const size_t N = inputs.n_cols;
76 probs.resize(nclass_,N);
80 labels.fill(unique_(0));
81 probs.row(unique_(0)).fill(1);
87 mlpack::KNN knn(inputs_);
89 arma::Mat<size_t> neig;
90 arma::Mat<size_t> select;
93 knn.Search(inputs, k_, neig, dist);
94 arma::Col<size_t> unq;
98 for (
size_t j=0; j<N; j++ )
101 select = labels_(arma::conv_to<arma::uvec>::from(neig.col(j)));
103 auto count = arma::hist(select,unique_);
105 arma::conv_to<arma::Col<T>>::from(count);
106 probs.col(j) = ps / arma::accu(ps);
108 labels(j) = unique_(count.index_max());
115 const arma::Row<size_t>& responses )
const
117 arma::Row<size_t> predictions;
118 Classify(points,predictions);
119 arma::Row<size_t> temp = predictions - responses;
120 double total = responses.n_cols;
122 return (arma::accu(temp != 0))/total;
127 const arma::Row<size_t>& responses )
const
129 return (1. - ComputeError(points, responses))*100;
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
T ComputeAccuracy(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const