proconlib

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub anqooqie/proconlib

:heavy_check_mark: tests/segmented_sieve/comprehensive.test.cpp

Depends on

Code

// competitive-verifier: STANDALONE

#include <algorithm>
#include <iostream>
#include <map>
#include <vector>
#include "tools/assert_that.hpp"
#include "tools/segmented_sieve.hpp"

using ll = long long;

bool naive_is_prime(const ll n) {
  if (n < 2) return false;
  for (ll d = 2; d * d <= n; ++d) {
    if (n % d == 0) return false;
  }
  return true;
}

std::vector<ll> naive_prime_factors(ll n) {
  std::vector<ll> result;
  for (ll d = 2; d * d <= n; ++d) {
    while (n % d == 0) {
      result.push_back(d);
      n /= d;
    }
  }
  if (n > 1) result.push_back(n);
  return result;
}

std::map<ll, ll> naive_distinct_prime_factors(ll n) {
  std::map<ll, ll> result;
  for (ll d = 2; d * d <= n; ++d) {
    while (n % d == 0) {
      ++result[d];
      n /= d;
    }
  }
  if (n > 1) ++result[n];
  return result;
}

std::vector<ll> naive_divisors(const ll n) {
  std::vector<ll> result;
  for (ll d = 1; d * d <= n; ++d) {
    if (n % d == 0) {
      result.push_back(d);
      if (d != n / d) result.push_back(n / d);
    }
  }
  std::ranges::sort(result);
  return result;
}

ll naive_divisor_count(const ll n) {
  if (n <= 0) return 0;
  ll count = 0;
  for (ll d = 1; d * d <= n; ++d) {
    if (n % d == 0) {
      count += (d == n / d) ? 1 : 2;
    }
  }
  return count;
}

void test_sieve(const ll L, const ll R) {
  tools::segmented_sieve sieve(L, R);
  const ll sq = sieve.sqrt_r();

  // sqrt_r, l, r
  assert_that(sq * sq <= R && (sq + 1) * (sq + 1) > R);
  assert_that(sieve.l() == L);
  assert_that(sieve.r() == R);

  // is_prime: small range [1, sqrt_r]
  for (ll n = 1; n <= sq; ++n) {
    assert_that(sieve.is_prime(n) == naive_is_prime(n));
  }

  // is_prime: large range [L, R]
  for (ll n = L; n <= R; ++n) {
    assert_that(sieve.is_prime(n) == naive_is_prime(n));
  }

  // prime_factor_range: small range
  for (ll n = 1; n <= sq; ++n) {
    std::vector<ll> got(sieve.prime_factor_range(n).begin(), sieve.prime_factor_range(n).end());
    std::vector<ll> expected = naive_prime_factors(n);
    std::ranges::sort(got);
    assert_that(got == expected);
  }

  // prime_factor_range: large range
  for (ll n = L; n <= R; ++n) {
    std::vector<ll> got(sieve.prime_factor_range(n).begin(), sieve.prime_factor_range(n).end());
    std::vector<ll> expected = naive_prime_factors(n);
    std::ranges::sort(got);
    assert_that(got == expected);
  }

  // distinct_prime_factor_range: small range
  for (ll n = 1; n <= sq; ++n) {
    std::map<ll, ll> expected = naive_distinct_prime_factors(n);
    std::map<ll, ll> got;
    for (const auto& [p, q, pq] : sieve.distinct_prime_factor_range(n)) {
      got[p] = q;
      ll power = 1;
      for (ll i = 0; i < q; ++i) power *= p;
      assert_that(pq == power);
    }
    assert_that(got == expected);
  }

  // distinct_prime_factor_range: large range
  for (ll n = L; n <= R; ++n) {
    std::map<ll, ll> expected = naive_distinct_prime_factors(n);
    std::map<ll, ll> got;
    for (const auto& [p, q, pq] : sieve.distinct_prime_factor_range(n)) {
      got[p] = q;
      ll power = 1;
      for (ll i = 0; i < q; ++i) power *= p;
      assert_that(pq == power);
    }
    assert_that(got == expected);
  }

  // divisors and sorted_divisors: small range
  for (ll n = 1; n <= sq; ++n) {
    std::vector<ll> expected = naive_divisors(n);
    assert_that(sieve.sorted_divisors(n) == expected);
    std::vector<ll> unsorted = sieve.divisors(n);
    std::ranges::sort(unsorted);
    assert_that(unsorted == expected);
  }

  // divisors and sorted_divisors: large range
  for (ll n = L; n <= R; ++n) {
    std::vector<ll> expected = naive_divisors(n);
    assert_that(sieve.sorted_divisors(n) == expected);
    std::vector<ll> unsorted = sieve.divisors(n);
    std::ranges::sort(unsorted);
    assert_that(unsorted == expected);
  }

  // divisor_counts
  {
    auto [small_dc, large_dc] = sieve.divisor_counts();
    assert_that(static_cast<ll>(small_dc.size()) == sq + 1);
    assert_that(static_cast<ll>(large_dc.size()) == R - L + 1);
    for (ll i = 0; i <= sq; ++i) {
      assert_that(small_dc[i] == naive_divisor_count(i));
    }
    for (ll n = L; n <= R; ++n) {
      assert_that(large_dc[n - L] == naive_divisor_count(n));
    }
  }
}

void test_prime_range(const ll L, const ll R) {
  tools::segmented_sieve sieve(L, R);
  const ll sq = sieve.sqrt_r();
  const bool overlapping = sq + 1 >= L;

  // Helper to get expected primes in [lo, hi]
  auto expected_primes = [&](ll lo, ll hi) {
    std::vector<ll> result;
    for (ll n = lo; n <= hi; ++n) {
      if (naive_is_prime(n)) result.push_back(n);
    }
    return result;
  };

  if (overlapping) {
    // Entirely small
    if (sq >= 2) {
      std::vector<ll> got;
      for (const auto p : sieve.prime_range(1, std::min(sq, R))) got.push_back(p);
      assert_that(got == expected_primes(1, std::min(sq, R)));
    }

    // Entirely large (if range extends beyond sqrt_r)
    if (R > sq) {
      std::vector<ll> got;
      for (const auto p : sieve.prime_range(sq + 1, R)) got.push_back(p);
      assert_that(got == expected_primes(sq + 1, R));
    }

    // Cross-boundary
    if (sq >= 2 && R > sq) {
      std::vector<ll> got;
      for (const auto p : sieve.prime_range(2, R)) got.push_back(p);
      assert_that(got == expected_primes(2, R));
    }

    // Full range
    {
      std::vector<ll> got;
      for (const auto p : sieve.prime_range(1, R)) got.push_back(p);
      assert_that(got == expected_primes(1, R));
    }

    // Subranges around sqrt_r boundary
    if (sq >= 3 && R > sq + 2) {
      std::vector<ll> got;
      for (const auto p : sieve.prime_range(sq - 2, sq + 3)) got.push_back(p);
      assert_that(got == expected_primes(sq - 2, sq + 3));
    }
  } else {
    // Disjoint ranges: can only query within [1, sqrt_r] or [L, R]
    if (sq >= 2) {
      std::vector<ll> got;
      for (const auto p : sieve.prime_range(1, sq)) got.push_back(p);
      assert_that(got == expected_primes(1, sq));
    }
    {
      std::vector<ll> got;
      for (const auto p : sieve.prime_range(L, R)) got.push_back(p);
      assert_that(got == expected_primes(L, R));
    }
  }
}

int main() {
  std::cin.tie(nullptr);
  std::ios_base::sync_with_stdio(false);

  // Edge cases: very small R (small_primes may be empty)
  test_sieve(1, 1);
  test_sieve(1, 2);
  test_sieve(1, 3);
  test_sieve(1, 4);

  // L=1, moderate R (overlapping ranges)
  test_sieve(1, 100);
  test_sieve(1, 200);

  // L>1, still overlapping with small range
  test_sieve(2, 100);
  test_sieve(10, 200);

  // L > sqrt_r (disjoint ranges)
  test_sieve(50, 60);
  test_sieve(9990, 10010);

  // Large values
  test_sieve(999999999990LL, 1000000000000LL);

  // Single element ranges
  test_sieve(1, 1);
  test_sieve(7, 7);
  test_sieve(97, 97);
  test_sieve(100, 100);
  test_sieve(999999999989LL, 999999999989LL);

  // prime_range tests (separate because of constraint differences)
  test_prime_range(1, 1);
  test_prime_range(1, 2);
  test_prime_range(1, 3);
  test_prime_range(1, 4);
  test_prime_range(1, 100);
  test_prime_range(1, 200);
  test_prime_range(2, 100);
  test_prime_range(10, 200);
  test_prime_range(50, 60);
  test_prime_range(9990, 10010);
  test_prime_range(999999999990LL, 1000000000000LL);

  return 0;
}
#line 1 "tests/segmented_sieve/comprehensive.test.cpp"
// competitive-verifier: STANDALONE

#include <algorithm>
#include <iostream>
#include <map>
#include <vector>
#line 1 "tools/assert_that.hpp"



#line 5 "tools/assert_that.hpp"
#include <cstdlib>

#define assert_that_impl(cond, file, line, func) do {\
  if (!cond) {\
    std::cerr << file << ':' << line << ": " << func << ": Assertion `" << #cond << "' failed." << '\n';\
    std::exit(EXIT_FAILURE);\
  }\
} while (false)
#define assert_that(...) assert_that_impl((__VA_ARGS__), __FILE__, __LINE__, __func__)


#line 1 "tools/segmented_sieve.hpp"



#line 5 "tools/segmented_sieve.hpp"
#include <cassert>
#include <cstddef>
#include <iterator>
#include <ranges>
#include <tuple>
#include <utility>
#line 1 "tools/block_ceil.hpp"



#line 5 "tools/block_ceil.hpp"
#include <type_traits>
#line 1 "tools/ceil.hpp"



#line 1 "tools/non_bool_integral.hpp"



#include <concepts>
#line 1 "tools/integral.hpp"



#line 1 "tools/is_integral.hpp"



#line 5 "tools/is_integral.hpp"

namespace tools {
  template <typename T>
  struct is_integral : std::is_integral<T> {};

  template <typename T>
  inline constexpr bool is_integral_v = tools::is_integral<T>::value;
}


#line 5 "tools/integral.hpp"

namespace tools {
  template <typename T>
  concept integral = tools::is_integral_v<T>;
}


#line 7 "tools/non_bool_integral.hpp"

namespace tools {
  template <typename T>
  concept non_bool_integral = tools::integral<T> && !std::same_as<std::remove_cv_t<T>, bool>;
}


#line 1 "tools/is_unsigned.hpp"



#line 5 "tools/is_unsigned.hpp"

namespace tools {
  template <typename T>
  struct is_unsigned : std::is_unsigned<T> {};

  template <typename T>
  inline constexpr bool is_unsigned_v = tools::is_unsigned<T>::value;
}


#line 8 "tools/ceil.hpp"

namespace tools {
  template <tools::non_bool_integral M, tools::non_bool_integral N>
  constexpr std::common_type_t<M, N> ceil(const M x, const N y) noexcept {
    assert(y != 0);
    if (y >= 0) {
      if (x > 0) {
        return (x - 1) / y + 1;
      } else {
        if constexpr (tools::is_unsigned_v<std::common_type_t<M, N>>) {
          return 0;
        } else {
          return x / y;
        }
      }
    } else {
      if (x >= 0) {
        if constexpr (tools::is_unsigned_v<std::common_type_t<M, N>>) {
          return 0;
        } else {
          return x / y;
        }
      } else {
        return (x + 1) / y + 1;
      }
    }
  }
}


#line 7 "tools/block_ceil.hpp"

namespace tools {
  template <typename M, typename N>
  constexpr std::common_type_t<M, N> block_ceil(const M x, const N y) noexcept {
    assert(y > 0);
    return tools::ceil(x, y) * y;
  }
}


#line 1 "tools/floor_sqrt.hpp"



#line 5 "tools/floor_sqrt.hpp"

namespace tools {

  template <typename T>
  T floor_sqrt(const T n) {
    assert(n >= 0);

    T ok = 0;
    T ng;
    for (ng = 1; ng <= n / ng; ng *= 2);

    while (ng - ok > 1) {
      const T mid = ok + (ng - ok) / 2;
      if (mid <= n / mid) {
        ok = mid;
      } else {
        ng = mid;
      }
    }

    return ok;
  }
}


#line 14 "tools/segmented_sieve.hpp"

namespace tools {
  class segmented_sieve {
    std::vector<int> m_small_primes;
    std::vector<std::tuple<int, int, int>> m_small_factors;
    std::vector<long long> m_large_primes;
    std::vector<std::tuple<long long, int, long long, int>> m_large_factors;
    long long m_l;
    long long m_r;

  public:
    class prime_factor_view : public std::ranges::view_interface<prime_factor_view> {
      tools::segmented_sieve const *m_parent;
      long long m_n;

    public:
      class iterator {
        tools::segmented_sieve const *m_parent;
        int m_i;
        int m_j;

      public:
        using difference_type = std::ptrdiff_t;
        using value_type = long long;
        using reference = long long;
        using pointer = const long long*;
        using iterator_category = std::input_iterator_tag;

        iterator() = default;
        iterator(tools::segmented_sieve const * const parent, const int i, const int j) : m_parent(parent), m_i(i), m_j(j) {
        }

        reference operator*() const {
          if (this->m_i >= 0) {
            return std::get<0>(this->m_parent->m_small_factors[this->m_i]);
          } else {
            return std::get<0>(this->m_parent->m_large_factors[~this->m_i]);
          }
        }
        iterator& operator++() {
          if (this->m_i >= 0) {
            ++this->m_j;
            if (this->m_j >= std::get<1>(this->m_parent->m_small_factors[this->m_i])) {
              this->m_i /= std::get<2>(this->m_parent->m_small_factors[this->m_i]);
              this->m_j = 0;
            }
          } else {
            ++this->m_j;
            if (this->m_j >= std::get<1>(this->m_parent->m_large_factors[~this->m_i])) {
              this->m_i = ~std::get<3>(this->m_parent->m_large_factors[~this->m_i]);
              this->m_j = 0;
            }
          }
          return *this;
        }
        iterator operator++(int) {
          const auto self = *this;
          ++*this;
          return self;
        }
        friend bool operator==(const iterator lhs, const iterator rhs) {
          assert(lhs.m_parent == rhs.m_parent);
          return lhs.m_i == rhs.m_i && lhs.m_j == rhs.m_j;
        }
        friend bool operator!=(const iterator lhs, const iterator rhs) {
          return !(lhs == rhs);
        }
      };

      prime_factor_view() = default;
      prime_factor_view(tools::segmented_sieve const * const parent, const long long n) : m_parent(parent), m_n(n) {
      }

      iterator begin() const {
        return iterator(this->m_parent, this->m_n <= this->m_parent->sqrt_r() ? this->m_n : ~(this->m_n - this->m_parent->m_l), 0);
      };
      iterator end() const {
        return iterator(this->m_parent, 1, 0);
      }
    };

    class distinct_prime_factor_view : public std::ranges::view_interface<distinct_prime_factor_view> {
      tools::segmented_sieve const *m_parent;
      long long m_n;

    public:
      class iterator {
        tools::segmented_sieve const *m_parent;
        int m_i;

      public:
        using difference_type = std::ptrdiff_t;
        using value_type = std::tuple<long long, long long, long long>;
        using reference = std::tuple<long long, long long, long long>;
        using pointer = const std::tuple<long long, long long, long long>*;
        using iterator_category = std::input_iterator_tag;

        iterator() = default;
        iterator(tools::segmented_sieve const * const parent, const int i) : m_parent(parent), m_i(i) {
        }

        reference operator*() const {
          if (this->m_i >= 0) {
            return this->m_parent->m_small_factors[this->m_i];
          } else {
            [[maybe_unused]] const auto& [p, q, pq, next_i] = this->m_parent->m_large_factors[~this->m_i];
            return value_type(p, q, pq);
          }
        }
        iterator& operator++() {
          if (this->m_i >= 0) {
            this->m_i /= std::get<2>(this->m_parent->m_small_factors[this->m_i]);
          } else {
            this->m_i = ~std::get<3>(this->m_parent->m_large_factors[~this->m_i]);
          }
          return *this;
        }
        iterator operator++(int) {
          const auto self = *this;
          ++*this;
          return self;
        }
        friend bool operator==(const iterator lhs, const iterator rhs) {
          assert(lhs.m_parent == rhs.m_parent);
          return lhs.m_i == rhs.m_i;
        }
        friend bool operator!=(const iterator lhs, const iterator rhs) {
          return !(lhs == rhs);
        }
      };

      distinct_prime_factor_view() = default;
      distinct_prime_factor_view(tools::segmented_sieve const * const parent, const long long n) : m_parent(parent), m_n(n) {
      }

      iterator begin() const {
        return iterator(this->m_parent, this->m_n <= this->m_parent->sqrt_r() ? this->m_n : ~(this->m_n - this->m_parent->m_l));
      };
      iterator end() const {
        return iterator(this->m_parent, 1);
      }
    };

    class prime_view : public std::ranges::view_interface<prime_view> {
      tools::segmented_sieve const *m_parent;
      int m_begin;
      int m_end;

    public:
      class iterator {
        tools::segmented_sieve const *m_parent;
        int m_i;

      public:
        using difference_type = std::ptrdiff_t;
        using value_type = long long;
        using reference = long long;
        using pointer = const long long*;
        using iterator_category = std::input_iterator_tag;

        iterator() = default;
        iterator(tools::segmented_sieve const * const parent, const int i) : m_parent(parent), m_i(i) {
        }

        reference operator*() const {
          if (this->m_i >= 0) {
            return this->m_parent->m_small_primes[this->m_i];
          } else {
            return this->m_parent->m_large_primes[~this->m_i];
          }
        }
        iterator& operator++() {
          if (this->m_i >= 0) {
            ++this->m_i;
            if (this->m_i >= std::ssize(this->m_parent->m_small_primes)) {
              this->m_i = ~std::distance(this->m_parent->m_large_primes.begin(), std::ranges::upper_bound(this->m_parent->m_large_primes, this->m_parent->sqrt_r()));
            }
          } else {
            --this->m_i;
          }
          return *this;
        }
        iterator operator++(int) {
          const auto self = *this;
          ++*this;
          return self;
        }
        friend bool operator==(const iterator lhs, const iterator rhs) {
          assert(lhs.m_parent == rhs.m_parent);
          return lhs.m_i == rhs.m_i;
        }
        friend bool operator!=(const iterator lhs, const iterator rhs) {
          return !(lhs == rhs);
        }
      };

      prime_view() = default;
      prime_view(tools::segmented_sieve const * const parent, const long long l, const long long r) :
        m_parent(parent),
        m_begin(
          !parent->m_small_primes.empty() && l <= parent->m_small_primes.back()
            ? std::distance(parent->m_small_primes.begin(), std::ranges::lower_bound(parent->m_small_primes, l))
            : ~std::distance(parent->m_large_primes.begin(), std::ranges::lower_bound(parent->m_large_primes, l))
        ),
        m_end(
          !parent->m_small_primes.empty() && r < parent->m_small_primes.back()
            ? std::distance(parent->m_small_primes.begin(), std::ranges::upper_bound(parent->m_small_primes, r))
            : ~std::distance(parent->m_large_primes.begin(), std::ranges::upper_bound(parent->m_large_primes, r))
        ) {
      }

      iterator begin() const {
        return iterator(this->m_parent, this->m_begin);
      };
      iterator end() const {
        return iterator(this->m_parent, this->m_end);
      }
    };

    segmented_sieve() = default;
    segmented_sieve(const long long l, const long long r) : m_small_factors(tools::floor_sqrt(r) + 1), m_large_factors(r - l + 1), m_l(l), m_r(r) {
      assert(1 <= l && l <= r);

      for (int n = 2; n <= this->sqrt_r(); ++n) {
        if (!std::get<0>(this->m_small_factors[n])) {
          this->m_small_primes.push_back(n);
          this->m_small_factors[n] = {n, 1, n};
        }
        for (auto it = this->m_small_primes.begin(); it != this->m_small_primes.end() && *it <= std::get<0>(this->m_small_factors[n]) && n * *it <= this->sqrt_r(); ++it) {
          std::get<0>(this->m_small_factors[n * *it]) = *it;
          if (*it < std::get<0>(this->m_small_factors[n])) {
            std::get<1>(this->m_small_factors[n * *it]) = 1;
            std::get<2>(this->m_small_factors[n * *it]) = *it;
          } else {
            std::get<1>(this->m_small_factors[n * *it]) = std::get<1>(this->m_small_factors[n]) + 1;
            std::get<2>(this->m_small_factors[n * *it]) = std::get<2>(this->m_small_factors[n]) * *it;
          }
        }
      }

      std::vector<long long> rem(r - l + 1);
      for (long long n = l; n <= r; ++n) {
        rem[n - l] = n;
      }
      std::vector<int> last(r - l + 1, -1);

      for (const auto p : this->m_small_primes) {
        for (long long n = tools::block_ceil(l, p); n <= r; n += p) {
          int curr;
          if (last[n - l] >= 0) {
            curr = this->m_large_factors.size();
            this->m_large_factors.emplace_back(p, 0, 1, ~1);
            std::get<3>(this->m_large_factors[last[n - l]]) = curr;
          } else {
            curr = n - l;
            this->m_large_factors[curr] = {p, 0, 1, ~1};
          }
          do {
            rem[n - l] /= p;
            ++std::get<1>(this->m_large_factors[curr]);
            std::get<2>(this->m_large_factors[curr]) *= p;
          } while (rem[n - l] % p == 0);
          last[n - l] = curr;
        }
      }
      for (long long n = l; n <= r; ++n) {
        if (last[n - l] >= 0) {
          if (rem[n - l] > 1) {
            std::get<3>(this->m_large_factors[last[n - l]]) = this->m_large_factors.size();
            this->m_large_factors.emplace_back(rem[n - l], 1, rem[n - l], ~1);
          }
        } else {
          if (n > 1) {
            this->m_large_primes.push_back(n);
            this->m_large_factors[n - l] = {n, 1, n, ~1};
          }
        }
      }
    }

    long long sqrt_r() const {
      return this->m_small_factors.size() - 1;
    }

    long long l() const {
      return this->m_l;
    }

    long long r() const {
      return this->m_r;
    }

    bool is_prime(const long long n) const {
      assert((1 <= n && n <= this->sqrt_r()) || (this->m_l <= n && n <= this->m_r));
      if (n <= this->sqrt_r()) {
        return n >= 2 && std::get<0>(this->m_small_factors[n]) == n;
      } else {
        return std::get<0>(this->m_large_factors[n - this->m_l]) == n;
      }
    }

    prime_factor_view prime_factor_range(const long long n) const {
      assert((1 <= n && n <= this->sqrt_r()) || (this->m_l <= n && n <= this->m_r));
      return prime_factor_view(this, n);
    }

    distinct_prime_factor_view distinct_prime_factor_range(const long long n) const {
      assert((1 <= n && n <= this->sqrt_r()) || (this->m_l <= n && n <= this->m_r));
      return distinct_prime_factor_view(this, n);
    }

    prime_view prime_range(const long long l, const long long r) const {
      #ifndef NDEBUG
      if (this->sqrt_r() + 1 < this->l()) {
        assert((1 <= l && l <= r && r <= this->sqrt_r()) || (this->m_l <= l && l <= r && r <= this->m_r));
      } else {
        assert(1 <= l && l <= r && r <= this->m_r);
      }
      #endif
      return prime_view(this, l, r);
    }

    std::vector<long long> divisors(const long long n) const {
      assert((1 <= n && n <= this->sqrt_r()) || (this->m_l <= n && n <= this->m_r));

      std::vector<long long> D{1};
      for ([[maybe_unused]] const auto& [p, q, pq] : this->distinct_prime_factor_range(n)) {
        const int end = D.size();
        for (long long e = 1, pe = 1; e <= q; ++e) {
          pe *= p;
          for (int i = 0; i < end; ++i) {
            D.push_back(D[i] * pe);
          }
        }
      }

      return D;
    }

    std::vector<long long> sorted_divisors(const long long n) const {
      auto D = this->divisors(n);
      std::ranges::sort(D);
      return D;
    }

    std::pair<std::vector<long long>, std::vector<long long>> divisor_counts() const {
      std::vector<std::pair<int, int>> dp(this->sqrt_r() + 1);
      dp[0] = std::make_pair(0, 0);
      dp[1] = std::make_pair(1, 1);
      for (int i = 2; i <= this->sqrt_r(); ++i) {
        const auto& prev = dp[i / std::get<0>(this->m_small_factors[i])];
        if (std::get<0>(this->m_small_factors[i / std::get<0>(this->m_small_factors[i])]) == std::get<0>(this->m_small_factors[i])) {
          dp[i] = std::make_pair(prev.first + 1, prev.second);
        } else {
          dp[i] = std::make_pair(2, prev.first * prev.second);
        }
      }

      std::vector<long long> small(this->sqrt_r() + 1);
      for (int i = 0; i <= this->sqrt_r(); ++i) {
        small[i] = dp[i].first * dp[i].second;
      }

      std::vector<long long> large(this->m_r - this->m_l + 1);
      for (long long n = this->m_l; n <= this->m_r; ++n) {
        large[n - this->m_l] = 1;
        for ([[maybe_unused]] const auto& [p, q, pq] : this->distinct_prime_factor_range(n)) {
          large[n - this->m_l] *= q + 1;
        }
      }

      return {small, large};
    }
  };
}


#line 9 "tests/segmented_sieve/comprehensive.test.cpp"

using ll = long long;

bool naive_is_prime(const ll n) {
  if (n < 2) return false;
  for (ll d = 2; d * d <= n; ++d) {
    if (n % d == 0) return false;
  }
  return true;
}

std::vector<ll> naive_prime_factors(ll n) {
  std::vector<ll> result;
  for (ll d = 2; d * d <= n; ++d) {
    while (n % d == 0) {
      result.push_back(d);
      n /= d;
    }
  }
  if (n > 1) result.push_back(n);
  return result;
}

std::map<ll, ll> naive_distinct_prime_factors(ll n) {
  std::map<ll, ll> result;
  for (ll d = 2; d * d <= n; ++d) {
    while (n % d == 0) {
      ++result[d];
      n /= d;
    }
  }
  if (n > 1) ++result[n];
  return result;
}

std::vector<ll> naive_divisors(const ll n) {
  std::vector<ll> result;
  for (ll d = 1; d * d <= n; ++d) {
    if (n % d == 0) {
      result.push_back(d);
      if (d != n / d) result.push_back(n / d);
    }
  }
  std::ranges::sort(result);
  return result;
}

ll naive_divisor_count(const ll n) {
  if (n <= 0) return 0;
  ll count = 0;
  for (ll d = 1; d * d <= n; ++d) {
    if (n % d == 0) {
      count += (d == n / d) ? 1 : 2;
    }
  }
  return count;
}

void test_sieve(const ll L, const ll R) {
  tools::segmented_sieve sieve(L, R);
  const ll sq = sieve.sqrt_r();

  // sqrt_r, l, r
  assert_that(sq * sq <= R && (sq + 1) * (sq + 1) > R);
  assert_that(sieve.l() == L);
  assert_that(sieve.r() == R);

  // is_prime: small range [1, sqrt_r]
  for (ll n = 1; n <= sq; ++n) {
    assert_that(sieve.is_prime(n) == naive_is_prime(n));
  }

  // is_prime: large range [L, R]
  for (ll n = L; n <= R; ++n) {
    assert_that(sieve.is_prime(n) == naive_is_prime(n));
  }

  // prime_factor_range: small range
  for (ll n = 1; n <= sq; ++n) {
    std::vector<ll> got(sieve.prime_factor_range(n).begin(), sieve.prime_factor_range(n).end());
    std::vector<ll> expected = naive_prime_factors(n);
    std::ranges::sort(got);
    assert_that(got == expected);
  }

  // prime_factor_range: large range
  for (ll n = L; n <= R; ++n) {
    std::vector<ll> got(sieve.prime_factor_range(n).begin(), sieve.prime_factor_range(n).end());
    std::vector<ll> expected = naive_prime_factors(n);
    std::ranges::sort(got);
    assert_that(got == expected);
  }

  // distinct_prime_factor_range: small range
  for (ll n = 1; n <= sq; ++n) {
    std::map<ll, ll> expected = naive_distinct_prime_factors(n);
    std::map<ll, ll> got;
    for (const auto& [p, q, pq] : sieve.distinct_prime_factor_range(n)) {
      got[p] = q;
      ll power = 1;
      for (ll i = 0; i < q; ++i) power *= p;
      assert_that(pq == power);
    }
    assert_that(got == expected);
  }

  // distinct_prime_factor_range: large range
  for (ll n = L; n <= R; ++n) {
    std::map<ll, ll> expected = naive_distinct_prime_factors(n);
    std::map<ll, ll> got;
    for (const auto& [p, q, pq] : sieve.distinct_prime_factor_range(n)) {
      got[p] = q;
      ll power = 1;
      for (ll i = 0; i < q; ++i) power *= p;
      assert_that(pq == power);
    }
    assert_that(got == expected);
  }

  // divisors and sorted_divisors: small range
  for (ll n = 1; n <= sq; ++n) {
    std::vector<ll> expected = naive_divisors(n);
    assert_that(sieve.sorted_divisors(n) == expected);
    std::vector<ll> unsorted = sieve.divisors(n);
    std::ranges::sort(unsorted);
    assert_that(unsorted == expected);
  }

  // divisors and sorted_divisors: large range
  for (ll n = L; n <= R; ++n) {
    std::vector<ll> expected = naive_divisors(n);
    assert_that(sieve.sorted_divisors(n) == expected);
    std::vector<ll> unsorted = sieve.divisors(n);
    std::ranges::sort(unsorted);
    assert_that(unsorted == expected);
  }

  // divisor_counts
  {
    auto [small_dc, large_dc] = sieve.divisor_counts();
    assert_that(static_cast<ll>(small_dc.size()) == sq + 1);
    assert_that(static_cast<ll>(large_dc.size()) == R - L + 1);
    for (ll i = 0; i <= sq; ++i) {
      assert_that(small_dc[i] == naive_divisor_count(i));
    }
    for (ll n = L; n <= R; ++n) {
      assert_that(large_dc[n - L] == naive_divisor_count(n));
    }
  }
}

void test_prime_range(const ll L, const ll R) {
  tools::segmented_sieve sieve(L, R);
  const ll sq = sieve.sqrt_r();
  const bool overlapping = sq + 1 >= L;

  // Helper to get expected primes in [lo, hi]
  auto expected_primes = [&](ll lo, ll hi) {
    std::vector<ll> result;
    for (ll n = lo; n <= hi; ++n) {
      if (naive_is_prime(n)) result.push_back(n);
    }
    return result;
  };

  if (overlapping) {
    // Entirely small
    if (sq >= 2) {
      std::vector<ll> got;
      for (const auto p : sieve.prime_range(1, std::min(sq, R))) got.push_back(p);
      assert_that(got == expected_primes(1, std::min(sq, R)));
    }

    // Entirely large (if range extends beyond sqrt_r)
    if (R > sq) {
      std::vector<ll> got;
      for (const auto p : sieve.prime_range(sq + 1, R)) got.push_back(p);
      assert_that(got == expected_primes(sq + 1, R));
    }

    // Cross-boundary
    if (sq >= 2 && R > sq) {
      std::vector<ll> got;
      for (const auto p : sieve.prime_range(2, R)) got.push_back(p);
      assert_that(got == expected_primes(2, R));
    }

    // Full range
    {
      std::vector<ll> got;
      for (const auto p : sieve.prime_range(1, R)) got.push_back(p);
      assert_that(got == expected_primes(1, R));
    }

    // Subranges around sqrt_r boundary
    if (sq >= 3 && R > sq + 2) {
      std::vector<ll> got;
      for (const auto p : sieve.prime_range(sq - 2, sq + 3)) got.push_back(p);
      assert_that(got == expected_primes(sq - 2, sq + 3));
    }
  } else {
    // Disjoint ranges: can only query within [1, sqrt_r] or [L, R]
    if (sq >= 2) {
      std::vector<ll> got;
      for (const auto p : sieve.prime_range(1, sq)) got.push_back(p);
      assert_that(got == expected_primes(1, sq));
    }
    {
      std::vector<ll> got;
      for (const auto p : sieve.prime_range(L, R)) got.push_back(p);
      assert_that(got == expected_primes(L, R));
    }
  }
}

int main() {
  std::cin.tie(nullptr);
  std::ios_base::sync_with_stdio(false);

  // Edge cases: very small R (small_primes may be empty)
  test_sieve(1, 1);
  test_sieve(1, 2);
  test_sieve(1, 3);
  test_sieve(1, 4);

  // L=1, moderate R (overlapping ranges)
  test_sieve(1, 100);
  test_sieve(1, 200);

  // L>1, still overlapping with small range
  test_sieve(2, 100);
  test_sieve(10, 200);

  // L > sqrt_r (disjoint ranges)
  test_sieve(50, 60);
  test_sieve(9990, 10010);

  // Large values
  test_sieve(999999999990LL, 1000000000000LL);

  // Single element ranges
  test_sieve(1, 1);
  test_sieve(7, 7);
  test_sieve(97, 97);
  test_sieve(100, 100);
  test_sieve(999999999989LL, 999999999989LL);

  // prime_range tests (separate because of constraint differences)
  test_prime_range(1, 1);
  test_prime_range(1, 2);
  test_prime_range(1, 3);
  test_prime_range(1, 4);
  test_prime_range(1, 100);
  test_prime_range(1, 200);
  test_prime_range(2, 100);
  test_prime_range(10, 200);
  test_prime_range(50, 60);
  test_prime_range(9990, 10010);
  test_prime_range(999999999990LL, 1000000000000LL);

  return 0;
}
Back to top page