Learning Curve Plus Plus (LCPP)
algo::classification::NMC< T > Class Template Reference

Public Member Functions

 NMC ()
 
 NMC (const size_t &num_classes)
 
 NMC (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t &num_class)
 
 NMC (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t &num_class, const double &shrink)
 
void Train (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
 
void Train (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t num_class)
 
void Classify (const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
 
void Classify (const arma::Mat< T > &inputs, arma::Row< size_t > &labels, arma::Mat< T > &scores) const
 
ComputeError (const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
 
ComputeAccuracy (const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
 
const arma::Mat< T > & Parameters () const
 
arma::Mat< T > & Parameters ()
 
template<typename Archive >
void serialize (Archive &ar, const unsigned int)
 

Detailed Description

template<class T = DTYPE>
class algo::classification::NMC< T >

Definition at line 301 of file paramclass.h.

Constructor & Destructor Documentation

◆ NMC() [1/4]

template<class T = DTYPE>
algo::classification::NMC< T >::NMC ( )
inline

Non-working model

Definition at line 308 of file paramclass.h.

308 { } ;

◆ NMC() [2/4]

template<class T = DTYPE>
algo::classification::NMC< T >::NMC ( const size_t &  num_classes)
Parameters
num_class: number of classes

◆ NMC() [3/4]

template<class T >
algo::classification::NMC< T >::NMC ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels,
const size_t &  num_class 
)
Parameters
inputs: X
labels: y
num_class: number of classes

Definition at line 303 of file paramclass_impl.h.

305  : shrink_(0.), num_class_(num_class)
306 {
307  Train(inputs, labels);
308 }
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)

References algo::classification::NMC< T >::Train().

+ Here is the call graph for this function:

◆ NMC() [4/4]

template<class T >
algo::classification::NMC< T >::NMC ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels,
const size_t &  num_class,
const double &  shrink 
)
Parameters
inputs: X
labels: y
num_class: number of classes
shrink: s

Definition at line 294 of file paramclass_impl.h.

297  : shrink_(shrink), num_class_(num_class)
298 {
299  Train(inputs, labels);
300 }

References algo::classification::NMC< T >::Train().

+ Here is the call graph for this function:

Member Function Documentation

◆ Classify() [1/2]

template<class T >
void algo::classification::NMC< T >::Classify ( const arma::Mat< T > &  inputs,
arma::Row< size_t > &  labels 
) const
Parameters
inputs: X*
labels: y*

Definition at line 368 of file paramclass_impl.h.

370 {
371  arma::Mat<T> temp;
372  Classify(inputs,labels,temp);
373 }
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const

◆ Classify() [2/2]

template<class T >
void algo::classification::NMC< T >::Classify ( const arma::Mat< T > &  inputs,
arma::Row< size_t > &  labels,
arma::Mat< T > &  scores 
) const
Parameters
inputs: X*
labels: y*
probs: scores*

Definition at line 376 of file paramclass_impl.h.

379 {
380  const size_t N = inputs.n_cols;
381  probs.resize(num_class_, N);
382  labels.resize(N);
383  if ( unique_.n_elem == 1 )
384  {
385  labels.fill(unique_(0));
386  probs.row(unique_(0)).fill(1.);
387  }
388  else
389  {
390  arma::Mat<T> distances(num_class_, N);
391  distances.fill(arma::datum::inf);
392  for ( size_t j=0; j<N; j++ )
393  {
394  for ( size_t i=0; i<unique_.n_elem; i++ )
395  {
396  distances(unique_(i), j) = metric_.Evaluate(parameters_.col(unique_(i))
397  ,inputs.col(j));
398  }
399  labels(j) = arma::index_min(distances.col(j));
400  }
401  if (unique_.n_elem == num_class_)
402  probs = distances.each_row() / arma::sum(distances,0);
403  else
404  {
405  for (arma::uword i = 0; i < distances.n_rows; ++i)
406  {
407  // Check if the row contains any infinite values
408  if (arma::is_finite(distances.row(i)))
409  // Perform the division for this row only if all values are finite
410  // and sum is nonzero
411  probs.row(i) = distances.row(i) / arma::sum(distances, 0);
412  }
413  }
414  probs.elem(arma::find_nonfinite(probs)).fill(1/num_class_);
415  }
416 }

◆ ComputeAccuracy()

template<class T >
T algo::classification::NMC< T >::ComputeAccuracy ( const arma::Mat< T > &  points,
const arma::Row< size_t > &  responses 
) const

Calculate the Accuracy

Parameters
inputs: X*
labels: y

Definition at line 430 of file paramclass_impl.h.

432 {
433  return (1. - ComputeError(points, responses))*100;
434 }
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const

◆ ComputeError()

template<class T >
T algo::classification::NMC< T >::ComputeError ( const arma::Mat< T > &  points,
const arma::Row< size_t > &  responses 
) const

Calculate the Error Rate

Parameters
inputs: X*
labels: y

Definition at line 419 of file paramclass_impl.h.

421 {
422  arma::Row<size_t> predictions;
423 
424  Classify(points,predictions);
425  arma::Row<size_t> temp = predictions - responses;
426  return (arma::accu(temp != 0))/T(predictions.n_elem);
427 }

◆ serialize()

template<class T = DTYPE>
template<typename Archive >
void algo::classification::NMC< T >::serialize ( Archive &  ar,
const unsigned int   
)
inline

Serialize the model.

Definition at line 391 of file paramclass.h.

393  {
394  ar ( cereal::make_nvp("parameters",parameters_),
395  cereal::make_nvp("dim",dim_),
396  cereal::make_nvp("num_class",num_class_),
397  cereal::make_nvp("centroid",centroid_),
398  cereal::make_nvp("unique",unique_),
399  cereal::make_nvp("shrink",shrink_),
400  cereal::make_nvp("size",size_) );
401  }

◆ Train() [1/2]

template<class T >
void algo::classification::NMC< T >::Train ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels 
)
Parameters
inputs: X
labels: y

Definition at line 311 of file paramclass_impl.h.

313 {
314  dim_ = inputs.n_rows;
315  unique_ = arma::unique(labels);
316  /* num_class_ = unique_.n_cols; */
317  size_ = inputs.n_cols;
318  arma::vec nk(num_class_);
319  parameters_.resize(inputs.n_rows, num_class_);
320  arma::uvec index;
321  arma::Row<size_t>::iterator it = unique_.begin();
322  arma::Row<size_t>::iterator it_end = unique_.end();
323  size_t counter =0;
324  // you should just iterate over the unique labels here instead of starting
325  // from 0
326 
327  for ( ;it!=it_end; ++it)
328  {
329  auto extract = extract_class(inputs, labels, *it);
330  index = std::get<1>(extract);
331  nk(counter) = index.n_rows;
332  parameters_.col(counter) = arma::mean(inputs.cols(index),1);
333  counter++;
334  }
335  centroid_ = arma::mean(inputs, 1);
336 
337  // Just for the shrinkage part
338  if (shrink_ > 0 && num_class_ != 1)
339  {
340  arma::Mat<T> nk = arma::ones<arma::Mat<T>>(num_class_,1);
341  arma::Mat<T> m = arma::sqrt(1./nk)-1/size_;
342  arma::uvec labs = arma::conv_to<arma::uvec>::from(labels);
343  arma::Mat<T> variance = arma::sum(
344  arma::pow(inputs - parameters_.cols(labs),2),1);
345  arma::Mat<T> s = arma::sqrt(variance/(size_-num_class_)).t();
346  arma::Mat<T> ms = m*s;
347  arma::inplace_trans(ms);
348  arma::Mat<T> devi = (parameters_.each_col() - centroid_) / ms;
349  arma::Mat<T> signs = arma::sign(devi);
350  arma::Mat<T> dev = arma::abs(devi) - shrink_;
351  dev = arma::clamp(dev, 0, arma::datum::inf);
352  dev %= signs;
353  arma::Mat<T> msd = dev % ms;
354  parameters_ = centroid_ + msd.each_col();
355  }
356 }

Referenced by algo::classification::NMC< T >::NMC().

+ Here is the caller graph for this function:

◆ Train() [2/2]

template<class T >
void algo::classification::NMC< T >::Train ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels,
const size_t  num_class 
)
Parameters
inputs: X
labels: y
num_class: number of classes

Definition at line 359 of file paramclass_impl.h.

362 {
363  this -> num_class_ = num_class;
364  this -> Train (inputs,labels);
365 }

The documentation for this class was generated from the following files: