|
| | ANN () |
| |
| | ANN (const arma::Mat< O > &inputs, const arma::Mat< O > &labels) |
| |
| | ANN (const arma::Mat< O > &inputs, const arma::Mat< O > &labels, const NET network) |
| |
|
template<class... OptArgs> |
| | ANN (const arma::Mat< O > &inputs, const arma::Mat< O > &labels, const NET network, bool early=false, const OptArgs &... args) |
| |
|
template<class... OptArgs> |
| | ANN (const arma::Mat< O > &inputs, const arma::Row< size_t > &labels, const NET network, bool early=false, const OptArgs &... args) |
| |
| void | Train (const arma::Mat< O > &inputs, const arma::Mat< O > &labels) |
| |
|
template<class... OptArgs> |
| void | Train (const arma::Mat< O > &inputs, const arma::Mat< O > &labels, const NET network, bool early=false, const OptArgs &... args) |
| |
|
template<class... OptArgs> |
| void | Train (const arma::Mat< O > &inputs, const arma::Row< size_t > &labels, const NET network, bool early=false, const OptArgs &... args) |
| |
|
void | Train (const arma::Mat< O > &inputs, const arma::Row< size_t > &labels) |
| |
| void | Predict (const arma::Mat< O > &inputs, arma::Mat< O > &preds) |
| |
| void | Classify (const arma::Mat< O > &inputs, arma::Row< size_t > &preds) |
| |
| O | ComputeError (const arma::Mat< O > &inputs, const arma::Mat< O > &labels) |
| |
|
arma::Mat< O > | Parameters () |
| |
|
template<class Archive > |
| void | serialize (Archive &ar) |
| |
template<class NET, class OPT = ens::StandardSGD, class MET = mlpack::MeanSquaredError, class O = DTYPE>
class algo::ANN< NET, OPT, MET, O >
Definition at line 22 of file nn.h.
◆ ANN() [1/3]
template<class NET , class OPT = ens::StandardSGD, class MET = mlpack::MeanSquaredError, class O = DTYPE>
Non-working model
Definition at line 28 of file nn.h.
◆ ANN() [2/3]
template<class NET , class OPT , class MET , class O >
| algo::ANN< NET, OPT, MET, O >::ANN |
( |
const arma::Mat< O > & |
inputs, |
|
|
const arma::Mat< O > & |
labels |
|
) |
| |
- Parameters
-
| network | : pointer to the network object |
| args | : optimizer arguments in the given order |
| inputs | : input data, X |
| labels | : labels of the input, y |
Definition at line 27 of file nn_impl.h.
31 opt_ = std::make_unique<OPT>();
void Train(const arma::Mat< O > &inputs, const arma::Mat< O > &labels)
◆ ANN() [3/3]
template<class NET , class OPT , class MET , class O >
| algo::ANN< NET, OPT, MET, O >::ANN |
( |
const arma::Mat< O > & |
inputs, |
|
|
const arma::Mat< O > & |
labels, |
|
|
const NET |
network |
|
) |
| |
- Parameters
-
| inputs | : input data, X |
| labels | : labels of the input, y |
| network | : network object |
Definition at line 36 of file nn_impl.h.
43 opt_ = std::make_unique<OPT>();
◆ Classify()
template<class NET , class OPT , class MET , class O >
| void algo::ANN< NET, OPT, MET, O >::Classify |
( |
const arma::Mat< O > & |
inputs, |
|
|
arma::Row< size_t > & |
preds |
|
) |
| |
- Parameters
-
| inputs | : input data, X |
| preds | : prediction of labels of the input, \hat{y} |
Definition at line 180 of file nn_impl.h.
184 network_.Predict(inputs,temp);
185 preds = _OneHotDecode(temp,ulab_);
◆ ComputeError()
template<class NET , class OPT , class MET , class O >
| O algo::ANN< NET, OPT, MET, O >::ComputeError |
( |
const arma::Mat< O > & |
inputs, |
|
|
const arma::Mat< O > & |
labels |
|
) |
| |
- Parameters
-
| inputs | : input data, X |
| labels | : labels of the input, y |
Definition at line 189 of file nn_impl.h.
193 network_.Predict(inputs,preds);
194 return MET::Evaluate(preds, labels)/preds.n_elem;
◆ Predict()
template<class NET , class OPT , class MET , class O >
| void algo::ANN< NET, OPT, MET, O >::Predict |
( |
const arma::Mat< O > & |
inputs, |
|
|
arma::Mat< O > & |
preds |
|
) |
| |
- Parameters
-
| inputs | : input data, X |
| preds | : prediction of labels of the input, \hat{y} |
Definition at line 173 of file nn_impl.h.
176 network_.Predict(inputs,preds);
◆ Train()
template<class NET , class OPT , class MET , class O >
| void algo::ANN< NET, OPT, MET, O >::Train |
( |
const arma::Mat< O > & |
inputs, |
|
|
const arma::Mat< O > & |
labels |
|
) |
| |
- Parameters
-
| inputs | : input data, X |
| labels | : labels of the input, y |
Definition at line 102 of file nn_impl.h.
108 if (network_.Parameters().n_elem != 0)
111 network_.Train(inputs,labels,*opt_);
114 arma::Mat<O> inp,lab,val_inp,val_lab;
116 mlpack::data::Split(inputs,labels,
117 inp,val_inp,lab,val_lab,0.2);
119 auto func = [&](
const arma::Mat<O>& dummy)
121 if (MET::NeedsMinimization)
122 return MET::Evaluate(*
this,val_inp,val_lab);
124 return -MET::Evaluate(*
this,val_inp,val_lab);
127 ens::EarlyStopAtMinLossType<arma::Mat<O>> stop(func,5);
128 network_.Train(inp,lab,*opt_,stop);
The documentation for this class was generated from the following files: