Learning Curve Plus Plus (LCPP)
multiclass_impl.h
Go to the documentation of this file.
1 /**
2  * @file multiclass_impl.h
3  * @author Ozgur Taylan Turan
4  *
5  * MultiClass Classifier for binary classifiers
6  *
7  */
8 
9 #ifndef MULTICLASS_IMPL_H
10 #define MULTICLASS_IMPL_H
11 
12 namespace algo {
13 namespace classification {
14 //-----------------------------------------------------------------------------
15 // OnevAll
16 //-----------------------------------------------------------------------------
17 template<class MODEL, class T>
18 template<class... Args>
19 OnevAll<MODEL,T>::OnevAll( const size_t& num_class, const Args&... args ) :
20  nclass_(num_class)
21 {
22  models_.resize(num_class);
23  for(size_t i=0;i<num_class;i++)
24  models_[i] = MODEL(args...);
25 }
26 ///////////////////////////////////////////////////////////////////////////////
27 template<class MODEL, class T>
28 template<class... Args>
29 OnevAll<MODEL,T>::OnevAll( const arma::Mat<T>& inputs,
30  const arma::Row<size_t>& labels,
31  const size_t& num_class,
32  const Args&... args ) : nclass_(num_class)
33 {
34  unq_ = arma::unique(labels).eval();
35  if (unq_.n_elem == 1)
36  oneclass_ = true;
37 
38  unclass_ = unq_.n_elem;
39  models_.resize(unclass_);
40 
41  for(size_t i=0;i<unclass_;i++)
42  models_[i] = MODEL(args...);
43 
44  Train(inputs, labels, args...);
45 
46 }
47 ///////////////////////////////////////////////////////////////////////////////
48 template<class MODEL, class T>
49 template<class... Args>
50 void OnevAll<MODEL,T>::Train ( const arma::Mat<T>& inputs,
51  const arma::Row<size_t>& labels )
52 {
53  unq_ = arma::unique(labels).eval();
54  if (unq_.n_elem == 1)
55  oneclass_ = true;
56 
57  unclass_ = unq_.n_elem;
58  if (!oneclass_)
59  {
60  for(size_t i=0;i<unclass_;i++)
61  {
62  auto binlabels = arma::conv_to<arma::Row<size_t>>::from(labels==unq_(i));
63  models_[i].Train(inputs, binlabels);
64  }
65  }
66  else
67  return;
68 }
69 ///////////////////////////////////////////////////////////////////////////////
70 template<class MODEL, class T>
71 template<class... Args>
72 void OnevAll<MODEL,T>::Train ( const arma::Mat<T>& inputs,
73  const arma::Row<size_t>& labels,
74  const Args&... args )
75 {
76  if (!oneclass_)
77  {
78  models_.resize(unq_.n_elem);
79  for(size_t i=0;i<unclass_;i++)
80  {
81  auto binlabels = arma::conv_to<arma::Row<size_t>>::from(labels==unq_(i));
82  MODEL model(inputs,binlabels,args...);
83  models_[i] = model;
84  }
85  }
86 }
87 ///////////////////////////////////////////////////////////////////////////////
88 template<class MODEL, class T>
89 void OnevAll<MODEL,T>::Classify( const arma::Mat<T>& inputs,
90  arma::Row<size_t>& preds ) const
91 {
92  arma::Mat<T> probs;
93  Classify(inputs, preds, probs);
94 }
95 ///////////////////////////////////////////////////////////////////////////////
96 template<class MODEL, class T>
97 void OnevAll<MODEL,T>::Classify( const arma::Mat<T>& inputs,
98  arma::Row<size_t>& preds,
99  arma::Mat<T>& probs ) const
100 {
101  probs.resize(nclass_,inputs.n_cols);
102  preds.resize(inputs.n_cols);
103  if (!oneclass_)
104  {
105  for(size_t i=0;i<unclass_;i++)
106  {
107  arma::Row<size_t> temp;
108  arma::Mat<T> tprobs;
109  models_[i].Classify(inputs,temp,tprobs);
110  probs.row(unq_(i)) = tprobs.row(1);
111  }
112  preds = arma::conv_to<arma::Row<size_t>>::from(arma::index_max(arma::abs(probs),0));
113 
114  // We cannot decide and randomly assign a class
115 
116  arma::uvec cantdecide = arma::find(arma::sum(probs,0) == 0);
117  arma::uvec randomsel = arma::randi<arma::uvec>(cantdecide.n_elem,
118  arma::distr_param(0,nclass_-1));
119  probs.elem(randomsel+ cantdecide*nclass_).ones();
120  probs = probs.each_row() / arma::sum(probs,0);
121 
122  }
123  else
124  preds.fill(unq_[0]);
125 }
126 ///////////////////////////////////////////////////////////////////////////////
127 template<class MODEL, class T>
128 T OnevAll<MODEL,T>::ComputeError ( const arma::Mat<T>& inputs,
129  const arma::Row<size_t>& labels )
130 {
131  arma::Row<size_t> predictions;
132  Classify(inputs, predictions);
133  arma::Row<size_t> temp = predictions - labels;
134  return (arma::accu(temp != 0))/T(predictions.n_elem);
135 }
136 ///////////////////////////////////////////////////////////////////////////////
137 template<class MODEL, class T>
138 T OnevAll<MODEL,T>::ComputeAccuracy ( const arma::Mat<T>& inputs,
139  const arma::Row<size_t>& labels )
140 {
141  return 1.-ComputeError(inputs,labels);
142 }
143 ///////////////////////////////////////////////////////////////////////////////
144 
145 } // namespace classification
146 } // namespace algo
147 
148 #endif
T ComputeAccuracy(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &preds) const
T ComputeError(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const Args &... args)