Learning Curve Plus Plus (LCPP)
algo::classification::OnevAll< MODEL, T > Class Template Reference

Public Member Functions

 OnevAll ()=default
 
template<class... Args>
 OnevAll (const size_t &num_class, const Args &... args)
 
template<class... Args>
 OnevAll (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t &num_class, const Args &... args)
 
template<class... Args>
void Train (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const Args &... args)
 
template<class... Args>
void Train (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
 
void Classify (const arma::Mat< T > &inputs, arma::Row< size_t > &preds) const
 
void Classify (const arma::Mat< T > &inputs, arma::Row< size_t > &preds, arma::Mat< T > &probs) const
 
ComputeError (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
 
ComputeAccuracy (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
 
template<typename Archive >
void serialize (Archive &ar, const unsigned int)
 

Public Attributes

std::vector< MODEL > models_
 

Detailed Description

template<class MODEL, class T = DTYPE>
class algo::classification::OnevAll< MODEL, T >

Definition at line 22 of file multiclass.h.

Constructor & Destructor Documentation

◆ OnevAll() [1/3]

template<class MODEL , class T = DTYPE>
algo::classification::OnevAll< MODEL, T >::OnevAll ( )
default

Non-working model

◆ OnevAll() [2/3]

template<class MODEL , class T >
template<class... Args>
algo::classification::OnevAll< MODEL, T >::OnevAll ( const size_t &  num_class,
const Args &...  args 
)
Parameters
args: parameters for the model

Definition at line 19 of file multiclass_impl.h.

19  :
20  nclass_(num_class)
21 {
22  models_.resize(num_class);
23  for(size_t i=0;i<num_class;i++)
24  models_[i] = MODEL(args...);
25 }

◆ OnevAll() [3/3]

template<class MODEL , class T >
template<class... Args>
algo::classification::OnevAll< MODEL, T >::OnevAll ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels,
const size_t &  num_class,
const Args &...  args 
)
Parameters
inputs: X
labels: y
args: parameters for the model

Definition at line 29 of file multiclass_impl.h.

32  : nclass_(num_class)
33 {
34  unq_ = arma::unique(labels).eval();
35  if (unq_.n_elem == 1)
36  oneclass_ = true;
37 
38  unclass_ = unq_.n_elem;
39  models_.resize(unclass_);
40 
41  for(size_t i=0;i<unclass_;i++)
42  models_[i] = MODEL(args...);
43 
44  Train(inputs, labels, args...);
45 
46 }
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const Args &... args)

Member Function Documentation

◆ Classify() [1/2]

template<class MODEL , class T >
void algo::classification::OnevAll< MODEL, T >::Classify ( const arma::Mat< T > &  inputs,
arma::Row< size_t > &  preds 
) const
Parameters
inputs: X*
preds: y*

Definition at line 89 of file multiclass_impl.h.

91 {
92  arma::Mat<T> probs;
93  Classify(inputs, preds, probs);
94 }
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &preds) const

◆ Classify() [2/2]

template<class MODEL , class T >
void algo::classification::OnevAll< MODEL, T >::Classify ( const arma::Mat< T > &  inputs,
arma::Row< size_t > &  preds,
arma::Mat< T > &  probs 
) const
Parameters
inputs: X*
preds: y*
probs: p*

Definition at line 97 of file multiclass_impl.h.

100 {
101  probs.resize(nclass_,inputs.n_cols);
102  preds.resize(inputs.n_cols);
103  if (!oneclass_)
104  {
105  for(size_t i=0;i<unclass_;i++)
106  {
107  arma::Row<size_t> temp;
108  arma::Mat<T> tprobs;
109  models_[i].Classify(inputs,temp,tprobs);
110  probs.row(unq_(i)) = tprobs.row(1);
111  }
112  preds = arma::conv_to<arma::Row<size_t>>::from(arma::index_max(arma::abs(probs),0));
113 
114  // We cannot decide and randomly assign a class
115 
116  arma::uvec cantdecide = arma::find(arma::sum(probs,0) == 0);
117  arma::uvec randomsel = arma::randi<arma::uvec>(cantdecide.n_elem,
118  arma::distr_param(0,nclass_-1));
119  probs.elem(randomsel+ cantdecide*nclass_).ones();
120  probs = probs.each_row() / arma::sum(probs,0);
121 
122  }
123  else
124  preds.fill(unq_[0]);
125 }

◆ ComputeAccuracy()

template<class MODEL , class T >
T algo::classification::OnevAll< MODEL, T >::ComputeAccuracy ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels 
)
Parameters
inputs: X
labels: y

Definition at line 138 of file multiclass_impl.h.

140 {
141  return 1.-ComputeError(inputs,labels);
142 }
T ComputeError(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)

◆ ComputeError()

template<class MODEL , class T >
T algo::classification::OnevAll< MODEL, T >::ComputeError ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels 
)
Parameters
inputs: X
labels: y

Definition at line 128 of file multiclass_impl.h.

130 {
131  arma::Row<size_t> predictions;
132  Classify(inputs, predictions);
133  arma::Row<size_t> temp = predictions - labels;
134  return (arma::accu(temp != 0))/T(predictions.n_elem);
135 }

◆ serialize()

template<class MODEL , class T = DTYPE>
template<typename Archive >
void algo::classification::OnevAll< MODEL, T >::serialize ( Archive &  ar,
const unsigned int   
)
inline

Serialize the model.

Definition at line 100 of file multiclass.h.

102  {
103  ar ( cereal::make_nvp("nclass",nclass_),
104  cereal::make_nvp("unclass",unclass_),
105  cereal::make_nvp("unq",unq_),
106  cereal::make_nvp("models",models_),
107  cereal::make_nvp("oneclass",oneclass_));
108  }

◆ Train() [1/2]

template<class MODEL , class T >
template<class... Args>
void algo::classification::OnevAll< MODEL, T >::Train ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels 
)
Parameters
inputs: X
labels: y
args: parameters for the model

Definition at line 50 of file multiclass_impl.h.

52 {
53  unq_ = arma::unique(labels).eval();
54  if (unq_.n_elem == 1)
55  oneclass_ = true;
56 
57  unclass_ = unq_.n_elem;
58  if (!oneclass_)
59  {
60  for(size_t i=0;i<unclass_;i++)
61  {
62  auto binlabels = arma::conv_to<arma::Row<size_t>>::from(labels==unq_(i));
63  models_[i].Train(inputs, binlabels);
64  }
65  }
66  else
67  return;
68 }

◆ Train() [2/2]

template<class MODEL , class T >
template<class... Args>
void algo::classification::OnevAll< MODEL, T >::Train ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels,
const Args &...  args 
)
Parameters
inputs: X
labels: y
args: parameters for the model

Definition at line 72 of file multiclass_impl.h.

75 {
76  if (!oneclass_)
77  {
78  models_.resize(unq_.n_elem);
79  for(size_t i=0;i<unclass_;i++)
80  {
81  auto binlabels = arma::conv_to<arma::Row<size_t>>::from(labels==unq_(i));
82  MODEL model(inputs,binlabels,args...);
83  models_[i] = model;
84  }
85  }
86 }

The documentation for this class was generated from the following files: