Learning Curve Plus Plus (LCPP)
multiclass.h
Go to the documentation of this file.
1 /**
2  * @file multiclass.h
3  * @author Ozgur Taylan Turan
4  *
5  * MultiClass Classifier for binary classifiers
6  *
7  * TODO:
8  * One vs One classifier
9  *
10  */
11 
12 #ifndef MULTICLASS_H
13 #define MULTICLASS_H
14 
15 namespace algo {
16 namespace classification {
17 //-----------------------------------------------------------------------------
18 // OnevAll : one vs rest classifier. This classifier trains a model for every
19 // label seperately and uses the best prediction from those for prediction.
20 //-----------------------------------------------------------------------------
21 template<class MODEL, class T=DTYPE>
22 class OnevAll
23 {
24 public:
25  /**
26  * Non-working model
27  */
28  OnevAll ( ) = default;
29 
30  /**
31  * @param args : parameters for the model
32  */
33  template<class... Args>
34  OnevAll ( const size_t& num_class, const Args&... args );
35 
36 
37  /**
38  * @param inputs : X
39  * @param labels : y
40  * @param args : parameters for the model
41  */
42  template<class... Args>
43  OnevAll ( const arma::Mat<T>& inputs,
44  const arma::Row<size_t>& labels,
45  const size_t& num_class,
46  const Args&... args );
47 
48  /**
49  * @param inputs : X
50  * @param labels : y
51  * @param args : parameters for the model
52  */
53  template<class... Args>
54  void Train ( const arma::Mat<T>& inputs,
55  const arma::Row<size_t>& labels,
56  const Args&... args );
57 
58  /**
59  * @param inputs : X
60  * @param labels : y
61  * @param args : parameters for the model
62  */
63  template<class... Args>
64  void Train ( const arma::Mat<T>& inputs,
65  const arma::Row<size_t>& labels );
66  /**
67  * @param inputs : X*
68  * @param preds : y*
69  */
70  void Classify ( const arma::Mat<T>& inputs,
71  arma::Row<size_t>& preds ) const ;
72  /**
73  * @param inputs : X*
74  * @param preds : y*
75  * @param probs : p*
76  */
77  void Classify ( const arma::Mat<T>& inputs,
78  arma::Row<size_t>& preds,
79  arma::Mat<T>& probs ) const ;
80  /**
81  * @param inputs : X
82  * @param labels : y
83  */
84  T ComputeError ( const arma::Mat<T>& inputs,
85  const arma::Row<size_t>& labels );
86 
87  /**
88  * @param inputs : X
89  * @param labels : y
90  */
91  T ComputeAccuracy ( const arma::Mat<T>& inputs,
92  const arma::Row<size_t>& labels );
93 
94  std::vector<MODEL> models_;
95 
96  /**
97  * Serialize the model.
98  */
99  template<typename Archive>
100  void serialize ( Archive& ar,
101  const unsigned int /* version */ )
102  {
103  ar ( cereal::make_nvp("nclass",nclass_),
104  cereal::make_nvp("unclass",unclass_),
105  cereal::make_nvp("unq",unq_),
106  cereal::make_nvp("models",models_),
107  cereal::make_nvp("oneclass",oneclass_));
108  }
109 
110 private:
111  size_t nclass_; // number of classes of the problem
112  size_t unclass_; // unique number of classes observed
113  arma::Row<size_t> unq_; // unique classes observed
114  bool oneclass_ = false; // This is to trigger if you have only one class input
115 };
116 
117 } // namespace classification
118 } // namespace algo
119 
120 #include "multiclass_impl.h"
121 
122 #endif
T ComputeAccuracy(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &preds) const
T ComputeError(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const Args &... args)
void serialize(Archive &ar, const unsigned int)
Definition: multiclass.h:100