|
| | QDC () |
| |
| | QDC (const size_t &num_class, const double &lambda) |
| |
| | QDC (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t &num_class, const double &lambda) |
| |
| | QDC (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t &num_class, const double &lambda, const arma::Row< T > &priors) |
| |
| | QDC (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t &num_class) |
| |
| void | Train (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels) |
| |
| void | Train (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t num_class) |
| |
| void | Classify (const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const |
| |
| void | Classify (const arma::Mat< T > &inputs, arma::Row< size_t > &labels, arma::Mat< T > &scores) const |
| |
| T | ComputeError (const arma::Mat< T > &points, const arma::Row< size_t > &responses) const |
| |
| T | ComputeAccuracy (const arma::Mat< T > &points, const arma::Row< size_t > &responses) const |
| |
| template<typename Archive > |
| void | serialize (Archive &ar, const unsigned int) |
| |
template<class T = DTYPE>
class algo::classification::QDC< T >
Definition at line 162 of file paramclass.h.
◆ QDC() [1/5]
template<class T = DTYPE>
◆ QDC() [2/5]
template<class T = DTYPE>
- Parameters
-
| num_class | : number of classes |
| lambda | : regularization |
Definition at line 175 of file paramclass.h.
175 : num_class_(num_class),
176 lambda_(lambda) { } ;
◆ QDC() [3/5]
template<class T >
| algo::classification::QDC< T >::QDC |
( |
const arma::Mat< T > & |
inputs, |
|
|
const arma::Row< size_t > & |
labels, |
|
|
const size_t & |
num_class, |
|
|
const double & |
lambda |
|
) |
| |
- Parameters
-
| inputs | : X |
| labels | : y |
| num_class | : number of classes |
| lambda | : regularization |
Definition at line 175 of file paramclass_impl.h.
178 : num_class_(num_class), lambda_(lambda)
180 Train(inputs, labels);
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
References algo::classification::QDC< T >::Train().
◆ QDC() [4/5]
template<class T >
| algo::classification::QDC< T >::QDC |
( |
const arma::Mat< T > & |
inputs, |
|
|
const arma::Row< size_t > & |
labels, |
|
|
const size_t & |
num_class, |
|
|
const double & |
lambda, |
|
|
const arma::Row< T > & |
priors |
|
) |
| |
◆ QDC() [5/5]
◆ Classify() [1/2]
- Parameters
-
Definition at line 231 of file paramclass_impl.h.
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
◆ Classify() [2/2]
template<class T >
| void algo::classification::QDC< T >::Classify |
( |
const arma::Mat< T > & |
inputs, |
|
|
arma::Row< size_t > & |
labels, |
|
|
arma::Mat< T > & |
scores |
|
) |
| const |
- Parameters
-
| inputs | : X* |
| labels | : y* |
| probs | : scores* |
Definition at line 239 of file paramclass_impl.h.
243 const size_t N = inputs.n_cols;
245 probs.resize(num_class_,N);
247 if ( num_class_ == 1 )
249 labels.fill(unique_(0));
250 probs.row(unique_(0)).fill(1.);
257 for (
size_t n=0; n<inputs.n_cols; n++ )
259 for (
size_t c=0; c<unique_.n_elem; c++ )
261 norm = inputs.col(n).t() - means_.at(unique_(c));
262 probs(class_(unique_(c)),n) = std::log(priors_(c))
263 - 0.5*(arma::det(covs_.at(unique_(c)))+inputs.n_rows*std::log(2*arma::datum::pi))
264 - 0.5* arma::dot(norm*icovs_.at(unique_(c)),norm);
266 labels(n) = class_(probs.col(n).index_max());
269 probs = arma::exp(probs.each_row() - arma::max(probs,0));
270 probs = probs.each_row()/arma::sum(probs,0);
◆ ComputeAccuracy()
template<class T >
| T algo::classification::QDC< T >::ComputeAccuracy |
( |
const arma::Mat< T > & |
points, |
|
|
const arma::Row< size_t > & |
responses |
|
) |
| const |
Calculate the Accuracy
- Parameters
-
Definition at line 285 of file paramclass_impl.h.
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
◆ ComputeError()
template<class T >
| T algo::classification::QDC< T >::ComputeError |
( |
const arma::Mat< T > & |
points, |
|
|
const arma::Row< size_t > & |
responses |
|
) |
| const |
Calculate the Error Rate
- Parameters
-
Definition at line 275 of file paramclass_impl.h.
278 arma::Row<size_t> predictions;
280 arma::Row<size_t> temp = predictions - responses;
281 return (arma::accu(temp != 0))/T(predictions.n_elem);
◆ serialize()
template<class T = DTYPE>
template<typename Archive >
Serialize the model.
Definition at line 262 of file paramclass.h.
265 ar ( cereal::make_nvp(
"dim",dim_),
266 cereal::make_nvp(
"num_class",num_class_),
267 cereal::make_nvp(
"size",size_),
268 cereal::make_nvp(
"lambda",lambda_),
269 cereal::make_nvp(
"means",means_),
270 cereal::make_nvp(
"covs",covs_),
271 cereal::make_nvp(
"icovs",icovs_),
272 cereal::make_nvp(
"unique",unique_),
273 cereal::make_nvp(
"class",class_),
274 cereal::make_nvp(
"priors",priors_));
◆ Train() [1/2]
- Parameters
-
Definition at line 184 of file paramclass_impl.h.
188 dim_ = inputs.n_rows;
189 size_ = inputs.n_cols;
190 unique_ = arma::unique(labels);
192 class_ = arma::regspace<arma::Row<size_t>>(0,1,num_class_);
193 priors_ = get_prior<T>(labels,num_class_);
195 arma::Row<size_t>::iterator it = unique_.begin();
196 arma::Row<size_t>::iterator end = unique_.end();
202 auto extract = extract_class(inputs, labels, *it);
204 inx = std::get<0>(extract);
205 means_[*it] = arma::conv_to<arma::Row<T>>::from(arma::mean(inx,1));
206 if ( inx.n_cols == 1 )
208 covs_[*it] = arma::eye<arma::Mat<T>>(dim_,dim_);
209 icovs_[*it] = arma::eye<arma::Mat<T>>(dim_,dim_);
213 covs_[*it] = arma::cov(inx.t());
214 covs_[*it].diag() += jitter_+lambda_;
215 icovs_[*it] = arma::pinv(covs_[*it]);
217 icovs_[*it] = arma::pinv(covs_[*it]);
Referenced by algo::classification::QDC< T >::QDC().
◆ Train() [2/2]
template<class T >
| void algo::classification::QDC< T >::Train |
( |
const arma::Mat< T > & |
inputs, |
|
|
const arma::Row< size_t > & |
labels, |
|
|
const size_t |
num_class |
|
) |
| |
- Parameters
-
| inputs | : X |
| labels | : y |
| num_class | : number of classes |
Definition at line 222 of file paramclass_impl.h.
226 this -> num_class_=num_class;
227 this ->
Train(inputs,labels);
The documentation for this class was generated from the following files: