10 #ifndef DATASET_IMPL_H
11 #define DATASET_IMPL_H
18 template<
class LABEL,
class T>
19 Dataset<LABEL,T>::Dataset (
size_t dim,
size_t seed ) :
20 dimension_(dim),seed_(seed) { };
22 template<
class LABEL,
class T>
23 Dataset<LABEL,T>::Dataset (
const arma::Mat<T>& inputs,
34 template<
class LABEL,
class T>
35 void Dataset<LABEL,T>::Linear (
const size_t N,
const T noise_std )
37 mlpack::RandomSeed(seed_.value());
38 inputs_ = arma::randn(dimension_,N);
39 labels_ = arma::ones(1,dimension_) * inputs_ +
40 arma::randn(1,N,arma::distr_param(T(0),T(noise_std)));
46 template<
class LABEL,
class T>
47 void Dataset<LABEL,T>::Sine (
const size_t N,
const T noise_std )
49 mlpack::RandomSeed(seed_.value());
50 inputs_ = arma::randn(dimension_,N);
51 labels_ = arma::sin(arma::ones(1,dimension_) * inputs_) +
52 arma::randn(1,N,arma::distr_param(T(0),T(noise_std)));
58 template<
class LABEL,
class T>
59 void Dataset<LABEL,T>::Banana (
const size_t N,
const T delta )
62 WARNING(
"Dataset::Banana requires dimension to be 2,\
63 overwriting dimension_!");
64 mlpack::RandomSeed(seed_.value());
67 arma::Mat<T> i1, i2, temp;
69 temp = 0.125*M_PI + 1.25*M_PI*
70 arma::randu<arma::Mat<T>>(1,N, arma::distr_param(0.,1.));
72 i1 = arma::join_cols(r*arma::sin(temp),r*arma::cos(temp));
73 i1 += s*arma::randu<arma::Mat<T>>(2,N, arma::distr_param(0.,1.));
75 temp = 0.375*M_PI - 1.25*M_PI*
76 arma::randu<arma::Mat<T>>(1,N, arma::distr_param(0.,1.));
78 i2 = arma::join_cols(r*arma::sin(temp),r*arma::cos(temp));
79 i2 += s*arma::randu<arma::Mat<T>>(2,N, arma::distr_param(0.,1.));
84 inputs_ = arma::join_rows(i1,i2);
86 labels_.set_size(2*N);
87 if constexpr ( std::is_same<LABEL,arma::Row<size_t>>::value )
88 labels_.subvec(0, N-1).zeros();
89 else if constexpr ( std::is_same<LABEL,arma::Row<int>>::value )
90 labels_.subvec(0, N-1).fill(-1);
91 labels_.subvec(N, 2*N - 1).ones();
98 template<
class LABEL,
class T>
99 void Dataset<LABEL,T>::Dipping (
const size_t N,
const T r,
const T noise_std )
102 arma::Mat<T> x1(dimension_, N, arma::fill::randn);
103 x1.each_row() /= arma::sqrt(arma::sum(arma::pow(x1,2),0));
108 arma::Mat<T> cov(dimension_, dimension_, arma::fill::eye);
109 arma::Col<T> mean(dimension_);
110 mean.zeros(); cov *= 0.1;
111 arma::Mat<T> x2 = arma::mvnrnd(mean, cov, N);
113 inputs_ = arma::join_rows(x1,x2);
116 inputs_ += arma::randn<arma::Mat<T>>(dimension_,2*N,
117 arma::distr_param(0., noise_std));
119 labels_.set_size(2*N);
120 if constexpr ( std::is_same<LABEL,arma::Row<size_t>>::value )
121 labels_.subvec(0, N-1).zeros();
122 else if constexpr ( std::is_same<LABEL,arma::Row<int>>::value )
123 labels_.subvec(0, N-1).fill(-1);
124 labels_.subvec(N, 2*N - 1).ones();
126 this->_update_info();
132 template<
class LABEL,
class T>
133 void Dataset<LABEL,T>::Gaussian (
const size_t N,
134 const arma::Row<T>& means )
137 size_t n_class = means.n_elem;
139 inputs_.set_size(dimension_, N*n_class);
140 labels_.set_size(N*n_class);
141 if constexpr ( std::is_same<LABEL,arma::Row<size_t>>::value )
142 labels_.subvec(0, N-1).zeros();
143 else if constexpr ( std::is_same<LABEL,arma::Row<int>>::value )
145 assert ( means.n_elem == 2 &&
146 "To use Row<int> you need binary classification problem");
147 labels_.subvec(0, N-1).fill(-1);
150 for (
size_t i = 0; i < n_class ; ++i)
152 inputs_.cols(i * N, (i + 1) * N - 1) = means(i)
153 + arma::randn<arma::Mat<T>>(dimension_, N);
156 labels_.subvec(N, 2*N - 1).ones();
159 this->_update_info();
164 template<
class LABEL,
class T>
165 void Dataset<LABEL,T>::_update_info( )
167 size_ = inputs_.n_cols;
168 dimension_ = inputs_.n_rows;
169 if constexpr ( std::is_same<LABEL,arma::Row<size_t>>::value )
170 num_class_ = arma::unique(labels_).eval().n_elem;
171 else if constexpr ( std::is_same<LABEL,arma::Row<int>>::value )
177 template<
class LABEL,
class T>
178 void Dataset<LABEL,T>::Update (
const arma::Mat<T>& inputs,
179 const LABEL& labels )
181 this->inputs_ = inputs; this->labels_ = labels; this->_update_info();
191 template<
class LTYPE,
class T>
192 Dataset<LTYPE,T>::Dataset(
const size_t&
id,
const std::filesystem::path& path ) :
195 std::filesystem::create_directories(filepath_);
196 std::filesystem::create_directories(metapath_);
198 meta_url_ =
"https://www.openml.org/api/v1/data/"
199 + std::to_string(
id);
201 metafile_ = metapath_/(std::to_string(
id)+
".meta");
202 file_ = filepath_ / (std::to_string(
id) +
".arff");
204 if (!std::filesystem::exists(metafile_))
205 this->_fetchmetadata();
207 down_url_ = _getdownurl(_readmetadata());
209 if (!std::filesystem::exists(file_))
212 this->_fetchmetadata();
217 WARNING(
"Dataset " << id_ <<
" is already present.");
222 template<
class LTYPE,
class T>
223 void Dataset<LTYPE,T>::Update (
const arma::Mat<T>& inputs,
224 const arma::Row<LTYPE>& labels )
226 inputs_ = inputs; labels_ = labels; this->_update_info();
229 template<
class LTYPE,
class T>
230 Dataset<LTYPE,T>::Dataset(
const size_t&
id ) :
231 Dataset( id,DATASET_PATH/
"openml") { }
233 template<
class LTYPE,
class T>
234 bool Dataset<LTYPE,T>::_download( )
239 curl_global_init(CURL_GLOBAL_DEFAULT);
240 curl = curl_easy_init();
243 FILE* fp = fopen(file_.c_str(),
"wb");
246 ERR(
"Could not open file for writing: " << file_);
247 curl_easy_cleanup(curl);
248 curl_global_cleanup();
253 LOG(down_url_.c_str());
254 curl_easy_setopt(curl, CURLOPT_URL, down_url_.c_str());
255 curl_easy_setopt(curl, CURLOPT_WRITEDATA, fp);
257 curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
259 curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
261 res = curl_easy_perform(curl);
266 ERR(
"curl_easy_perform() failed: "<< curl_easy_strerror(res));
268 curl_easy_cleanup(curl);
269 curl_global_cleanup();
275 curl_easy_cleanup(curl);
276 curl_global_cleanup();
278 LOG(
"Dataset " << id_ <<
" downloaded to " << file_ <<
".");
282 template<
class LTYPE,
class T>
283 std::string Dataset<LTYPE,T>::_fetchmetadata()
289 std::string readBuffer;
291 curl_global_init(CURL_GLOBAL_DEFAULT);
292 curl = curl_easy_init();
295 curl_easy_setopt(curl, CURLOPT_URL, meta_url_.c_str());
296 curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, utils::WriteCallback);
297 curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer);
298 res = curl_easy_perform(curl);
300 ERR(
"curl_easy_perform() failed: " << curl_easy_strerror(res));
301 curl_easy_cleanup(curl);
303 curl_global_cleanup();
309 std::ofstream outFile(metafile_);
310 if (outFile.is_open())
312 outFile << readBuffer;
316 ERR(
"Unable to open file for writing.");
319 ERR(
"Not Saving metadata.");
323 template<
class LTYPE,
class T>
324 std::string Dataset<LTYPE,T>::_readmetadata()
326 std::ifstream infile(metafile_);
327 if (!infile.is_open())
329 ERR(
"Unable to open file for reading: " << metafile_ );
332 std::stringstream buffer;
333 buffer << infile.rdbuf();
338 template<
class LTYPE,
class T>
339 std::string Dataset<LTYPE,T>::_gettargetname(
const std::string& metadata )
342 std::regex re(R
"(<oml:default_target_attribute>(.*?)</oml:default_target_attribute>)");
346 if (regex_search(metadata, match, re) && match.size() > 1)
352 WARNING(
"Cannot find the target name using 'class' instead!!!");
357 template<
class LTYPE,
class T>
358 std::string Dataset<LTYPE,T>::_getdownurl(
const std::string& metadata )
361 std::regex re(R
"(<oml:url>(.*?)</oml:url>)");
366 if (regex_search(metadata, match, re) && match.size() > 1)
372 WARNING(
"Probably something went wrong with meta data fetch!!!");
377 template<
class LTYPE,
class T>
378 int Dataset<LTYPE,T>::_findlabel (
const std::string& targetname )
380 std::ifstream file(file_);
386 ERR(
"Error opening file: " << file_ );
390 while (std::getline(file, line))
392 line.erase(0, line.find_first_not_of(
" \t"));
393 if (line.find(
"@ATTRIBUTE") == 0 || line.find(
"@attribute") == 0)
395 size_t start = line.find(
' ') + 1;
397 size_t end = line.find_first_of(
" \t", start);
398 if (end == std::string::npos)
402 std::string name = line.substr(start, end - start);
404 if (name == targetname || name ==
"'"+targetname+
"'")
413 template<
class LTYPE,
class T>
414 bool Dataset<LTYPE,T>::_iscateg(
const arma::Row<T>& row)
416 std::set<int> distinctValues;
419 for (
size_t i = 0; i < row.n_elem; ++i)
424 if (value ==
static_cast<int>(value))
425 distinctValues.insert(
static_cast<int>(value));
432 return distinctValues.size() < row.n_elem;
435 template<
class LTYPE,
class T>
436 arma::Row<size_t> Dataset<LTYPE,T>::_convcateg(
const arma::Row<T>& row)
438 std::unordered_map<T, size_t> valueToIndex;
439 size_t categoryIndex = 0;
442 for (
size_t i = 0; i < row.n_elem; ++i)
446 if (valueToIndex.find(value) == valueToIndex.end())
447 valueToIndex[value] = categoryIndex++;
451 arma::Row<size_t> categoricalRow(row.n_elem);
454 for (
size_t i = 0; i < row.n_elem; ++i)
455 categoricalRow(i) = valueToIndex[row(i)];
457 return categoricalRow;
460 template<
class LTYPE,
class T>
461 arma::Row<size_t> Dataset<LTYPE,T>::_procrow(
const arma::Row<T>& row)
468 arma::Row<size_t> categoricalRow(row.n_elem);
469 for (
size_t i = 0; i < row.n_elem; ++i)
470 categoricalRow(i) =
static_cast<size_t>(row(i));
471 return categoricalRow;
475 return _convcateg(row);
478 template<
class LTYPE,
class T>
479 void Dataset<LTYPE,T>::Save(
const std::string& filename )
481 std::ofstream file(filename, std::ios::binary);
483 ERR(
"\rCannot open file for writing: " << filename << std::flush);
485 cereal::BinaryOutputArchive archive(file);
486 archive(cereal::make_nvp(
"Dataset", *
this));
487 LOG(
"\rDataset object saved to " << filename << std::flush);
491 template<
class LTYPE,
class T>
492 std::shared_ptr<Dataset<LTYPE,T>> Dataset<LTYPE,T>::Load
493 (
const std::string& filename )
495 std::ifstream file(filename, std::ios::binary);
498 ERR(
"\rError: Cannot open file for reading: " << filename);
501 cereal::BinaryInputArchive archive(file);
502 auto dataset = std::make_shared<Dataset<LTYPE,T>>();
503 archive(cereal::make_nvp(
"Dataset", *dataset));
504 LOG(
"\rDataset loaded from " << filename);
508 template<
class LTYPE,
class T>
509 void Dataset<LTYPE,T>::_load( )
512 arma::Mat<DTYPE> data;
513 mlpack::data::DatasetInfo info;
514 mlpack::data::Load(file_.c_str(), data, info);
515 idx =_findlabel(_gettargetname(_readmetadata()));
518 throw std::runtime_error(
"Cannot find the label!");
520 if constexpr (std::is_same<LTYPE, size_t>::value)
521 labels_ = _procrow(data.row(idx));
523 labels_ = data.row(idx);
527 this->_update_info();
530 template<
class LTYPE,
class T>
531 void Dataset<LTYPE,T>::_update_info( )
533 dimension_ = inputs_.n_rows;
534 size_ = inputs_.n_cols;
535 if (std::is_same<LTYPE,size_t>::value)
536 num_class_ = (arma::unique(labels_).eval()).n_elem;
543 Collect<T>::Collect(
const size_t&
id ) : Collect( id, DATASET_PATH )
545 std::filesystem::create_directories(metapath_);
546 std::filesystem::create_directories(filespath_);
550 Collect<T>::Collect(
const arma::Row<size_t>& ids ) : size_(ids.n_elem),
554 std::filesystem::create_directories(metapath_);
555 std::filesystem::create_directories(filespath_);
559 Collect<T>::Collect(
const size_t&
id,
const std::filesystem::path& path ) :
562 std::filesystem::create_directories(metapath_);
563 std::filesystem::create_directories(filespath_);
564 url_ =
"https://www.openml.org/api/v1/json/study/" + std::to_string(
id);
566 size_ = keys_.n_elem;
570 Dataset<T> Collect<T>::GetNext ( )
572 return Dataset<T>(keys_[counter_++],filespath_);
576 Dataset<T> Collect<T>::GetID (
const size_t&
id )
578 if (arma::any(arma::find(keys_ ==
id)))
579 return Dataset<T>(
id);
582 ERR(
"Collect: cannot find the dataset, giving you the next instead...");
588 arma::Row<size_t> Collect<T>::_getkeys()
594 std::string readBuffer;
596 curl_global_init(CURL_GLOBAL_DEFAULT);
597 curl = curl_easy_init();
600 curl_easy_setopt(curl, CURLOPT_URL, url_.c_str());
601 curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, utils::WriteCallback);
602 curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer);
603 res = curl_easy_perform(curl);
605 ERR(
"Collect:curl_easy_perform() failed: " << curl_easy_strerror(res));
606 curl_easy_cleanup(curl);
608 curl_global_cleanup();
614 std::ofstream outFile(metafile_);
615 if (outFile.is_open())
617 outFile << readBuffer;
621 ERR(
"Collect:Unable to open file for writing.");
624 ERR(
"Collect:Not Saving metadata.");
625 std::vector<size_t> dataIds;
629 std::regex arrayRegex(R
"("data_id"\s*:\s*\[([^\]]+)\])");
633 if (std::regex_search(readBuffer, match, arrayRegex))
636 std::string dataIdArrayStr = match[1].str();
639 std::regex numberRegex(R
"(\d+)");
640 auto numbersBegin = std::sregex_iterator(dataIdArrayStr.begin(),
641 dataIdArrayStr.end(),
643 auto numbersEnd = std::sregex_iterator();
646 for (std::sregex_iterator i = numbersBegin; i != numbersEnd; ++i)
648 int dataId = std::stoi((*i).str());
649 dataIds.push_back(dataId);
652 return arma::conv_to<arma::Row<size_t>>::from(dataIds);