16 namespace classification {
21 template<
class MODEL,
class T=DTYPE>
33 template<
class... Args>
34 OnevAll (
const size_t& num_class,
const Args&... args );
42 template<
class... Args>
43 OnevAll (
const arma::Mat<T>& inputs,
44 const arma::Row<size_t>& labels,
45 const size_t& num_class,
46 const Args&... args );
53 template<
class... Args>
54 void Train (
const arma::Mat<T>& inputs,
55 const arma::Row<size_t>& labels,
56 const Args&... args );
63 template<
class... Args>
64 void Train (
const arma::Mat<T>& inputs,
65 const arma::Row<size_t>& labels );
70 void Classify (
const arma::Mat<T>& inputs,
71 arma::Row<size_t>& preds )
const ;
77 void Classify (
const arma::Mat<T>& inputs,
78 arma::Row<size_t>& preds,
79 arma::Mat<T>& probs )
const ;
85 const arma::Row<size_t>& labels );
92 const arma::Row<size_t>& labels );
94 std::vector<MODEL> models_;
99 template<
typename Archive>
103 ar ( cereal::make_nvp(
"nclass",nclass_),
104 cereal::make_nvp(
"unclass",unclass_),
105 cereal::make_nvp(
"unq",unq_),
106 cereal::make_nvp(
"models",models_),
107 cereal::make_nvp(
"oneclass",oneclass_));
113 arma::Row<size_t> unq_;
114 bool oneclass_ =
false;
T ComputeAccuracy(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &preds) const
T ComputeError(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const Args &... args)
void serialize(Archive &ar, const unsigned int)