19 class OPT=ens::StandardSGD,
20 class MET=mlpack::MeanSquaredError,
41 ANN (
const arma::Mat<O>& inputs,
42 const arma::Mat<O>& labels ) ;
50 ANN (
const arma::Mat<O>& inputs,
51 const arma::Mat<O>& labels,
54 template<
class... OptArgs>
55 ANN (
const arma::Mat<O>& inputs,
56 const arma::Mat<O>& labels,
57 const NET network,
bool early =
false,
const OptArgs&... args ) ;
59 template<
class... OptArgs>
60 ANN (
const arma::Mat<O>& inputs,
61 const arma::Row<size_t>& labels,
62 const NET network,
bool early =
false,
const OptArgs&... args ) ;
67 void Train(
const arma::Mat<O>& inputs,
const arma::Mat<O>& labels );
69 template<
class... OptArgs>
70 void Train(
const arma::Mat<O>& inputs,
const arma::Mat<O>& labels,
71 const NET network,
bool early =
false,
const OptArgs&... args );
73 template<
class... OptArgs>
74 void Train(
const arma::Mat<O>& inputs,
const arma::Row<size_t>& labels,
75 const NET network,
bool early =
false,
const OptArgs&... args );
77 void Train(
const arma::Mat<O>& inputs,
const arma::Row<size_t>& labels );
83 void Predict(
const arma::Mat<O>& inputs, arma::Mat<O>& preds );
89 void Classify(
const arma::Mat<O>& inputs, arma::Row<size_t>& preds );
96 O
ComputeError(
const arma::Mat<O>& inputs,
const arma::Mat<O>& labels );
98 arma::Mat<O> Parameters( );
100 template<
class Archive>
101 void serialize(Archive& ar)
103 ar( CEREAL_NVP(network_),
114 arma::Mat<O> _OneHotEncode (
const arma::Row<size_t>& labels,
115 const arma::Row<size_t>& ulabels );
121 arma::Row<size_t> _OneHotDecode (
const arma::Mat<O>& labels,
122 const arma::Row<size_t>& ulabels);
126 std::unique_ptr<OPT> opt_;
137 arma::Row<size_t> ulab_;
void Train(const arma::Mat< O > &inputs, const arma::Mat< O > &labels)
O ComputeError(const arma::Mat< O > &inputs, const arma::Mat< O > &labels)
void Classify(const arma::Mat< O > &inputs, arma::Row< size_t > &preds)
void Predict(const arma::Mat< O > &inputs, arma::Mat< O > &preds)