Learning Curve Plus Plus (LCPP)
dataset_impl.h
1 /**
2  * @file datagen_impl.h
3  * @author Ozgur Taylan Turan
4  *
5  * A simple toy data generation interface
6  *
7  *
8  *
9  */
10 #ifndef DATASET_IMPL_H
11 #define DATASET_IMPL_H
12 
13 namespace data {
14 
15 //-----------------------------------------------------------------------------
16 // Dataset
17 //-----------------------------------------------------------------------------
18 template<class LABEL,class T>
19 Dataset<LABEL,T>::Dataset ( size_t dim, size_t seed ) :
20  dimension_(dim),seed_(seed) { };
21 
22 template<class LABEL,class T>
23 Dataset<LABEL,T>::Dataset ( const arma::Mat<T>& inputs,
24  const LABEL& labels )
25 {
26  inputs_ = inputs;
27  labels_ = labels;
28  this->_update_info();
29 };
30 
31 //-----------------------------------------------------------------------------
32 // Dataset::Linear
33 //-----------------------------------------------------------------------------
34 template<class LABEL,class T>
35 void Dataset<LABEL,T>::Linear ( const size_t N, const T noise_std )
36 {
37  mlpack::RandomSeed(seed_.value());
38  inputs_ = arma::randn(dimension_,N);
39  labels_ = arma::ones(1,dimension_) * inputs_ +
40  arma::randn(1,N,arma::distr_param(T(0),T(noise_std)));
41  this->_update_info();
42 }
43 //-----------------------------------------------------------------------------
44 // Dataset::Sine
45 //-----------------------------------------------------------------------------
46 template<class LABEL,class T>
47 void Dataset<LABEL,T>::Sine ( const size_t N, const T noise_std )
48 {
49  mlpack::RandomSeed(seed_.value());
50  inputs_ = arma::randn(dimension_,N);
51  labels_ = arma::sin(arma::ones(1,dimension_) * inputs_) +
52  arma::randn(1,N,arma::distr_param(T(0),T(noise_std)));
53  this->_update_info();
54 }
55 //-----------------------------------------------------------------------------
56 // Dataset::Banana
57 //-----------------------------------------------------------------------------
58 template<class LABEL,class T>
59 void Dataset<LABEL,T>::Banana ( const size_t N, const T delta )
60 {
61  if (dimension_ != 2)
62  WARNING("Dataset::Banana requires dimension to be 2,\
63  overwriting dimension_!");
64  mlpack::RandomSeed(seed_.value());
65  double r = 5.;
66  double s = 1.0;
67  arma::Mat<T> i1, i2, temp;
68 
69  temp = 0.125*M_PI + 1.25*M_PI*
70  arma::randu<arma::Mat<T>>(1,N, arma::distr_param(0.,1.));
71 
72  i1 = arma::join_cols(r*arma::sin(temp),r*arma::cos(temp));
73  i1 += s*arma::randu<arma::Mat<T>>(2,N, arma::distr_param(0.,1.));
74 
75  temp = 0.375*M_PI - 1.25*M_PI*
76  arma::randu<arma::Mat<T>>(1,N, arma::distr_param(0.,1.));
77 
78  i2 = arma::join_cols(r*arma::sin(temp),r*arma::cos(temp));
79  i2 += s*arma::randu<arma::Mat<T>>(2,N, arma::distr_param(0.,1.));
80  i2 -= 0.75*r;
81 
82  i2 += delta;
83 
84  inputs_ = arma::join_rows(i1,i2);
85 
86  labels_.set_size(2*N);
87  if constexpr ( std::is_same<LABEL,arma::Row<size_t>>::value )
88  labels_.subvec(0, N-1).zeros();
89  else if constexpr ( std::is_same<LABEL,arma::Row<int>>::value )
90  labels_.subvec(0, N-1).fill(-1);
91  labels_.subvec(N, 2*N - 1).ones();
92 
93  this->_update_info();
94 }
95 //-----------------------------------------------------------------------------
96 // Dataset::Dipping
97 //-----------------------------------------------------------------------------
98 template<class LABEL,class T>
99 void Dataset<LABEL,T>::Dipping ( const size_t N, const T r, const T noise_std )
100 
101 {
102  arma::Mat<T> x1(dimension_, N, arma::fill::randn);
103  x1.each_row() /= arma::sqrt(arma::sum(arma::pow(x1,2),0));
104 
105  if ( r != 1 )
106  x1 *= r;
107 
108  arma::Mat<T> cov(dimension_, dimension_, arma::fill::eye);
109  arma::Col<T> mean(dimension_);
110  mean.zeros(); cov *= 0.1;
111  arma::Mat<T> x2 = arma::mvnrnd(mean, cov, N);
112 
113  inputs_ = arma::join_rows(x1,x2);
114 
115  if ( noise_std > 0 )
116  inputs_ += arma::randn<arma::Mat<T>>(dimension_,2*N,
117  arma::distr_param(0., noise_std));
118 
119  labels_.set_size(2*N);
120  if constexpr ( std::is_same<LABEL,arma::Row<size_t>>::value )
121  labels_.subvec(0, N-1).zeros();
122  else if constexpr ( std::is_same<LABEL,arma::Row<int>>::value )
123  labels_.subvec(0, N-1).fill(-1);
124  labels_.subvec(N, 2*N - 1).ones();
125 
126  this->_update_info();
127 }
128 
129 //-----------------------------------------------------------------------------
130 // Dataset::Gaussian
131 //-----------------------------------------------------------------------------
132 template<class LABEL,class T>
133 void Dataset<LABEL,T>::Gaussian ( const size_t N,
134  const arma::Row<T>& means )
135 
136 {
137  size_t n_class = means.n_elem;
138 
139  inputs_.set_size(dimension_, N*n_class);
140  labels_.set_size(N*n_class);
141  if constexpr ( std::is_same<LABEL,arma::Row<size_t>>::value )
142  labels_.subvec(0, N-1).zeros();
143  else if constexpr ( std::is_same<LABEL,arma::Row<int>>::value )
144  {
145  assert ( means.n_elem == 2 &&
146  "To use Row<int> you need binary classification problem");
147  labels_.subvec(0, N-1).fill(-1);
148  }
149 
150  for (size_t i = 0; i < n_class ; ++i)
151  {
152  inputs_.cols(i * N, (i + 1) * N - 1) = means(i)
153  + arma::randn<arma::Mat<T>>(dimension_, N);
154 
155  if (i>0)
156  labels_.subvec(N, 2*N - 1).ones();
157  }
158 
159  this->_update_info();
160 }
161 //-----------------------------------------------------------------------------
162 // Dataset::_update_info
163 //-----------------------------------------------------------------------------
164 template<class LABEL,class T>
165 void Dataset<LABEL,T>::_update_info( )
166 {
167  size_ = inputs_.n_cols;
168  dimension_ = inputs_.n_rows;
169  if constexpr ( std::is_same<LABEL,arma::Row<size_t>>::value )
170  num_class_ = arma::unique(labels_).eval().n_elem;
171  else if constexpr ( std::is_same<LABEL,arma::Row<int>>::value )
172  num_class_ = 2;
173 }
174 //-----------------------------------------------------------------------------
175 // Dataset:: Update
176 //-----------------------------------------------------------------------------
177 template<class LABEL,class T>
178 void Dataset<LABEL,T>::Update ( const arma::Mat<T>& inputs,
179  const LABEL& labels )
180 {
181  this->inputs_ = inputs; this->labels_ = labels; this->_update_info();
182 }
183 
184 } // namespace data
185 
186 namespace data::oml
187 {
188 //-----------------------------------------------------------------------------
189 // Dataset
190 //-----------------------------------------------------------------------------
191 template<class LTYPE,class T>
192 Dataset<LTYPE,T>::Dataset( const size_t& id, const std::filesystem::path& path ) :
193  id_(id), path_(path)
194 {
195  std::filesystem::create_directories(filepath_);
196  std::filesystem::create_directories(metapath_);
197 
198  meta_url_ = "https://www.openml.org/api/v1/data/"
199  + std::to_string(id);
200 
201  metafile_ = metapath_/(std::to_string(id)+".meta");
202  file_ = filepath_ / (std::to_string(id) + ".arff");
203 
204  if (!std::filesystem::exists(metafile_))
205  this->_fetchmetadata();
206 
207  down_url_ = _getdownurl(_readmetadata());
208 
209  if (!std::filesystem::exists(file_))
210  {
211  this->_download();
212  this->_fetchmetadata();
213  this->_load();
214  }
215  else
216  {
217  WARNING("Dataset " << id_ << " is already present.");
218  this->_load();
219  }
220 }
221 ///////////////////////////////////////////////////////////////////////////////
222 template<class LTYPE,class T>
223 void Dataset<LTYPE,T>::Update ( const arma::Mat<T>& inputs,
224  const arma::Row<LTYPE>& labels )
225 {
226  inputs_ = inputs; labels_ = labels; this->_update_info();
227 }
228 ///////////////////////////////////////////////////////////////////////////////
229 template<class LTYPE,class T>
230 Dataset<LTYPE,T>::Dataset( const size_t& id ) :
231  Dataset( id,DATASET_PATH/"openml") { }
232 ///////////////////////////////////////////////////////////////////////////////
233 template<class LTYPE,class T>
234 bool Dataset<LTYPE,T>::_download( )
235 {
236  CURL* curl;
237  CURLcode res;
238  // Initialize CURL
239  curl_global_init(CURL_GLOBAL_DEFAULT);
240  curl = curl_easy_init();
241 
242  // Open file to write the downloaded data
243  FILE* fp = fopen(file_.c_str(), "wb");
244  if (!fp)
245  {
246  ERR("Could not open file for writing: " << file_);
247  curl_easy_cleanup(curl);
248  curl_global_cleanup();
249  return false;
250  }
251 
252  // Set CURL options
253  LOG(down_url_.c_str());
254  curl_easy_setopt(curl, CURLOPT_URL, down_url_.c_str());
255  curl_easy_setopt(curl, CURLOPT_WRITEDATA, fp);
256  // There is redirecting need in openml side
257  curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
258  // Let's not stay in a never ending loop
259  curl_easy_setopt(curl, CURLOPT_MAXREDIRS, 5L);
260  // Perform the request
261  res = curl_easy_perform(curl);
262 
263  // Check for errors
264  if(res != CURLE_OK)
265  {
266  ERR("curl_easy_perform() failed: "<< curl_easy_strerror(res));
267  fclose(fp);
268  curl_easy_cleanup(curl);
269  curl_global_cleanup();
270  return false;
271  }
272 
273  // Cleanup
274  fclose(fp);
275  curl_easy_cleanup(curl);
276  curl_global_cleanup();
277 
278  LOG("Dataset " << id_ << " downloaded to " << file_ << ".");
279  return true;
280 }
281 ///////////////////////////////////////////////////////////////////////////////
282 template<class LTYPE,class T>
283 std::string Dataset<LTYPE,T>::_fetchmetadata()
284 {
285 
286  // Function to fetch metadata from OpenMLdd
287  CURL* curl;
288  CURLcode res;
289  std::string readBuffer;
290 
291  curl_global_init(CURL_GLOBAL_DEFAULT);
292  curl = curl_easy_init();
293  if(curl)
294  {
295  curl_easy_setopt(curl, CURLOPT_URL, meta_url_.c_str());
296  curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, utils::WriteCallback);
297  curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer);
298  res = curl_easy_perform(curl);
299  if(res != CURLE_OK)
300  ERR("curl_easy_perform() failed: " << curl_easy_strerror(res));
301  curl_easy_cleanup(curl);
302  }
303  curl_global_cleanup();
304 
305 
306  if(res == CURLE_OK)
307  {
308  // Save readBuffer to a text file
309  std::ofstream outFile(metafile_);
310  if (outFile.is_open())
311  {
312  outFile << readBuffer;
313  outFile.close();
314  }
315  else
316  ERR("Unable to open file for writing.");
317  }
318  else
319  ERR("Not Saving metadata.");
320  return readBuffer;
321 }
322 ///////////////////////////////////////////////////////////////////////////////
323 template<class LTYPE,class T>
324 std::string Dataset<LTYPE,T>::_readmetadata()
325 {
326  std::ifstream infile(metafile_);
327  if (!infile.is_open())
328  {
329  ERR("Unable to open file for reading: " << metafile_ );
330  return "";
331  }
332  std::stringstream buffer;
333  buffer << infile.rdbuf();
334  infile.close();
335  return buffer.str();
336 }
337 ///////////////////////////////////////////////////////////////////////////////
338 template<class LTYPE,class T>
339 std::string Dataset<LTYPE,T>::_gettargetname( const std::string& metadata )
340 {
341  // Define a regular expression to find the <default_target_value> element
342  std::regex re(R"(<oml:default_target_attribute>(.*?)</oml:default_target_attribute>)");
343  std::smatch match;
344 
345  // Search for the pattern in the XML data
346  if (regex_search(metadata, match, re) && match.size() > 1)
347  {
348  return match.str(1); // Return the matched content
349  }
350  else
351  {
352  WARNING("Cannot find the target name using 'class' instead!!!");
353  return "class";
354  }
355 }
356 ///////////////////////////////////////////////////////////////////////////////
357 template<class LTYPE,class T>
358 std::string Dataset<LTYPE,T>::_getdownurl( const std::string& metadata )
359 {
360  // Define a regular expression to find the <default_target_value> element
361  std::regex re(R"(<oml:url>(.*?)</oml:url>)");
362 
363  std::smatch match;
364 
365  // Search for the pattern in the XML data
366  if (regex_search(metadata, match, re) && match.size() > 1)
367  {
368  return match.str(1); // Return the matched content
369  }
370  else
371  {
372  WARNING("Probably something went wrong with meta data fetch!!!");
373  return "";
374  }
375 }
376 ///////////////////////////////////////////////////////////////////////////////
377 template<class LTYPE,class T>
378 int Dataset<LTYPE,T>::_findlabel ( const std::string& targetname )
379 {
380  std::ifstream file(file_);
381  std::string line;
382  int index = 0;
383 
384  if (!file.is_open())
385  {
386  ERR("Error opening file: " << file_ );
387  return -1; // Error code for file opening failure
388  }
389 
390  while (std::getline(file, line))
391  {
392  line.erase(0, line.find_first_not_of(" \t")); // Trim leading whitespace
393  if (line.find("@ATTRIBUTE") == 0 || line.find("@attribute") == 0)
394  {
395  size_t start = line.find(' ') + 1; // Skip "@attribute"
396  // Find the first space/tab after the attribute name
397  size_t end = line.find_first_of(" \t", start);
398  if (end == std::string::npos)
399  end = line.length();
400 
401  // Extract the attribute name
402  std::string name = line.substr(start, end - start);
403 
404  if (name == targetname || name == "'"+targetname+"'")
405  return index; // Attribute found, return its index
406 
407  ++index; // Increment index for each attribute line
408  }
409  }
410  return -1; // Attribute not found
411 }
412 ///////////////////////////////////////////////////////////////////////////////
413 template<class LTYPE,class T>
414 bool Dataset<LTYPE,T>::_iscateg(const arma::Row<T>& row)
415 {
416  std::set<int> distinctValues;
417 
418  // Iterate over the array
419  for (size_t i = 0; i < row.n_elem; ++i)
420  {
421  T value = row(i);
422 
423  // Check if the value is a whole number
424  if (value == static_cast<int>(value))
425  distinctValues.insert(static_cast<int>(value));
426  else
427  // If any value is not a whole number, it's not categorical
428  return false;
429  }
430 
431  // If we have only a few distinct values, consider it categorical
432  return distinctValues.size() < row.n_elem;
433 }
434 ///////////////////////////////////////////////////////////////////////////////
435 template<class LTYPE,class T>
436 arma::Row<size_t> Dataset<LTYPE,T>::_convcateg(const arma::Row<T>& row)
437 {
438  std::unordered_map<T, size_t> valueToIndex;
439  size_t categoryIndex = 0;
440 
441  // Create a mapping of unique values to integer indices
442  for (size_t i = 0; i < row.n_elem; ++i)
443  {
444  T value = row(i);
445  // If the value has not been seen before, assign a new category
446  if (valueToIndex.find(value) == valueToIndex.end())
447  valueToIndex[value] = categoryIndex++;
448  }
449 
450  // Create a new Row<size_t> to store the mapped categorical values
451  arma::Row<size_t> categoricalRow(row.n_elem);
452 
453  // Map each original value to its corresponding categorical index
454  for (size_t i = 0; i < row.n_elem; ++i)
455  categoricalRow(i) = valueToIndex[row(i)];
456 
457  return categoricalRow;
458 }
459 ///////////////////////////////////////////////////////////////////////////////
460 template<class LTYPE, class T>
461 arma::Row<size_t> Dataset<LTYPE,T>::_procrow(const arma::Row<T>& row)
462 {
463  if (_iscateg(row))
464  {
465  // Return the row as is if it's already categorical
466  // Convert it to a size_t type (though it might not be necessary
467  // if it's already integers)
468  arma::Row<size_t> categoricalRow(row.n_elem);
469  for (size_t i = 0; i < row.n_elem; ++i)
470  categoricalRow(i) = static_cast<size_t>(row(i));
471  return categoricalRow;
472  }
473  else
474  // Convert real values into categorical values
475  return _convcateg(row);
476 }
477 ///////////////////////////////////////////////////////////////////////////////
478 template<class LTYPE,class T>
479 void Dataset<LTYPE,T>::Save( const std::string& filename )
480 {
481  std::ofstream file(filename, std::ios::binary);
482  if (!file)
483  ERR("\rCannot open file for writing: " << filename << std::flush);
484 
485  cereal::BinaryOutputArchive archive(file);
486  archive(cereal::make_nvp("Dataset", *this)); // Serialize the current object
487  LOG("\rDataset object saved to " << filename << std::flush);
488 
489 }
490 ///////////////////////////////////////////////////////////////////////////////
491 template<class LTYPE,class T>
492 std::shared_ptr<Dataset<LTYPE,T>> Dataset<LTYPE,T>::Load
493  ( const std::string& filename )
494 {
495  std::ifstream file(filename, std::ios::binary);
496  if (!file)
497  {
498  ERR("\rError: Cannot open file for reading: " << filename);
499  return nullptr;
500  }
501  cereal::BinaryInputArchive archive(file);
502  auto dataset = std::make_shared<Dataset<LTYPE,T>>();
503  archive(cereal::make_nvp("Dataset", *dataset));// Deserialize into a new object
504  LOG("\rDataset loaded from " << filename);
505  return dataset;
506 }
507 ///////////////////////////////////////////////////////////////////////////////
508 template<class LTYPE,class T>
509 void Dataset<LTYPE,T>::_load( )
510 {
511  int idx = -1;
512  arma::Mat<DTYPE> data;
513  mlpack::data::DatasetInfo info;
514  mlpack::data::Load(file_.c_str(), data, info);
515  idx =_findlabel(_gettargetname(_readmetadata()));
516 
517  if (idx<0)
518  throw std::runtime_error("Cannot find the label!");
519 
520  if constexpr (std::is_same<LTYPE, size_t>::value)
521  labels_ = _procrow(data.row(idx));
522  else
523  labels_ = data.row(idx);
524  /* labels_ = _procrow(data.row(idx)); */
525  data.shed_row(idx);
526  inputs_ = data;
527  this->_update_info();
528 }
529 ///////////////////////////////////////////////////////////////////////////////
530 template<class LTYPE,class T>
531 void Dataset<LTYPE,T>::_update_info( )
532 {
533  dimension_ = inputs_.n_rows;
534  size_ = inputs_.n_cols;
535  if (std::is_same<LTYPE,size_t>::value)
536  num_class_ = (arma::unique(labels_).eval()).n_elem;
537 }
538 ///////////////////////////////////////////////////////////////////////////////
539 //-----------------------------------------------------------------------------
540 // Collect
541 //-----------------------------------------------------------------------------
542 template<class T>
543 Collect<T>::Collect( const size_t& id ) : Collect( id, DATASET_PATH )
544 {
545  std::filesystem::create_directories(metapath_);
546  std::filesystem::create_directories(filespath_);
547 };
548 ///////////////////////////////////////////////////////////////////////////////
549 template<class T>
550 Collect<T>::Collect( const arma::Row<size_t>& ids ) : size_(ids.n_elem),
551  keys_(ids),
552  path_(DATASET_PATH)
553 {
554  std::filesystem::create_directories(metapath_);
555  std::filesystem::create_directories(filespath_);
556 };
557 ///////////////////////////////////////////////////////////////////////////////
558 template<class T>
559 Collect<T>::Collect( const size_t& id, const std::filesystem::path& path ) :
560  id_(id), path_(path)
561 {
562  std::filesystem::create_directories(metapath_);
563  std::filesystem::create_directories(filespath_);
564  url_ = "https://www.openml.org/api/v1/json/study/" + std::to_string(id);
565  keys_ = _getkeys();
566  size_ = keys_.n_elem;
567 }
568 ///////////////////////////////////////////////////////////////////////////////
569 template<class T>
570 Dataset<T> Collect<T>::GetNext ( )
571 {
572  return Dataset<T>(keys_[counter_++],filespath_);
573 }
574 ///////////////////////////////////////////////////////////////////////////////
575 template<class T>
576 Dataset<T> Collect<T>::GetID ( const size_t& id )
577 {
578  if (arma::any(arma::find(keys_ == id)))
579  return Dataset<T>(id);
580  else
581  {
582  ERR("Collect: cannot find the dataset, giving you the next instead...");
583  return GetNext();
584  }
585 }
586 ///////////////////////////////////////////////////////////////////////////////
587 template<class T>
588 arma::Row<size_t> Collect<T>::_getkeys()
589 {
590 
591  // Function to fetch metadata from OpenML
592  CURL* curl;
593  CURLcode res;
594  std::string readBuffer;
595 
596  curl_global_init(CURL_GLOBAL_DEFAULT);
597  curl = curl_easy_init();
598  if(curl)
599  {
600  curl_easy_setopt(curl, CURLOPT_URL, url_.c_str());
601  curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, utils::WriteCallback);
602  curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer);
603  res = curl_easy_perform(curl);
604  if(res != CURLE_OK)
605  ERR("Collect:curl_easy_perform() failed: " << curl_easy_strerror(res));
606  curl_easy_cleanup(curl);
607  }
608  curl_global_cleanup();
609 
610 
611  if(res == CURLE_OK)
612  {
613  // Save readBuffer to a text file
614  std::ofstream outFile(metafile_);
615  if (outFile.is_open())
616  {
617  outFile << readBuffer;
618  outFile.close();
619  }
620  else
621  ERR("Collect:Unable to open file for writing.");
622  }
623  else
624  ERR("Collect:Not Saving metadata.");
625  std::vector<size_t> dataIds;
626 
627  // Regex to match the array inside "data_id": [...]
628  // I will use data_id because I want to control the splits...
629  std::regex arrayRegex(R"("data_id"\s*:\s*\[([^\]]+)\])");
630  std::smatch match;
631 
632  // If the regex finds a match for the "data_id" array
633  if (std::regex_search(readBuffer, match, arrayRegex))
634  {
635  // The first captured group (the numbers inside the brackets)
636  std::string dataIdArrayStr = match[1].str();
637 
638  // Regex to find individual numbers in the array
639  std::regex numberRegex(R"(\d+)");
640  auto numbersBegin = std::sregex_iterator(dataIdArrayStr.begin(),
641  dataIdArrayStr.end(),
642  numberRegex);
643  auto numbersEnd = std::sregex_iterator();
644 
645  // Iterate over each match (number) and add it to the vector
646  for (std::sregex_iterator i = numbersBegin; i != numbersEnd; ++i)
647  {
648  int dataId = std::stoi((*i).str());
649  dataIds.push_back(dataId);
650  }
651  }
652  return arma::conv_to<arma::Row<size_t>>::from(dataIds);
653 }
654 
655 } // namespace oml
656 
657 
658 
659 #endif