Learning Curve Plus Plus (LCPP)
algo::ANN< NET, OPT, MET, O > Class Template Reference

Public Member Functions

 ANN ()
 
 ANN (const arma::Mat< O > &inputs, const arma::Mat< O > &labels)
 
 ANN (const arma::Mat< O > &inputs, const arma::Mat< O > &labels, const NET network)
 
template<class... OptArgs>
 ANN (const arma::Mat< O > &inputs, const arma::Mat< O > &labels, const NET network, bool early=false, const OptArgs &... args)
 
template<class... OptArgs>
 ANN (const arma::Mat< O > &inputs, const arma::Row< size_t > &labels, const NET network, bool early=false, const OptArgs &... args)
 
void Train (const arma::Mat< O > &inputs, const arma::Mat< O > &labels)
 
template<class... OptArgs>
void Train (const arma::Mat< O > &inputs, const arma::Mat< O > &labels, const NET network, bool early=false, const OptArgs &... args)
 
template<class... OptArgs>
void Train (const arma::Mat< O > &inputs, const arma::Row< size_t > &labels, const NET network, bool early=false, const OptArgs &... args)
 
void Train (const arma::Mat< O > &inputs, const arma::Row< size_t > &labels)
 
void Predict (const arma::Mat< O > &inputs, arma::Mat< O > &preds)
 
void Classify (const arma::Mat< O > &inputs, arma::Row< size_t > &preds)
 
ComputeError (const arma::Mat< O > &inputs, const arma::Mat< O > &labels)
 
arma::Mat< O > Parameters ()
 
template<class Archive >
void serialize (Archive &ar)
 

Detailed Description

template<class NET, class OPT = ens::StandardSGD, class MET = mlpack::MeanSquaredError, class O = DTYPE>
class algo::ANN< NET, OPT, MET, O >

Definition at line 22 of file nn.h.

Constructor & Destructor Documentation

◆ ANN() [1/3]

template<class NET , class OPT = ens::StandardSGD, class MET = mlpack::MeanSquaredError, class O = DTYPE>
algo::ANN< NET, OPT, MET, O >::ANN ( )
inline

Non-working model

Definition at line 28 of file nn.h.

28 { };

◆ ANN() [2/3]

template<class NET , class OPT , class MET , class O >
algo::ANN< NET, OPT, MET, O >::ANN ( const arma::Mat< O > &  inputs,
const arma::Mat< O > &  labels 
)
Parameters
network: pointer to the network object
args: optimizer arguments in the given order
inputs: input data, X
labels: labels of the input, y

Definition at line 27 of file nn_impl.h.

29 {
30  early_ = false;
31  opt_ = std::make_unique<OPT>();
32  Train(inputs,labels);
33 }
void Train(const arma::Mat< O > &inputs, const arma::Mat< O > &labels)
Definition: nn_impl.h:102

◆ ANN() [3/3]

template<class NET , class OPT , class MET , class O >
algo::ANN< NET, OPT, MET, O >::ANN ( const arma::Mat< O > &  inputs,
const arma::Mat< O > &  labels,
const NET  network 
)
Parameters
inputs: input data, X
labels: labels of the input, y
network: network object

Definition at line 36 of file nn_impl.h.

40 {
41  network_ = network;
42  early_ = false;
43  opt_ = std::make_unique<OPT>();
44  Train(inputs,labels);
45 }

Member Function Documentation

◆ Classify()

template<class NET , class OPT , class MET , class O >
void algo::ANN< NET, OPT, MET, O >::Classify ( const arma::Mat< O > &  inputs,
arma::Row< size_t > &  preds 
)
Parameters
inputs: input data, X
preds: prediction of labels of the input, \hat{y}

Definition at line 180 of file nn_impl.h.

182 {
183  arma::Mat<O> temp;
184  network_.Predict(inputs,temp);
185  preds = _OneHotDecode(temp,ulab_);
186 }

◆ ComputeError()

template<class NET , class OPT , class MET , class O >
O algo::ANN< NET, OPT, MET, O >::ComputeError ( const arma::Mat< O > &  inputs,
const arma::Mat< O > &  labels 
)
Parameters
inputs: input data, X
labels: labels of the input, y

Definition at line 189 of file nn_impl.h.

191 {
192  arma::Mat<O> preds;
193  network_.Predict(inputs,preds);
194  return MET::Evaluate(preds, labels)/preds.n_elem;
195 }

◆ Predict()

template<class NET , class OPT , class MET , class O >
void algo::ANN< NET, OPT, MET, O >::Predict ( const arma::Mat< O > &  inputs,
arma::Mat< O > &  preds 
)
Parameters
inputs: input data, X
preds: prediction of labels of the input, \hat{y}

Definition at line 173 of file nn_impl.h.

175 {
176  network_.Predict(inputs,preds);
177 }

◆ Train()

template<class NET , class OPT , class MET , class O >
void algo::ANN< NET, OPT, MET, O >::Train ( const arma::Mat< O > &  inputs,
const arma::Mat< O > &  labels 
)
Parameters
inputs: input data, X
labels: labels of the input, y

Definition at line 102 of file nn_impl.h.

104 {
105  // Safety Net for learning curve generation from scratch,
106  // but you might want to start from a trained model. So future modification
107  // might be needed...
108  if (network_.Parameters().n_elem != 0)
109  network_.Reset();
110  if (!early_)
111  network_.Train(inputs,labels,*opt_);
112  else
113  {
114  arma::Mat<O> inp,lab,val_inp,val_lab;
115 
116  mlpack::data::Split(inputs,labels,
117  inp,val_inp,lab,val_lab,0.2);
118 
119  auto func = [&](const arma::Mat<O>& dummy)
120  {
121  if (MET::NeedsMinimization)
122  return MET::Evaluate(*this,val_inp,val_lab);
123  else
124  return -MET::Evaluate(*this,val_inp,val_lab);
125  };
126 
127  ens::EarlyStopAtMinLossType<arma::Mat<O>> stop(func,5);
128  network_.Train(inp,lab,*opt_,stop);
129 
130  }
131 }

The documentation for this class was generated from the following files: