Learning Curve Plus Plus (LCPP)
nn_impl.h
Go to the documentation of this file.
1 /**
2  * @file nn_impl.h
3  * @author Ozgur Taylan Turan
4  *
5  * Simple neural network wrapper for using learning curve generation and
6  * hyper-parameter tuning.
7  *
8  */
9 #ifndef NN_IMPL_H
10 #define NN_IMPL_H
11 
12 namespace algo {
13 
14 //-----------------------------------------------------------------------------
15 // ANN
16 //-----------------------------------------------------------------------------
17 /* template<class NET,class OPT,class MET,class O> */
18 /* template<class... OptArgs> */
19 /* ANN<NET,OPT,MET,O>::ANN ( NET* network, bool early, const OptArgs&... args ) */
20 /* { */
21 /* network_ = network; */
22 /* early_ = early; */
23 /* opt_ = std::make_unique<OPT>(args...); */
24 /* } */
25 
26 template<class NET,class OPT,class MET,class O>
27 ANN<NET,OPT,MET,O>::ANN ( const arma::Mat<O>& inputs,
28  const arma::Mat<O>& labels )
29 {
30  early_ = false;
31  opt_ = std::make_unique<OPT>();
32  Train(inputs,labels);
33 }
34 ///////////////////////////////////////////////////////////////////////////////
35 template<class NET,class OPT,class MET,class O>
36 ANN<NET,OPT,MET,O>::ANN ( const arma::Mat<O>& inputs,
37  const arma::Mat<O>& labels,
38  const NET network )
39 
40 {
41  network_ = network;
42  early_ = false;
43  opt_ = std::make_unique<OPT>();
44  Train(inputs,labels);
45 }
46 ///////////////////////////////////////////////////////////////////////////////
47 template<class NET,class OPT,class MET,class O>
48 template<class... OptArgs>
49 ANN<NET,OPT,MET,O>::ANN ( const arma::Mat<O>& inputs,
50  const arma::Mat<O>& labels,
51  const NET network, bool early, const OptArgs&... args )
52 
53 {
54  network_ = network;
55  early_ = early;
56  opt_ = std::make_unique<OPT>(args...);
57  Train(inputs,labels);
58 }
59 ///////////////////////////////////////////////////////////////////////////////
60 template<class NET,class OPT,class MET,class O>
61 template<class... OptArgs>
62 ANN<NET,OPT,MET,O>::ANN ( const arma::Mat<O>& inputs,
63  const arma::Row<size_t>& labels,
64  const NET network, bool early, const OptArgs&... args )
65 
66 {
67  network_ = network;
68  early_ = early;
69  opt_ = std::make_unique<OPT>(args...);
70  Train(inputs,labels);
71 }
72 ///////////////////////////////////////////////////////////////////////////////
73 template<class NET,class OPT,class MET,class O>
74 template<class... OptArgs>
75 void ANN<NET,OPT,MET,O>::Train ( const arma::Mat<O>& inputs,
76  const arma::Row<size_t>& labels,
77  const NET network, bool early,
78  const OptArgs&... args )
79 
80 {
81  network_ = network;
82  early_ = early;
83  opt_ = std::make_unique<OPT>(args...);
84  Train(inputs,labels);
85 }
86 ///////////////////////////////////////////////////////////////////////////////
87 template<class NET,class OPT,class MET,class O>
88 template<class... OptArgs>
89 void ANN<NET,OPT,MET,O>::Train ( const arma::Mat<O>& inputs,
90  const arma::Mat<O>& labels,
91  const NET network, bool early,
92  const OptArgs&... args )
93 
94 {
95  network_ = network;
96  early_ = early;
97  opt_ = std::make_unique<OPT>(args...);
98  Train(inputs,labels);
99 }
100 ///////////////////////////////////////////////////////////////////////////////
101 template<class NET,class OPT,class MET,class O>
102 void ANN<NET,OPT,MET,O>::Train( const arma::Mat<O>& inputs,
103  const arma::Mat<O>& labels )
104 {
105  // Safety Net for learning curve generation from scratch,
106  // but you might want to start from a trained model. So future modification
107  // might be needed...
108  if (network_.Parameters().n_elem != 0)
109  network_.Reset();
110  if (!early_)
111  network_.Train(inputs,labels,*opt_);
112  else
113  {
114  arma::Mat<O> inp,lab,val_inp,val_lab;
115 
116  mlpack::data::Split(inputs,labels,
117  inp,val_inp,lab,val_lab,0.2);
118 
119  auto func = [&](const arma::Mat<O>& dummy)
120  {
121  if (MET::NeedsMinimization)
122  return MET::Evaluate(*this,val_inp,val_lab);
123  else
124  return -MET::Evaluate(*this,val_inp,val_lab);
125  };
126 
127  ens::EarlyStopAtMinLossType<arma::Mat<O>> stop(func,5);
128  network_.Train(inp,lab,*opt_,stop);
129 
130  }
131 }
132 ///////////////////////////////////////////////////////////////////////////////
133 template<class NET,class OPT,class MET,class O>
134 void ANN<NET,OPT,MET,O>::Train( const arma::Mat<O>& inputs,
135  const arma::Row<size_t>& labels )
136 {
137  // Get the unique labels
138  this -> ulab_ = arma::unique(labels);
139  // OneHotEncode the labels for the classifier network
140  auto convlabels = _OneHotEncode(labels,ulab_);
141 
142  // Safety Net for learning curve generation from scratch,
143  // but you might want to start from a trained model. So future modification
144  // might be needed...
145  if (network_.Parameters().n_elem != 0)
146  network_.Reset();
147  if (!early_)
148  {
149  network_.Train(inputs,convlabels,*opt_);
150  }
151  else
152  {
153  arma::Mat<O> inp,lab,val_inp,val_lab;
154 
155  mlpack::data::Split(inputs,convlabels,
156  inp,val_inp,lab,val_lab,0.2);
157 
158  auto val_lab_ = _OneHotDecode(val_lab,ulab_);
159  auto func = [&]( const arma::Mat<O>& dummy )
160  {
161  if (MET::NeedsMinimization)
162  return MET::Evaluate(*this,val_inp,val_lab_);
163  else
164  return -MET::Evaluate(*this,val_inp,val_lab_);
165  };
166 
167  ens::EarlyStopAtMinLossType<arma::Mat<O>> stop(func,5);
168  network_.Train(inp,lab,*opt_,stop);
169  }
170 }
171 ///////////////////////////////////////////////////////////////////////////////
172 template<class NET,class OPT,class MET,class O>
173 void ANN<NET,OPT,MET,O>::Predict( const arma::Mat<O>& inputs,
174  arma::Mat<O>& preds )
175 {
176  network_.Predict(inputs,preds);
177 }
178 ///////////////////////////////////////////////////////////////////////////////
179 template<class NET,class OPT,class MET,class O>
180 void ANN<NET,OPT,MET,O>::Classify( const arma::Mat<O>& inputs,
181  arma::Row<size_t>& preds )
182 {
183  arma::Mat<O> temp;
184  network_.Predict(inputs,temp);
185  preds = _OneHotDecode(temp,ulab_);
186 }
187 ///////////////////////////////////////////////////////////////////////////////
188 template<class NET,class OPT,class MET,class O>
189 O ANN<NET,OPT,MET,O>::ComputeError( const arma::Mat<O>& inputs,
190  const arma::Mat<O>& labels )
191 {
192  arma::Mat<O> preds;
193  network_.Predict(inputs,preds);
194  return MET::Evaluate(preds, labels)/preds.n_elem;
195 }
196 ///////////////////////////////////////////////////////////////////////////////
197 template<class NET,class OPT,class MET,class O>
199  const arma::Row<size_t>& labels,
200  const arma::Row<size_t>& ulabels )
201 {
202  size_t i=0;
203  return arma::Mat<O>(ulabels.n_elem, labels.n_elem).
204  each_col( [&](arma::vec& col){col(ulabels[labels[i++]])=1.;} );
205 }
206 ///////////////////////////////////////////////////////////////////////////////
207 template<class NET,class OPT,class MET,class O>
208 arma::Row<size_t> ANN<NET,OPT,MET,O>::_OneHotDecode(
209  const arma::Mat<O>& labels,
210  const arma::Row<size_t>& ulabels )
211 {
212  return ulabels.cols(arma::index_max(labels,0));
213 }
214 ///////////////////////////////////////////////////////////////////////////////
215 template<class NET,class OPT,class MET,class O>
216 arma::Mat<O> ANN<NET,OPT,MET,O>::Parameters ( )
217 {
218  return network_.Parameters();
219 }
220 
221 } // namespace algo
222 
223 #endif
224 
Definition: nn.h:23
ANN()
Definition: nn.h:28
void Train(const arma::Mat< O > &inputs, const arma::Mat< O > &labels)
Definition: nn_impl.h:102
O ComputeError(const arma::Mat< O > &inputs, const arma::Mat< O > &labels)
Definition: nn_impl.h:189
void Classify(const arma::Mat< O > &inputs, arma::Row< size_t > &preds)
Definition: nn_impl.h:180
void Predict(const arma::Mat< O > &inputs, arma::Mat< O > &preds)
Definition: nn_impl.h:173