9 #ifndef MULTICLASS_IMPL_H
10 #define MULTICLASS_IMPL_H
13 namespace classification {
17 template<
class MODEL,
class T>
18 template<
class... Args>
22 models_.resize(num_class);
23 for(
size_t i=0;i<num_class;i++)
24 models_[i] = MODEL(args...);
27 template<
class MODEL,
class T>
28 template<
class... Args>
30 const arma::Row<size_t>& labels,
31 const size_t& num_class,
32 const Args&... args ) : nclass_(num_class)
34 unq_ = arma::unique(labels).eval();
38 unclass_ = unq_.n_elem;
39 models_.resize(unclass_);
41 for(
size_t i=0;i<unclass_;i++)
42 models_[i] = MODEL(args...);
44 Train(inputs, labels, args...);
48 template<
class MODEL,
class T>
49 template<
class... Args>
51 const arma::Row<size_t>& labels )
53 unq_ = arma::unique(labels).eval();
57 unclass_ = unq_.n_elem;
60 for(
size_t i=0;i<unclass_;i++)
62 auto binlabels = arma::conv_to<arma::Row<size_t>>::from(labels==unq_(i));
63 models_[i].Train(inputs, binlabels);
70 template<
class MODEL,
class T>
71 template<
class... Args>
73 const arma::Row<size_t>& labels,
78 models_.resize(unq_.n_elem);
79 for(
size_t i=0;i<unclass_;i++)
81 auto binlabels = arma::conv_to<arma::Row<size_t>>::from(labels==unq_(i));
82 MODEL model(inputs,binlabels,args...);
88 template<
class MODEL,
class T>
90 arma::Row<size_t>& preds )
const
93 Classify(inputs, preds, probs);
96 template<
class MODEL,
class T>
98 arma::Row<size_t>& preds,
99 arma::Mat<T>& probs )
const
101 probs.resize(nclass_,inputs.n_cols);
102 preds.resize(inputs.n_cols);
105 for(
size_t i=0;i<unclass_;i++)
107 arma::Row<size_t> temp;
109 models_[i].Classify(inputs,temp,tprobs);
110 probs.row(unq_(i)) = tprobs.row(1);
112 preds = arma::conv_to<arma::Row<size_t>>::from(arma::index_max(arma::abs(probs),0));
116 arma::uvec cantdecide = arma::find(arma::sum(probs,0) == 0);
117 arma::uvec randomsel = arma::randi<arma::uvec>(cantdecide.n_elem,
118 arma::distr_param(0,nclass_-1));
119 probs.elem(randomsel+ cantdecide*nclass_).ones();
120 probs = probs.each_row() / arma::sum(probs,0);
127 template<
class MODEL,
class T>
129 const arma::Row<size_t>& labels )
131 arma::Row<size_t> predictions;
132 Classify(inputs, predictions);
133 arma::Row<size_t> temp = predictions - labels;
134 return (arma::accu(temp != 0))/T(predictions.n_elem);
137 template<
class MODEL,
class T>
139 const arma::Row<size_t>& labels )
141 return 1.-ComputeError(inputs,labels);
T ComputeAccuracy(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &preds) const
T ComputeError(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const Args &... args)