robust_statistics.h 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. // Copyright 2023 Google LLC
  2. // SPDX-License-Identifier: Apache-2.0
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. //
  8. // http://www.apache.org/licenses/LICENSE-2.0
  9. //
  10. // Unless required by applicable law or agreed to in writing, software
  11. // distributed under the License is distributed on an "AS IS" BASIS,
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. // See the License for the specific language governing permissions and
  14. // limitations under the License.
  15. #ifndef HIGHWAY_HWY_ROBUST_STATISTICS_H_
  16. #define HIGHWAY_HWY_ROBUST_STATISTICS_H_
  17. #include <algorithm> // std::sort, std::find_if
  18. #include <limits>
  19. #include <utility> // std::pair
  20. #include <vector>
  21. #include "hwy/base.h"
  22. namespace hwy {
  23. namespace robust_statistics {
  24. // Sorts integral values in ascending order (e.g. for Mode). About 3x faster
  25. // than std::sort for input distributions with very few unique values.
  26. template <class T>
  27. void CountingSort(T* values, size_t num_values) {
  28. // Unique values and their frequency (similar to flat_map).
  29. using Unique = std::pair<T, int>;
  30. std::vector<Unique> unique;
  31. for (size_t i = 0; i < num_values; ++i) {
  32. const T value = values[i];
  33. const auto pos =
  34. std::find_if(unique.begin(), unique.end(),
  35. [value](const Unique u) { return u.first == value; });
  36. if (pos == unique.end()) {
  37. unique.push_back(std::make_pair(value, 1));
  38. } else {
  39. ++pos->second;
  40. }
  41. }
  42. // Sort in ascending order of value (pair.first).
  43. std::sort(unique.begin(), unique.end());
  44. // Write that many copies of each unique value to the array.
  45. T* HWY_RESTRICT p = values;
  46. for (const auto& value_count : unique) {
  47. std::fill(p, p + value_count.second, value_count.first);
  48. p += value_count.second;
  49. }
  50. HWY_ASSERT(p == values + num_values);
  51. }
  52. // @return i in [idx_begin, idx_begin + half_count) that minimizes
  53. // sorted[i + half_count] - sorted[i].
  54. template <typename T>
  55. size_t MinRange(const T* const HWY_RESTRICT sorted, const size_t idx_begin,
  56. const size_t half_count) {
  57. T min_range = std::numeric_limits<T>::max();
  58. size_t min_idx = 0;
  59. for (size_t idx = idx_begin; idx < idx_begin + half_count; ++idx) {
  60. HWY_ASSERT(sorted[idx] <= sorted[idx + half_count]);
  61. const T range = sorted[idx + half_count] - sorted[idx];
  62. if (range < min_range) {
  63. min_range = range;
  64. min_idx = idx;
  65. }
  66. }
  67. return min_idx;
  68. }
  69. // Returns an estimate of the mode by calling MinRange on successively
  70. // halved intervals. "sorted" must be in ascending order. This is the
  71. // Half Sample Mode estimator proposed by Bickel in "On a fast, robust
  72. // estimator of the mode", with complexity O(N log N). The mode is less
  73. // affected by outliers in highly-skewed distributions than the median.
  74. // The averaging operation below assumes "T" is an unsigned integer type.
  75. template <typename T>
  76. T ModeOfSorted(const T* const HWY_RESTRICT sorted, const size_t num_values) {
  77. size_t idx_begin = 0;
  78. size_t half_count = num_values / 2;
  79. while (half_count > 1) {
  80. idx_begin = MinRange(sorted, idx_begin, half_count);
  81. half_count >>= 1;
  82. }
  83. const T x = sorted[idx_begin + 0];
  84. if (half_count == 0) {
  85. return x;
  86. }
  87. HWY_ASSERT(half_count == 1);
  88. const T average = (x + sorted[idx_begin + 1] + 1) / 2;
  89. return average;
  90. }
  91. // Returns the mode. Side effect: sorts "values".
  92. template <typename T>
  93. T Mode(T* values, const size_t num_values) {
  94. CountingSort(values, num_values);
  95. return ModeOfSorted(values, num_values);
  96. }
  97. template <typename T, size_t N>
  98. T Mode(T (&values)[N]) {
  99. return Mode(&values[0], N);
  100. }
  101. // Returns the median value. Side effect: sorts "values".
  102. template <typename T>
  103. T Median(T* values, const size_t num_values) {
  104. HWY_ASSERT(num_values != 0);
  105. std::sort(values, values + num_values);
  106. const size_t half = num_values / 2;
  107. // Odd count: return middle
  108. if (num_values % 2) {
  109. return values[half];
  110. }
  111. // Even count: return average of middle two.
  112. return (values[half] + values[half - 1] + 1) / 2;
  113. }
  114. // Returns a robust measure of variability.
  115. template <typename T>
  116. T MedianAbsoluteDeviation(const T* values, const size_t num_values,
  117. const T median) {
  118. HWY_ASSERT(num_values != 0);
  119. std::vector<T> abs_deviations;
  120. abs_deviations.reserve(num_values);
  121. for (size_t i = 0; i < num_values; ++i) {
  122. const int64_t abs = ScalarAbs(static_cast<int64_t>(values[i]) -
  123. static_cast<int64_t>(median));
  124. abs_deviations.push_back(static_cast<T>(abs));
  125. }
  126. return Median(abs_deviations.data(), num_values);
  127. }
  128. } // namespace robust_statistics
  129. } // namespace hwy
  130. #endif // HIGHWAY_HWY_ROBUST_STATISTICS_H_