13 namespace classification {
18 template<
class T=DTYPE>
26 LDC ( ) : lambda_(0.) { };
31 LDC (
const size_t& num_class,
const double& lambda ) : num_class_(num_class),
39 LDC (
const arma::Mat<T>& inputs,
40 const arma::Row<size_t>& labels,
41 const size_t& num_class );
49 LDC (
const arma::Mat<T>& inputs,
50 const arma::Row<size_t>& labels,
51 const size_t& num_class,
52 const double& lambda );
60 LDC (
const arma::Mat<T>& inputs,
61 const arma::Row<size_t>& labels,
62 const size_t& num_class,
64 const arma::Row<T>& priors );
70 void Train (
const arma::Mat<T>& inputs,
71 const arma::Row<size_t>& labels );
78 void Train (
const arma::Mat<T>& inputs,
79 const arma::Row<size_t>& labels,
80 const size_t num_class );
86 void Classify (
const arma::Mat<T>& inputs,
87 arma::Row<size_t>& labels )
const;
94 void Classify (
const arma::Mat<T>& inputs,
95 arma::Row<size_t>& labels,
96 arma::Mat<T>& scores )
const;
104 const arma::Row<size_t>& responses )
const;
113 const arma::Row<size_t>& responses )
const;
118 template<
typename Archive>
122 ar (cereal::make_nvp(
"dim",dim_),
123 cereal::make_nvp(
"num_class",num_class_),
124 cereal::make_nvp(
"size",size_),
125 cereal::make_nvp(
"lambda",lambda_),
126 cereal::make_nvp(
"means",means_),
127 cereal::make_nvp(
"covs",covs_),
128 cereal::make_nvp(
"mean",mean_),
129 cereal::make_nvp(
"cov",cov_),
130 cereal::make_nvp(
"class",class_),
131 cereal::make_nvp(
"unique",unique_),
132 cereal::make_nvp(
"priors",priors_));
143 double jitter_ = 1.e-8;
146 std::map<size_t, arma::Row<T>> means_;
147 std::map<size_t, arma::Mat<T>> covs_;
152 arma::Row<size_t> unique_;
153 arma::Row<size_t> class_;
154 arma::Row<T> priors_;
161 template<
class T=DTYPE>
169 QDC ( ) : lambda_(0.) { };
175 QDC (
const size_t& num_class,
const double& lambda ) : num_class_(num_class),
176 lambda_(lambda) { } ;
184 QDC (
const arma::Mat<T>& inputs,
185 const arma::Row<size_t>& labels,
186 const size_t& num_class,
187 const double& lambda );
195 QDC (
const arma::Mat<T>& inputs,
196 const arma::Row<size_t>& labels,
197 const size_t& num_class,
198 const double& lambda,
199 const arma::Row<T>& priors );
206 QDC (
const arma::Mat<T>& inputs,
207 const arma::Row<size_t>& labels,
208 const size_t& num_class );
214 void Train (
const arma::Mat<T>& inputs,
215 const arma::Row<size_t>& labels );
222 void Train (
const arma::Mat<T>& inputs,
223 const arma::Row<size_t>& labels,
224 const size_t num_class );
230 void Classify (
const arma::Mat<T>& inputs,
231 arma::Row<size_t>& labels )
const;
237 void Classify (
const arma::Mat<T>& inputs,
238 arma::Row<size_t>& labels,
239 arma::Mat<T>& scores )
const;
247 const arma::Row<size_t>& responses )
const;
256 const arma::Row<size_t>& responses )
const;
261 template<
typename Archive>
265 ar ( cereal::make_nvp(
"dim",dim_),
266 cereal::make_nvp(
"num_class",num_class_),
267 cereal::make_nvp(
"size",size_),
268 cereal::make_nvp(
"lambda",lambda_),
269 cereal::make_nvp(
"means",means_),
270 cereal::make_nvp(
"covs",covs_),
271 cereal::make_nvp(
"icovs",icovs_),
272 cereal::make_nvp(
"unique",unique_),
273 cereal::make_nvp(
"class",class_),
274 cereal::make_nvp(
"priors",priors_));
285 double jitter_ = 1.e-8;
287 std::map<size_t, arma::Row<T>> means_;
288 std::map<size_t, arma::Mat<T>> covs_;
289 std::map<size_t, arma::Mat<T>> icovs_;
291 arma::Row<size_t> unique_;
292 arma::Row<size_t> class_;
293 arma::Row<T> priors_;
300 template<
class T=DTYPE>
313 NMC (
const size_t& num_classes );
320 NMC (
const arma::Mat<T>& inputs,
321 const arma::Row<size_t>& labels,
322 const size_t& num_class );
331 NMC (
const arma::Mat<T>& inputs,
332 const arma::Row<size_t>& labels,
333 const size_t& num_class,
334 const double& shrink );
339 void Train (
const arma::Mat<T>& inputs,
340 const arma::Row<size_t>& labels );
346 void Train (
const arma::Mat<T>& inputs,
347 const arma::Row<size_t>& labels,
348 const size_t num_class );
354 void Classify (
const arma::Mat<T>& inputs,
355 arma::Row<size_t>& labels )
const;
362 void Classify (
const arma::Mat<T>& inputs,
363 arma::Row<size_t>& labels,
364 arma::Mat<T>& scores )
const;
372 const arma::Row<size_t>& responses )
const;
380 const arma::Row<size_t>& responses )
const;
382 const arma::Mat<T>& Parameters()
const {
return parameters_; }
384 arma::Mat<T>& Parameters() {
return parameters_; }
390 template<
typename Archive>
394 ar ( cereal::make_nvp(
"parameters",parameters_),
395 cereal::make_nvp(
"dim",dim_),
396 cereal::make_nvp(
"num_class",num_class_),
397 cereal::make_nvp(
"centroid",centroid_),
398 cereal::make_nvp(
"unique",unique_),
399 cereal::make_nvp(
"shrink",shrink_),
400 cereal::make_nvp(
"size",size_) );
410 arma::Mat<T> centroid_;
411 arma::Mat<T> parameters_;
413 arma::Row<size_t> unique_;
415 mlpack::EuclideanDistance metric_;
421 template<
class T=DTYPE>
422 std::tuple< arma::Mat<T>,
423 arma::uvec > extract_class (
const arma::Mat<T>& inputs,
424 const arma::Row<size_t>& labels,
425 const size_t& label_id )
427 arma::uvec index = arma::find(labels == label_id);
428 return std::make_tuple(inputs.cols(index), index);
434 template<
class T=DTYPE>
435 arma::Row<T> get_prior (
const arma::Row<size_t>& labels,
436 const size_t& num_class )
438 auto unq = arma::regspace<arma::Row<size_t>>(0,1,num_class);
439 auto prior = arma::conv_to<arma::Row<T>>::from(arma::hist(labels, unq));
440 return prior / labels.n_cols;
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
LDC(const size_t &num_class, const double &lambda)
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
T ComputeAccuracy(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
void serialize(Archive &ar, const unsigned int)
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
void serialize(Archive &ar, const unsigned int)
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
NMC(const size_t &num_classes)
T ComputeAccuracy(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
QDC(const size_t &num_class, const double &lambda)
void serialize(Archive &ar, const unsigned int)
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
T ComputeAccuracy(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const