23 T
SetDiff(
const T& check,
const T& with )
25 assert ( check.is_sorted() && with.is_sorted() &&
26 "For this method I assumed you sorted your stuff...");
31 while (i < check.n_elem && j < with.n_elem)
33 if (check[i] < with[j])
35 result.resize(result.n_elem+1);
36 result[result.n_elem-1] = check[i];
39 else if (check[i] > with[j])
49 while (i < check.n_elem)
51 result.resize(result.n_elem+1);
52 result[result.n_elem-1] = check[i];
69 template<
typename T,
typename U>
71 arma::Row<U>& train_lab,
72 arma::Mat<T>& test_inp,
73 arma::Row<U>& test_lab,
76 assert ( ( train_inp.n_cols == train_lab.n_elem &&
77 train_inp.n_rows == test_inp.n_rows &&
78 test_inp.n_cols == test_lab.n_elem &&
79 test_lab.n_elem >= N) &&
80 "Requested element number is bigger than what you have.");
82 train_inp.resize(train_inp.n_rows, train_inp.n_cols+N);
83 train_lab.resize(train_lab.n_cols+N);
84 arma::uvec idx = arma::randperm(test_inp.n_cols, N);
85 train_inp.tail_cols(N) = test_inp.cols(idx);
86 train_lab.tail_cols(N) = test_lab.cols(idx);
87 test_lab.shed_cols(idx);
88 test_inp.shed_cols(idx);
91 template<
typename T,
typename U>
92 void Migrate ( arma::Mat<T>& train_inp,
93 arma::Mat<U>& train_lab,
94 arma::Mat<T>& test_inp,
95 arma::Mat<U>& test_lab,
98 assert ( ( test_inp.n_cols == test_lab.n_elem &&
99 train_inp.n_rows == test_inp.n_rows &&
100 test_lab.n_rows == train_lab.n_rows &&
101 test_lab.n_elem >= N ) &&
102 "Requested element number is bigger than what you have.");
104 train_inp.resize(train_inp.n_rows, train_inp.n_cols+N);
105 train_lab.resize(train_lab.n_rows, train_lab.n_cols+N);
106 arma::uvec idx = arma::randperm(test_inp.n_cols, N);
107 train_inp.tail_cols(N) = test_inp.cols(idx);
108 train_lab.tail_cols(N) = test_lab.cols(idx);
109 test_lab.shed_cols(idx);
110 test_inp.shed_cols(idx);
122 Migrate(trainset.inputs_,trainset.labels_,testset.inputs_,testset.labels_,N);
124 trainset.size_ = trainset.inputs_.n_cols;
125 testset.size_ = testset.inputs_.n_cols;
130 trainset.num_class_ = arma::unique(trainset.labels_).eval().n_cols;
131 testset.num_class_ = arma::unique(testset.labels_).eval().n_cols;
135 template<
typename T=arma::uword>
136 void Migrate ( arma::Col<T>& trainset,
137 arma::Col<T>& testset,
140 assert ( testset.n_elem >= N &&
141 "Requested element number is bigger than what you have.");
143 trainset.resize(trainset.n_elem+N);
144 arma::uvec idx = arma::randperm(testset.n_elem, N);
145 trainset.tail(N) = testset.rows(idx);
146 testset.shed_rows(idx);
161 template<
typename T,
typename U>
162 void Split (
const arma::Mat<T>& input,
163 const arma::Row<U>& inputLabel,
164 arma::Mat<T>& trainData,
165 arma::Mat<T>& testData,
166 arma::Row<U>& trainLabel,
167 arma::Row<U>& testLabel,
168 const size_t trainNum )
170 const arma::uvec order =
171 arma::shuffle(arma::regspace<arma::uvec>(0, input.n_cols - 1));
173 trainData = input.cols(order.rows(0,trainNum-1));
174 trainLabel = inputLabel.cols(order.rows(0,trainNum-1));
176 testData = input.cols(order.rows(trainNum,input.n_cols-1));
177 testLabel = inputLabel.cols(order.rows(trainNum,input.n_cols-1));
181 template<
typename T,
typename U>
182 void Split (
const arma::Mat<T>& input,
183 const arma::Mat<U>& inputLabel,
184 arma::Mat<T>& trainData,
185 arma::Mat<T>& testData,
186 arma::Mat<U>& trainLabel,
187 arma::Mat<U>& testLabel,
188 const size_t trainNum )
191 const arma::uvec order =
192 arma::shuffle(arma::regspace<arma::uvec>(0, input.n_cols - 1));
194 trainData = input.cols(order.rows(0,trainNum-1));
195 trainLabel = inputLabel.cols(order.rows(0,trainNum-1));
197 testData = input.cols(order.rows(trainNum,input.n_cols-1));
198 testLabel = inputLabel.cols(order.rows(trainNum,input.n_cols-1));
208 void Split (
const arma::Mat<T>& input,
209 arma::Mat<T>& trainData,
210 arma::Mat<T>& testData,
211 const size_t trainNum )
213 const arma::uvec order =
214 arma::shuffle(arma::regspace<arma::uvec>(0, input.n_cols - 1));
216 trainData = input.cols( order.head(trainNum) );
217 testData = input.cols( order.tail(input.n_cols-trainNum) );
222 void Split (
const arma::Row<T>& input,
223 arma::Row<T>& trainData,
224 arma::Row<T>& testData,
225 const size_t trainNum )
227 const arma::uvec order =
228 arma::shuffle(arma::regspace<arma::uvec>(0, input.n_cols - 1));
230 trainData = input.cols( order.head(trainNum) );
231 testData = input.cols( order.tail(input.n_cols-trainNum) );
236 void Split (
const arma::Col<T>& input,
237 arma::Col<T>& trainData,
238 arma::Col<T>& testData,
239 const size_t trainNum )
241 const arma::uvec order =
242 arma::shuffle(arma::regspace<arma::uvec>(0, input.n_rows- 1));
244 trainData = input.rows( order.head(trainNum) );
245 testData = input.rows( order.tail(input.n_rows-trainNum) );
255 template<
typename T,
typename U>
256 std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Row<U>, arma::Row<U>>
258 const arma::Row<U>& inputLabel,
259 const size_t trainNum)
261 arma::Mat<T> trainData;
262 arma::Mat<T> testData;
263 arma::Row<U> trainLabel;
264 arma::Row<U> testLabel;
266 Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
269 return std::make_tuple(std::move(trainData),
271 std::move(trainLabel),
272 std::move(testLabel));
275 template<
typename T,
typename U>
276 std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Mat<U>, arma::Mat<U>>
277 Split (
const arma::Mat<T>& input,
278 const arma::Mat<U>& inputLabel,
279 const size_t trainNum )
281 arma::Mat<T> trainData;
282 arma::Mat<T> testData;
283 arma::Mat<U> trainLabel;
284 arma::Mat<U> testLabel;
286 Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
289 return std::make_tuple(std::move(trainData),
291 std::move(trainLabel),
292 std::move(testLabel));
301 std::tuple<arma::Mat<T>, arma::Mat<T>>
303 const size_t trainNum)
305 arma::Mat<T> trainData;
306 arma::Mat<T> testData;
307 Split(input, trainData, testData, trainNum);
309 return std::make_tuple(std::move(trainData),
310 std::move(testData));
323 const size_t trainNum )
326 trainset = dataset; testset = dataset;
328 Split(dataset.inputs_, dataset.labels_,
329 trainset.inputs_, testset.inputs_,
330 trainset.labels_, testset.labels_, trainNum);
332 trainset.Update(trainset.inputs_,trainset.labels_);
333 testset.Update(testset.inputs_,testset.labels_);
343 template<
typename T,
class O=DTYPE>
350 trainset = dataset; testset = dataset;
352 mlpack::data::Split(dataset.inputs_, dataset.labels_,
353 trainset.inputs_, testset.inputs_,
354 trainset.labels_, testset.labels_, testRatio);
356 trainset.Update(trainset.inputs_,trainset.labels_);
357 testset.Update(testset.inputs_,testset.labels_);
375 template<
typename T,
typename LabelsType,
376 typename = std::enable_if_t<arma::is_arma_type<LabelsType>::value> >
378 const LabelsType& inputLabel,
379 arma::Mat<T>& trainData,
380 arma::Mat<T>& testData,
381 LabelsType& trainLabel,
382 LabelsType& testLabel,
383 const size_t trainNum,
384 const bool shuffleData =
true)
386 const bool typeCheck = (arma::is_Row<LabelsType>::value)
387 || (arma::is_Col<LabelsType>::value);
389 throw std::runtime_error(
"data::Split(): when stratified sampling is done, "
390 "labels must have type `arma::Row<>`!");
391 mlpack::util::CheckSameSizes(input, inputLabel,
"data::Split()");
393 double testRatio = double(1) - double(trainNum)/double(inputLabel.n_elem);
396 size_t trainSize = 0;
398 arma::uvec labelCounts;
399 arma::uvec testLabelCounts;
400 typename LabelsType::elem_type maxLabel = inputLabel.max();
402 labelCounts.zeros(maxLabel+1);
403 testLabelCounts.zeros(maxLabel+1);
405 for (
typename LabelsType::elem_type label : inputLabel)
406 ++labelCounts[label];
408 for (arma::uword labelCount : labelCounts)
410 testSize += floor(labelCount * testRatio+1e-6);
411 trainSize += labelCount - floor(labelCount * testRatio+1e-6);
414 trainData.set_size(input.n_rows, trainSize);
415 testData.set_size(input.n_rows, testSize);
416 trainLabel.set_size(inputLabel.n_rows, trainSize);
417 testLabel.set_size(inputLabel.n_rows, testSize);
421 arma::uvec order = arma::shuffle(
422 arma::linspace<arma::uvec>(0, input.n_cols - 1, input.n_cols));
424 for (arma::uword i : order)
426 typename LabelsType::elem_type label = inputLabel[i];
427 if (testLabelCounts[label] < floor(labelCounts[label] * testRatio+1e-6))
429 testLabelCounts[label] += 1;
430 testData.col(testIdx) = input.col(i);
431 testLabel[testIdx] = inputLabel[i];
436 trainData.col(trainIdx) = input.col(i);
437 trainLabel[trainIdx] = inputLabel[i];
444 for (arma::uword i = 0; i < input.n_cols; i++)
446 typename LabelsType::elem_type label = inputLabel[i];
447 if (testLabelCounts[label] < floor(labelCounts[label] * testRatio+1e-6))
449 testLabelCounts[label] += 1;
450 testData.col(testIdx) = input.col(i);
451 testLabel[testIdx] = inputLabel[i];
456 trainData.col(trainIdx) = input.col(i);
457 trainLabel[trainIdx] = inputLabel[i];
478 template<
typename T,
typename LabelsType,
479 typename = std::enable_if_t<arma::is_arma_type<LabelsType>::value> >
481 const LabelsType& inputLabel,
482 arma::Mat<T>& trainData,
483 arma::Mat<T>& testData,
484 LabelsType& trainLabel,
485 LabelsType& testLabel,
486 const double testRatio,
487 const bool shuffleData =
true)
489 const bool typeCheck = (arma::is_Row<LabelsType>::value)
490 || (arma::is_Col<LabelsType>::value);
492 throw std::runtime_error(
"data::Split(): when stratified sampling is done, "
493 "labels must have type `arma::Row<>`!");
494 mlpack::util::CheckSameSizes(input, inputLabel,
"data::Split()");
498 size_t trainSize = 0;
500 arma::uvec labelCounts;
501 arma::uvec testLabelCounts;
502 typename LabelsType::elem_type maxLabel = inputLabel.max();
504 labelCounts.zeros(maxLabel+1);
505 testLabelCounts.zeros(maxLabel+1);
507 for (
typename LabelsType::elem_type label : inputLabel)
508 ++labelCounts[label];
510 for (arma::uword labelCount : labelCounts)
512 testSize += floor(labelCount * testRatio+1e-6);
513 trainSize += labelCount - floor(labelCount * testRatio+1e-6);
516 trainData.set_size(input.n_rows, trainSize);
517 testData.set_size(input.n_rows, testSize);
518 trainLabel.set_size(inputLabel.n_rows, trainSize);
519 testLabel.set_size(inputLabel.n_rows, testSize);
523 arma::uvec order = arma::shuffle(
524 arma::linspace<arma::uvec>(0, input.n_cols - 1, input.n_cols));
526 for (arma::uword i : order)
528 typename LabelsType::elem_type label = inputLabel[i];
529 if (testLabelCounts[label] < floor(labelCounts[label] * testRatio+1e-6))
531 testLabelCounts[label] += 1;
532 testData.col(testIdx) = input.col(i);
533 testLabel[testIdx] = inputLabel[i];
538 trainData.col(trainIdx) = input.col(i);
539 trainLabel[trainIdx] = inputLabel[i];
546 for (arma::uword i = 0; i < input.n_cols; i++)
548 typename LabelsType::elem_type label = inputLabel[i];
549 if (testLabelCounts[label] < floor(labelCounts[label] * testRatio+1e-6))
551 testLabelCounts[label] += 1;
552 testData.col(testIdx) = input.col(i);
553 testLabel[testIdx] = inputLabel[i];
558 trainData.col(trainIdx) = input.col(i);
559 trainLabel[trainIdx] = inputLabel[i];
578 const size_t trainNum )
580 assert ( (
typeid(T) ==
typeid(
Dataset<arma::Row<size_t>>) ||
582 "StratifiedSplit can only be used for classification dataset type...");
584 trainset = dataset; testset = dataset;
587 trainset.inputs_, testset.inputs_,
588 trainset.labels_, testset.labels_, trainNum);
590 trainset.Update(trainset.inputs_,trainset.labels_);
591 testset.Update(testset.inputs_,testset.labels_);
606 const double testRatio )
609 assert ( (
typeid(T) ==
typeid(
Dataset<arma::Row<size_t>>) ||
611 "StratifiedSplit can only be used for classification dataset type...");
613 trainset = dataset; testset = dataset;
615 mlpack::data::StratifiedSplit(dataset.inputs_, dataset.labels_,
616 trainset.inputs_, testset.inputs_,
617 trainset.labels_, testset.labels_, testRatio);
619 trainset.Update(trainset.inputs_,trainset.labels_);
620 testset.Update(testset.inputs_,testset.labels_);
623 template<
typename T,
typename U>
624 std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Row<U>, arma::Row<U>>
625 StratifiedSplit (
const arma::Mat<T>& input,
626 const arma::Row<U>& inputLabel,
627 const size_t trainNum)
629 arma::Mat<T> trainData;
630 arma::Mat<T> testData;
631 arma::Row<U> trainLabel;
632 arma::Row<U> testLabel;
634 StratifiedSplit(input, inputLabel, trainData, testData, trainLabel, testLabel,
637 return std::make_tuple(std::move(trainData),
639 std::move(trainLabel),
640 std::move(testLabel));
void Split(const arma::Mat< T > &input, const arma::Row< U > &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, arma::Row< U > &trainLabel, arma::Row< U > &testLabel, const size_t trainNum)
void StratifiedSplit(const arma::Mat< T > &input, const LabelsType &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, LabelsType &trainLabel, LabelsType &testLabel, const size_t trainNum, const bool shuffleData=true)
T SetDiff(const T &check, const T &with)
void Migrate(arma::Mat< T > &train_inp, arma::Row< U > &train_lab, arma::Mat< T > &test_inp, arma::Row< U > &test_lab, const size_t N)