26 template<
class NET,
class OPT,
class MET,
class O>
28 const arma::Mat<O>& labels )
31 opt_ = std::make_unique<OPT>();
35 template<
class NET,
class OPT,
class MET,
class O>
37 const arma::Mat<O>& labels,
43 opt_ = std::make_unique<OPT>();
47 template<
class NET,
class OPT,
class MET,
class O>
48 template<
class... OptArgs>
50 const arma::Mat<O>& labels,
51 const NET network,
bool early,
const OptArgs&... args )
56 opt_ = std::make_unique<OPT>(args...);
60 template<
class NET,
class OPT,
class MET,
class O>
61 template<
class... OptArgs>
63 const arma::Row<size_t>& labels,
64 const NET network,
bool early,
const OptArgs&... args )
69 opt_ = std::make_unique<OPT>(args...);
73 template<
class NET,
class OPT,
class MET,
class O>
74 template<
class... OptArgs>
76 const arma::Row<size_t>& labels,
77 const NET network,
bool early,
78 const OptArgs&... args )
83 opt_ = std::make_unique<OPT>(args...);
87 template<
class NET,
class OPT,
class MET,
class O>
88 template<
class... OptArgs>
90 const arma::Mat<O>& labels,
91 const NET network,
bool early,
92 const OptArgs&... args )
97 opt_ = std::make_unique<OPT>(args...);
101 template<
class NET,
class OPT,
class MET,
class O>
103 const arma::Mat<O>& labels )
108 if (network_.Parameters().n_elem != 0)
111 network_.Train(inputs,labels,*opt_);
114 arma::Mat<O> inp,lab,val_inp,val_lab;
116 mlpack::data::Split(inputs,labels,
117 inp,val_inp,lab,val_lab,0.2);
119 auto func = [&](
const arma::Mat<O>& dummy)
121 if (MET::NeedsMinimization)
122 return MET::Evaluate(*
this,val_inp,val_lab);
124 return -MET::Evaluate(*
this,val_inp,val_lab);
127 ens::EarlyStopAtMinLossType<arma::Mat<O>> stop(func,5);
128 network_.Train(inp,lab,*opt_,stop);
133 template<
class NET,
class OPT,
class MET,
class O>
135 const arma::Row<size_t>& labels )
138 this -> ulab_ = arma::unique(labels);
140 auto convlabels = _OneHotEncode(labels,ulab_);
145 if (network_.Parameters().n_elem != 0)
149 network_.Train(inputs,convlabels,*opt_);
153 arma::Mat<O> inp,lab,val_inp,val_lab;
155 mlpack::data::Split(inputs,convlabels,
156 inp,val_inp,lab,val_lab,0.2);
158 auto val_lab_ = _OneHotDecode(val_lab,ulab_);
159 auto func = [&](
const arma::Mat<O>& dummy )
161 if (MET::NeedsMinimization)
162 return MET::Evaluate(*
this,val_inp,val_lab_);
164 return -MET::Evaluate(*
this,val_inp,val_lab_);
167 ens::EarlyStopAtMinLossType<arma::Mat<O>> stop(func,5);
168 network_.Train(inp,lab,*opt_,stop);
172 template<
class NET,
class OPT,
class MET,
class O>
174 arma::Mat<O>& preds )
176 network_.Predict(inputs,preds);
179 template<
class NET,
class OPT,
class MET,
class O>
181 arma::Row<size_t>& preds )
184 network_.Predict(inputs,temp);
185 preds = _OneHotDecode(temp,ulab_);
188 template<
class NET,
class OPT,
class MET,
class O>
190 const arma::Mat<O>& labels )
193 network_.Predict(inputs,preds);
194 return MET::Evaluate(preds, labels)/preds.n_elem;
197 template<
class NET,
class OPT,
class MET,
class O>
199 const arma::Row<size_t>& labels,
200 const arma::Row<size_t>& ulabels )
203 return arma::Mat<O>(ulabels.n_elem, labels.n_elem).
204 each_col( [&](arma::vec& col){col(ulabels[labels[i++]])=1.;} );
207 template<
class NET,
class OPT,
class MET,
class O>
208 arma::Row<size_t> ANN<NET,OPT,MET,O>::_OneHotDecode(
209 const arma::Mat<O>& labels,
210 const arma::Row<size_t>& ulabels )
212 return ulabels.cols(arma::index_max(labels,0));
215 template<
class NET,
class OPT,
class MET,
class O>
216 arma::Mat<O> ANN<NET,OPT,MET,O>::Parameters ( )
218 return network_.Parameters();
void Train(const arma::Mat< O > &inputs, const arma::Mat< O > &labels)
O ComputeError(const arma::Mat< O > &inputs, const arma::Mat< O > &labels)
void Classify(const arma::Mat< O > &inputs, arma::Row< size_t > &preds)
void Predict(const arma::Mat< O > &inputs, arma::Mat< O > &preds)