|
| | LDC () |
| |
| | LDC (const size_t &num_class, const double &lambda) |
| |
| | LDC (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t &num_class) |
| |
| | LDC (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t &num_class, const double &lambda) |
| |
| | LDC (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t &num_class, const double &lambda, const arma::Row< T > &priors) |
| |
| void | Train (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels) |
| |
| void | Train (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t num_class) |
| |
| void | Classify (const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const |
| |
| void | Classify (const arma::Mat< T > &inputs, arma::Row< size_t > &labels, arma::Mat< T > &scores) const |
| |
| T | ComputeError (const arma::Mat< T > &points, const arma::Row< size_t > &responses) const |
| |
| T | ComputeAccuracy (const arma::Mat< T > &points, const arma::Row< size_t > &responses) const |
| |
| template<typename Archive > |
| void | serialize (Archive &ar, const unsigned int) |
| |
|
|
size_t | dim_ |
| |
|
size_t | num_class_ |
| |
|
size_t | size_ |
| |
|
double | lambda_ |
| |
|
double | jitter_ = 1.e-8 |
| |
|
std::map< size_t, arma::Row< T > > | means_ |
| |
|
std::map< size_t, arma::Mat< T > > | covs_ |
| |
|
arma::Mat< T > | cov_ |
| |
|
arma::Mat< T > | mean_ |
| |
|
arma::Row< size_t > | unique_ |
| |
|
arma::Row< size_t > | class_ |
| |
|
arma::Row< T > | priors_ |
| |
template<class T = DTYPE>
class algo::classification::LDC< T >
Definition at line 19 of file paramclass.h.
◆ LDC() [1/5]
template<class T = DTYPE>
◆ LDC() [2/5]
template<class T = DTYPE>
- Parameters
-
Definition at line 31 of file paramclass.h.
31 : num_class_(num_class),
◆ LDC() [3/5]
◆ LDC() [4/5]
template<class T >
| algo::classification::LDC< T >::LDC |
( |
const arma::Mat< T > & |
inputs, |
|
|
const arma::Row< size_t > & |
labels, |
|
|
const size_t & |
num_class, |
|
|
const double & |
lambda |
|
) |
| |
◆ LDC() [5/5]
template<class T >
| algo::classification::LDC< T >::LDC |
( |
const arma::Mat< T > & |
inputs, |
|
|
const arma::Row< size_t > & |
labels, |
|
|
const size_t & |
num_class, |
|
|
const double & |
lambda, |
|
|
const arma::Row< T > & |
priors |
|
) |
| |
◆ Classify() [1/2]
- Parameters
-
Definition at line 94 of file paramclass_impl.h.
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &labels) const
◆ Classify() [2/2]
template<class T >
| void algo::classification::LDC< T >::Classify |
( |
const arma::Mat< T > & |
inputs, |
|
|
arma::Row< size_t > & |
labels, |
|
|
arma::Mat< T > & |
scores |
|
) |
| const |
- Parameters
-
| inputs | : X* |
| labels | : y* |
| probs | : scores* |
Definition at line 102 of file paramclass_impl.h.
106 const size_t N = inputs.n_cols;
108 probs.resize(num_class_,N);
109 if ( unique_.n_elem == 1 )
111 labels.fill(unique_(0));
112 probs.row(unique_(0)).fill(1.);
116 #pragma omp parallel for
117 for (
size_t n=0; n<inputs.n_cols; n++ )
119 for (
size_t c=0; c<unique_.n_elem; c++ )
121 probs(class_(unique_(c)),n) = std::log(priors_(c))
122 - 0.5*arma::dot(means_.at(unique_(c))*
123 cov_, means_.at(unique_(c)))
124 + arma::dot(inputs.col(n).t()*cov_,means_.at(unique_(c)));
126 labels(n) = class_(probs.col(n).index_max());
129 probs = arma::exp(probs.each_row() - arma::max(probs,0));
130 probs = probs.each_row()/arma::sum(probs,0);
◆ ComputeAccuracy()
template<class T >
| T algo::classification::LDC< T >::ComputeAccuracy |
( |
const arma::Mat< T > & |
points, |
|
|
const arma::Row< size_t > & |
responses |
|
) |
| const |
Calculate the Accuracy
- Parameters
-
Definition at line 146 of file paramclass_impl.h.
T ComputeError(const arma::Mat< T > &points, const arma::Row< size_t > &responses) const
◆ ComputeError()
template<class T >
| T algo::classification::LDC< T >::ComputeError |
( |
const arma::Mat< T > & |
points, |
|
|
const arma::Row< size_t > & |
responses |
|
) |
| const |
Calculate the Error Rate
- Parameters
-
Definition at line 136 of file paramclass_impl.h.
139 arma::Row<size_t> predictions;
141 arma::Row<size_t> temp = predictions - responses;
142 return (arma::accu(temp != 0))/T(predictions.n_elem);
◆ serialize()
template<class T = DTYPE>
template<typename Archive >
Serialize the model.
Definition at line 119 of file paramclass.h.
122 ar (cereal::make_nvp(
"dim",dim_),
123 cereal::make_nvp(
"num_class",num_class_),
124 cereal::make_nvp(
"size",size_),
125 cereal::make_nvp(
"lambda",lambda_),
126 cereal::make_nvp(
"means",means_),
127 cereal::make_nvp(
"covs",covs_),
128 cereal::make_nvp(
"mean",mean_),
129 cereal::make_nvp(
"cov",cov_),
130 cereal::make_nvp(
"class",class_),
131 cereal::make_nvp(
"unique",unique_),
132 cereal::make_nvp(
"priors",priors_));
◆ Train() [1/2]
- Parameters
-
Definition at line 47 of file paramclass_impl.h.
50 class_ = arma::regspace<arma::Row<size_t>>(0,1,num_class_);
51 priors_ = get_prior<T>(labels, num_class_);
54 size_ = inputs.n_cols;
55 unique_ = arma::unique(labels);
56 if (unique_.n_elem != 1)
58 cov_.resize(dim_,dim_);
61 arma::Row<size_t>::iterator it = unique_.begin();
62 arma::Row<size_t>::iterator end = unique_.end();
68 auto extract = extract_class(inputs, labels, *it);
69 inx = std::get<0>(extract);
70 means_[*it] = arma::conv_to<arma::Row<T>>::from(arma::mean(inx,1));
71 if ( inx.n_cols == 1 )
72 covs_[*it] = arma::eye<arma::Mat<T>>(dim_,dim_);
75 covs_[*it] = arma::cov(inx.t());
76 covs_[*it].diag() += jitter_+lambda_;
80 cov_ = arma::pinv(cov_) / num_class_;
Referenced by algo::classification::LDC< T >::LDC().
◆ Train() [2/2]
template<class T >
| void algo::classification::LDC< T >::Train |
( |
const arma::Mat< T > & |
inputs, |
|
|
const arma::Row< size_t > & |
labels, |
|
|
const size_t |
num_class |
|
) |
| |
- Parameters
-
| inputs | : X |
| labels | : y |
| num_class | : number of classes |
Definition at line 85 of file paramclass_impl.h.
89 this->num_class_ = num_class;
90 this->
Train(inputs,labels);
The documentation for this class was generated from the following files: