|
| | OnevAll ()=default |
| |
| template<class... Args> |
| | OnevAll (const size_t &num_class, const Args &... args) |
| |
| template<class... Args> |
| | OnevAll (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const size_t &num_class, const Args &... args) |
| |
| template<class... Args> |
| void | Train (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const Args &... args) |
| |
| template<class... Args> |
| void | Train (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels) |
| |
| void | Classify (const arma::Mat< T > &inputs, arma::Row< size_t > &preds) const |
| |
| void | Classify (const arma::Mat< T > &inputs, arma::Row< size_t > &preds, arma::Mat< T > &probs) const |
| |
| T | ComputeError (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels) |
| |
| T | ComputeAccuracy (const arma::Mat< T > &inputs, const arma::Row< size_t > &labels) |
| |
| template<typename Archive > |
| void | serialize (Archive &ar, const unsigned int) |
| |
|
|
std::vector< MODEL > | models_ |
| |
template<class MODEL, class T = DTYPE>
class algo::classification::OnevAll< MODEL, T >
Definition at line 22 of file multiclass.h.
◆ OnevAll() [1/3]
template<class MODEL , class T = DTYPE>
◆ OnevAll() [2/3]
template<class MODEL , class T >
template<class... Args>
- Parameters
-
| args | : parameters for the model |
Definition at line 19 of file multiclass_impl.h.
22 models_.resize(num_class);
23 for(
size_t i=0;i<num_class;i++)
24 models_[i] = MODEL(args...);
◆ OnevAll() [3/3]
template<class MODEL , class T >
template<class... Args>
- Parameters
-
| inputs | : X |
| labels | : y |
| args | : parameters for the model |
Definition at line 29 of file multiclass_impl.h.
34 unq_ = arma::unique(labels).eval();
38 unclass_ = unq_.n_elem;
39 models_.resize(unclass_);
41 for(
size_t i=0;i<unclass_;i++)
42 models_[i] = MODEL(args...);
44 Train(inputs, labels, args...);
void Train(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels, const Args &... args)
◆ Classify() [1/2]
template<class MODEL , class T >
- Parameters
-
Definition at line 89 of file multiclass_impl.h.
void Classify(const arma::Mat< T > &inputs, arma::Row< size_t > &preds) const
◆ Classify() [2/2]
template<class MODEL , class T >
| void algo::classification::OnevAll< MODEL, T >::Classify |
( |
const arma::Mat< T > & |
inputs, |
|
|
arma::Row< size_t > & |
preds, |
|
|
arma::Mat< T > & |
probs |
|
) |
| const |
- Parameters
-
| inputs | : X* |
| preds | : y* |
| probs | : p* |
Definition at line 97 of file multiclass_impl.h.
101 probs.resize(nclass_,inputs.n_cols);
102 preds.resize(inputs.n_cols);
105 for(
size_t i=0;i<unclass_;i++)
107 arma::Row<size_t> temp;
109 models_[i].Classify(inputs,temp,tprobs);
110 probs.row(unq_(i)) = tprobs.row(1);
112 preds = arma::conv_to<arma::Row<size_t>>::from(arma::index_max(arma::abs(probs),0));
116 arma::uvec cantdecide = arma::find(arma::sum(probs,0) == 0);
117 arma::uvec randomsel = arma::randi<arma::uvec>(cantdecide.n_elem,
118 arma::distr_param(0,nclass_-1));
119 probs.elem(randomsel+ cantdecide*nclass_).ones();
120 probs = probs.each_row() / arma::sum(probs,0);
◆ ComputeAccuracy()
template<class MODEL , class T >
- Parameters
-
Definition at line 138 of file multiclass_impl.h.
T ComputeError(const arma::Mat< T > &inputs, const arma::Row< size_t > &labels)
◆ ComputeError()
template<class MODEL , class T >
- Parameters
-
Definition at line 128 of file multiclass_impl.h.
131 arma::Row<size_t> predictions;
133 arma::Row<size_t> temp = predictions - labels;
134 return (arma::accu(temp != 0))/T(predictions.n_elem);
◆ serialize()
template<class MODEL , class T = DTYPE>
template<typename Archive >
Serialize the model.
Definition at line 100 of file multiclass.h.
103 ar ( cereal::make_nvp(
"nclass",nclass_),
104 cereal::make_nvp(
"unclass",unclass_),
105 cereal::make_nvp(
"unq",unq_),
106 cereal::make_nvp(
"models",models_),
107 cereal::make_nvp(
"oneclass",oneclass_));
◆ Train() [1/2]
template<class MODEL , class T >
template<class... Args>
- Parameters
-
| inputs | : X |
| labels | : y |
| args | : parameters for the model |
Definition at line 50 of file multiclass_impl.h.
53 unq_ = arma::unique(labels).eval();
57 unclass_ = unq_.n_elem;
60 for(
size_t i=0;i<unclass_;i++)
62 auto binlabels = arma::conv_to<arma::Row<size_t>>::from(labels==unq_(i));
63 models_[i].Train(inputs, binlabels);
◆ Train() [2/2]
template<class MODEL , class T >
template<class... Args>
| void algo::classification::OnevAll< MODEL, T >::Train |
( |
const arma::Mat< T > & |
inputs, |
|
|
const arma::Row< size_t > & |
labels, |
|
|
const Args &... |
args |
|
) |
| |
- Parameters
-
| inputs | : X |
| labels | : y |
| args | : parameters for the model |
Definition at line 72 of file multiclass_impl.h.
78 models_.resize(unq_.n_elem);
79 for(
size_t i=0;i<unclass_;i++)
81 auto binlabels = arma::conv_to<arma::Row<size_t>>::from(labels==unq_(i));
82 MODEL model(inputs,binlabels,args...);
The documentation for this class was generated from the following files: