Learning Curve Plus Plus (LCPP)
nn.h
Go to the documentation of this file.
1 /**
2  * @file nn.h
3  * @author Ozgur Taylan Turan
4  *
5  * Simple neural network wrapper for using learning curve generation and
6  * hyper-parameter tuning.
7  *
8  * TODO:
9  * Pre-trained model loading.
10  *
11  *
12  */
13 #ifndef NN_H
14 #define NN_H
15 
16 namespace algo {
17 
18 template<class NET,
19  class OPT=ens::StandardSGD,
20  class MET=mlpack::MeanSquaredError,
21  class O=DTYPE>
22 class ANN
23 {
24 public:
25  /**
26  * Non-working model
27  */
28  ANN ( ) { };
29 
30  /**
31  * @param network : pointer to the network object
32  * @param args : optimizer arguments in the given order
33  */
34  /* template<class... OptArgs> */
35  /* ANN ( NET* network, bool early = false, const OptArgs&... args ); */
36 
37  /**
38  * @param inputs : input data, X
39  * @param labels : labels of the input, y
40  */
41  ANN ( const arma::Mat<O>& inputs,
42  const arma::Mat<O>& labels ) ;
43 
44  /**
45  * @param inputs : input data, X
46  * @param labels : labels of the input, y
47  * @param network : network object
48  */
49 
50  ANN ( const arma::Mat<O>& inputs,
51  const arma::Mat<O>& labels,
52  const NET network ) ;
53 
54  template<class... OptArgs>
55  ANN ( const arma::Mat<O>& inputs,
56  const arma::Mat<O>& labels,
57  const NET network, bool early = false, const OptArgs&... args ) ;
58 
59  template<class... OptArgs>
60  ANN ( const arma::Mat<O>& inputs,
61  const arma::Row<size_t>& labels,
62  const NET network, bool early = false, const OptArgs&... args ) ;
63  /**
64  * @param inputs : input data, X
65  * @param labels : labels of the input, y
66  */
67  void Train( const arma::Mat<O>& inputs, const arma::Mat<O>& labels );
68 
69  template<class... OptArgs>
70  void Train( const arma::Mat<O>& inputs, const arma::Mat<O>& labels,
71  const NET network, bool early = false, const OptArgs&... args );
72 
73  template<class... OptArgs>
74  void Train( const arma::Mat<O>& inputs, const arma::Row<size_t>& labels,
75  const NET network, bool early = false, const OptArgs&... args );
76 
77  void Train( const arma::Mat<O>& inputs, const arma::Row<size_t>& labels );
78 
79  /**
80  * @param inputs : input data, X
81  * @param preds : prediction of labels of the input, \hat{y}
82  */
83  void Predict( const arma::Mat<O>& inputs, arma::Mat<O>& preds );
84 
85  /**
86  * @param inputs : input data, X
87  * @param preds : prediction of labels of the input, \hat{y}
88  */
89  void Classify( const arma::Mat<O>& inputs, arma::Row<size_t>& preds );
90 
91 
92  /**
93  * @param inputs : input data, X
94  * @param labels : labels of the input, y
95  */
96  O ComputeError( const arma::Mat<O>& inputs, const arma::Mat<O>& labels );
97 
98  arma::Mat<O> Parameters( );
99 
100  template<class Archive>
101  void serialize(Archive& ar)
102  {
103  ar( CEREAL_NVP(network_),
104  CEREAL_NVP(early_),
105  CEREAL_NVP(ulab_) );
106  }
107 
108 
109 private:
110  /**
111  * @param labels : labels to be encoded
112  * @param map : expected labels
113  */
114  arma::Mat<O> _OneHotEncode ( const arma::Row<size_t>& labels,
115  const arma::Row<size_t>& ulabels );
116 
117  /**
118  * @param labels : labels to be decoded
119  * @param map : expected labels
120  */
121  arma::Row<size_t> _OneHotDecode ( const arma::Mat<O>& labels,
122  const arma::Row<size_t>& ulabels);
123 
124  // A unique pointer for the optimizer. This is needed because of design
125  // choices in ensmallen optimizers based on SGD.
126  std::unique_ptr<OPT> opt_; // optimizer pointer
127 
128  // A predefined network pointer. We have to define the network outside with
129  // this construction. This gives more freedom on how you want to create your
130  // network. Moreover, it gives you the flexibility to
131  /* NET* network_; // network pointer */
132  NET network_; // network pointer
133 
134  bool early_; // Early stopping flag
135 
136  // ONLY VALID FOR CLASSIFICATION PROBLEMS
137  arma::Row<size_t> ulab_; // Just a container for the potential unique labels
138 
139 };
140 
141 } // namespace algo
142 
143 #include "nn_impl.h"
144 
145 #endif
Definition: nn.h:23
ANN()
Definition: nn.h:28
void Train(const arma::Mat< O > &inputs, const arma::Mat< O > &labels)
Definition: nn_impl.h:102
O ComputeError(const arma::Mat< O > &inputs, const arma::Mat< O > &labels)
Definition: nn_impl.h:189
void Classify(const arma::Mat< O > &inputs, arma::Row< size_t > &preds)
Definition: nn_impl.h:180
void Predict(const arma::Mat< O > &inputs, arma::Mat< O > &preds)
Definition: nn_impl.h:173