Learning Curve Plus Plus (LCPP)
sample.h
Go to the documentation of this file.
1 /**
2  * @file sample.h
3  * @author Ozgur Taylan Turan
4  *
5  * Samplers for datasets.
6  *
7  */
8 
9 #ifndef SAMPLE_H
10 #define SAMPLE_H
11 
12 namespace data {
13 //-----------------------------------------------------------------------------
14 // RandomSelect : Given a dataset, select N of them randomly without replacement
15 // and seperate the rest.
16 //-----------------------------------------------------------------------------
17 /**
18  *
19  * @param dataset : dataset to be splited
20  * @param size : size of the selection
21  * @param repeat : how many times repeat the process
22  */
23 template<class O=arma::uword>
25 {
26  template<class T>
27  std::pair<arma::Col<O>,arma::Col<O>> operator()
28  ( const arma::Col<O>& total, const size_t size, const T N )
29  {
30  arma::Col<O> a,b;
31  Split(total,a,b,N);
32  return std::pair<arma::Col<O>,arma::Col<O>>(a,b);
33  }
34 
35  void operator() ( const size_t size,
36  const arma::Row<size_t> Ns,
37  const size_t repeat,
38  std::vector<std::pair<arma::Col<O>,arma::Col<O>>>& collect,
39  const size_t seed = SEED )
40  {
41  arma::Col<O> total = arma::regspace<arma::Col<O>>(0,size-1);
42  mlpack::RandomSeed(seed);
43  O counter = 0;
44  collect.clear();
45  collect.resize(repeat*Ns.n_elem);
46  for (size_t j = 0; j < repeat; ++j)
47  for (size_t i = 0; i < Ns.n_elem; ++i)
48  collect.at(counter++) = (*this)(total, size, Ns[i]);
49  }
50 };
51 //-----------------------------------------------------------------------------
52 // Bootstrap : Given a dataset, select N of them randomly with replacement and
53 // seperate the rest.
54 //-----------------------------------------------------------------------------
55 /**
56  * @param dataset : dataset to be splited
57  * @param size : size of the selection
58  * @param repeat : how many times repeat the process
59  * @param collect : collection of your sets
60  * @param counter : for labeling your sets
61  */
62 template<class O=arma::uword>
63 struct Bootstrap
64 {
65  template<class T>
66  std::pair<arma::Col<O>,arma::Col<O>> operator()
67  ( const arma::Col<O> total, const size_t size, const T N )
68  {
69  auto sel = arma::randi<arma::uvec>(N, arma::distr_param(0,size-1));
70  sel = arma::sort(sel);
71  auto rest = data::SetDiff(total,sel);
72  return std::pair<arma::Col<O>,arma::Col<O>>(sel,rest);
73  }
74 
75  void operator() ( const size_t size,
76  const arma::Row<size_t> Ns,
77  const size_t repeat,
78  std::vector<std::pair<arma::Col<O>,arma::Col<O>>>& collect,
79  const size_t seed = SEED )
80  {
81  arma::Col<O> total = arma::regspace<arma::Col<O>>(0,size-1);
82  mlpack::RandomSeed(seed);
83  O counter = 0;
84  collect.clear();
85  collect.resize(repeat*Ns.n_elem);
86  for (size_t j = 0; j < repeat; ++j)
87  for (size_t i = 0; i < Ns.n_elem; ++i)
88  collect.at(counter++) = (*this)(total, size, Ns[i]);
89  }
90 };
91 //-----------------------------------------------------------------------------
92 // Additive : Given a dataset, select N of them randomly and keep on taking
93 // form the rest.
94 //-----------------------------------------------------------------------------
95 /**
96  * @param dataset : dataset to be splited
97  * @param size : size of the selection
98  * @param repeat : how many times repeat the process
99  */
100 template<class O=arma::uword>
101 struct Additive
102 {
103  void operator() ( const size_t size,
104  const arma::Row<size_t> Ns,
105  const size_t repeat,
106  std::vector<std::pair<arma::Col<O>,arma::Col<O>>>& collect,
107  const size_t seed = SEED )
108  {
109  size_t counter = 0;
110  auto total = arma::regspace<arma::Col<O>>(0,size-1);
111  collect.clear();
112  collect.resize(repeat*Ns.n_elem);
113  mlpack::RandomSeed(seed);
114  for (size_t j = 0; j < repeat; ++j)
115  {
116  arma::Col<O> trainset,testset;
117  for (size_t i=0; i < Ns.n_elem; i++)
118  {
119  if (i == 0)
120  Split(total,trainset,testset,Ns[i]);
121  else
122  Migrate(trainset,testset,Ns[i]-Ns[i-1]);
123  collect.at(counter++) =
124  std::pair<arma::Col<O>,arma::Col<O>>(trainset,testset);
125  }
126  }
127  }
128 };
129 
130 } // namespace data
131 
132 #endif
void Split(const arma::Mat< T > &input, const arma::Row< U > &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, arma::Row< U > &trainLabel, arma::Row< U > &testLabel, const size_t trainNum)
Definition: manip.h:162
T SetDiff(const T &check, const T &with)
Definition: manip.h:23
void Migrate(arma::Mat< T > &train_inp, arma::Row< U > &train_lab, arma::Mat< T > &test_inp, arma::Row< U > &test_lab, const size_t N)
Definition: manip.h:70