Learning Curve Plus Plus (LCPP)
algo::classification::QDC< T > Class Template Reference

Public Member Functions

 QDC ()
 
 QDC (const size_t &num_class, const double &lambda)
 
 QDC (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t &num_class, const double &lambda)
 
 QDC (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t &num_class, const double &lambda, const arma::Row< T > &priors)
 
 QDC (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t &num_class)
 
void Train (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
 
void Train (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t num_class)
 
void Classify (const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
 
void Classify (const arma::Mat< T > &inputs, arma::Row< size_t > &labels, arma::Mat< T > &scores) const
 
ComputeError (const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
 
ComputeAccuracy (const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
 
template<typename Archive >
void serialize (Archive &ar, const unsigned int)
 

Detailed Description

template<class T = DTYPE>
class algo::classification::QDC< T >

Definition at line 162 of file paramclass.h.

Constructor & Destructor Documentation

◆ QDC() [1/5]

template<class T = DTYPE>
algo::classification::QDC< T >::QDC ( )
inline

Non-working model

Definition at line 169 of file paramclass.h.

169 : lambda_(0.) { };

◆ QDC() [2/5]

template<class T = DTYPE>
algo::classification::QDC< T >::QDC ( const size_t &  num_class,
const double &  lambda 
)
inline
Parameters
num_class: number of classes
lambda: regularization

Definition at line 175 of file paramclass.h.

175  : num_class_(num_class),
176  lambda_(lambda) { } ;

◆ QDC() [3/5]

template<class T >
algo::classification::QDC< T >::QDC ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels,
const size_t &  num_class,
const double &  lambda 
)
Parameters
inputs: X
labels: y
num_class: number of classes
lambda: regularization

Definition at line 175 of file paramclass_impl.h.

178  : num_class_(num_class), lambda_(lambda)
179 {
180  Train(inputs, labels);
181 }
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)

References algo::classification::QDC< T >::Train().

+ Here is the call graph for this function:

◆ QDC() [4/5]

template<class T >
algo::classification::QDC< T >::QDC ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels,
const size_t &  num_class,
const double &  lambda,
const arma::Row< T > &  priors 
)
Parameters
inputs: X
labels: y
num_class: number of classes
lambda: regularization
priors: known priors

Definition at line 164 of file paramclass_impl.h.

168  : num_class_(num_class),
169  lambda_(lambda), priors_(priors)
170 {
171  Train(inputs, labels);
172 }

References algo::classification::QDC< T >::Train().

+ Here is the call graph for this function:

◆ QDC() [5/5]

template<class T >
algo::classification::QDC< T >::QDC ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels,
const size_t &  num_class 
)
Parameters
inputs: X
labels: y
num_class: number of classes

Definition at line 156 of file paramclass_impl.h.

158  : num_class_(num_class), lambda_(0.)
159 {
160  Train(inputs, labels);
161 }

References algo::classification::QDC< T >::Train().

+ Here is the call graph for this function:

Member Function Documentation

◆ Classify() [1/2]

template<class T >
void algo::classification::QDC< T >::Classify ( const arma::Mat< T > &  inputs,
arma::Row< size_t > &  labels 
) const
Parameters
inputs: X*
labels: y*

Definition at line 231 of file paramclass_impl.h.

233 {
234  arma::Mat<T> temp;
235  Classify(inputs, labels, temp);
236 }
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const

◆ Classify() [2/2]

template<class T >
void algo::classification::QDC< T >::Classify ( const arma::Mat< T > &  inputs,
arma::Row< size_t > &  labels,
arma::Mat< T > &  scores 
) const
Parameters
inputs: X*
labels: y*
probs: scores*

Definition at line 239 of file paramclass_impl.h.

242 {
243  const size_t N = inputs.n_cols;
244  labels.resize(N);
245  probs.resize(num_class_,N);
246 
247  if ( num_class_ == 1 )
248  {
249  labels.fill(unique_(0));
250  probs.row(unique_(0)).fill(1.);
251  }
252  else
253  {
254  arma::Row<T> norm;
255 
256  /* #pragma omp parallel for */
257  for ( size_t n=0; n<inputs.n_cols; n++ )
258  {
259  for ( size_t c=0; c<unique_.n_elem; c++ )
260  {
261  norm = inputs.col(n).t() - means_.at(unique_(c));
262  probs(class_(unique_(c)),n) = std::log(priors_(c))
263  - 0.5*(arma::det(covs_.at(unique_(c)))+inputs.n_rows*std::log(2*arma::datum::pi))
264  - 0.5* arma::dot(norm*icovs_.at(unique_(c)),norm);
265  }
266  labels(n) = class_(probs.col(n).index_max());
267 
268  }
269  probs = arma::exp(probs.each_row() - arma::max(probs,0));
270  probs = probs.each_row()/arma::sum(probs,0);
271  }
272 }

◆ ComputeAccuracy()

template<class T >
T algo::classification::QDC< T >::ComputeAccuracy ( const arma::Mat< T > &  points,
const arma::Row< size_t > &  responses 
) const

Calculate the Accuracy

Parameters
inputs: X*
labels: y*

Definition at line 285 of file paramclass_impl.h.

287 {
288  return (1. - ComputeError(points, responses))*100;
289 }
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const

◆ ComputeError()

template<class T >
T algo::classification::QDC< T >::ComputeError ( const arma::Mat< T > &  points,
const arma::Row< size_t > &  responses 
) const

Calculate the Error Rate

Parameters
inputs: X*
labels: y*

Definition at line 275 of file paramclass_impl.h.

277 {
278  arma::Row<size_t> predictions;
279  Classify(points,predictions);
280  arma::Row<size_t> temp = predictions - responses;
281  return (arma::accu(temp != 0))/T(predictions.n_elem);
282 }

◆ serialize()

template<class T = DTYPE>
template<typename Archive >
void algo::classification::QDC< T >::serialize ( Archive &  ar,
const unsigned int   
)
inline

Serialize the model.

Definition at line 262 of file paramclass.h.

264  {
265  ar ( cereal::make_nvp("dim",dim_),
266  cereal::make_nvp("num_class",num_class_),
267  cereal::make_nvp("size",size_),
268  cereal::make_nvp("lambda",lambda_),
269  cereal::make_nvp("means",means_),
270  cereal::make_nvp("covs",covs_),
271  cereal::make_nvp("icovs",icovs_),
272  cereal::make_nvp("unique",unique_),
273  cereal::make_nvp("class",class_),
274  cereal::make_nvp("priors",priors_));
275  }

◆ Train() [1/2]

template<class T >
void algo::classification::QDC< T >::Train ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels 
)
Parameters
inputs: X
labels: y

Definition at line 184 of file paramclass_impl.h.

186 {
187 
188  dim_ = inputs.n_rows;
189  size_ = inputs.n_cols;
190  unique_ = arma::unique(labels);
191 
192  class_ = arma::regspace<arma::Row<size_t>>(0,1,num_class_);
193  priors_ = get_prior<T>(labels,num_class_);
194 
195  arma::Row<size_t>::iterator it = unique_.begin();
196  arma::Row<size_t>::iterator end = unique_.end();
197 
198  arma::Mat<T> inx;
199 
200  for(; it!=end; it++)
201  {
202  auto extract = extract_class(inputs, labels, *it);
203 
204  inx = std::get<0>(extract);
205  means_[*it] = arma::conv_to<arma::Row<T>>::from(arma::mean(inx,1));
206  if ( inx.n_cols == 1 )
207  {
208  covs_[*it] = arma::eye<arma::Mat<T>>(dim_,dim_);
209  icovs_[*it] = arma::eye<arma::Mat<T>>(dim_,dim_);
210  }
211  else
212  {
213  covs_[*it] = arma::cov(inx.t());
214  covs_[*it].diag() += jitter_+lambda_;
215  icovs_[*it] = arma::pinv(covs_[*it]);
216  }
217  icovs_[*it] = arma::pinv(covs_[*it]);
218  }
219 }

Referenced by algo::classification::QDC< T >::QDC().

+ Here is the caller graph for this function:

◆ Train() [2/2]

template<class T >
void algo::classification::QDC< T >::Train ( const arma::Mat< T > &  inputs,
const arma::Row< size_t > &  labels,
const size_t  num_class 
)
Parameters
inputs: X
labels: y
num_class: number of classes

Definition at line 222 of file paramclass_impl.h.

225 {
226  this -> num_class_=num_class;
227  this -> Train(inputs,labels);
228 }

The documentation for this class was generated from the following files: