proconlib

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub anqooqie/proconlib

:heavy_check_mark: tests/auxiliary_tree.test.cpp

Depends on

Code

// competitive-verifier: PROBLEM https://onlinejudge.u-aizu.ac.jp/problems/0439

#include <iostream>
#include <vector>
#include <utility>
#include "tools/auxiliary_tree.hpp"
#include "tools/rerooting_dp.hpp"
#include "tools/monoid.hpp"

int main() {
  std::cin.tie(nullptr);
  std::ios_base::sync_with_stdio(false);

  int N;
  std::cin >> N;
  std::vector<int> c(N);
  for (auto&& c_i : c) {
    std::cin >> c_i;
    --c_i;
  }
  tools::auxiliary_tree tree(N);
  for (int i = 0; i < N - 1; ++i) {
    int s, t;
    std::cin >> s >> t;
    --s, --t;
    tree.add_edge(s, t);
  }
  tree.build(0);

  std::vector<std::vector<int>> huts(N);
  for (int v = 0; v < N; ++v) {
    huts[c[v]].push_back(v);
  }

  std::vector<int> answers(N);
  std::vector<int> tree2aux(N);
  for (int color = 0; color < N; ++color) {
    if (huts[color].empty()) continue;

    const auto aux = tree.query(huts[color].begin(), huts[color].end());
    std::vector<int> aux2tree(aux.size());
    {
      int aux_v = 0;
      for (const auto tree_v : aux.vertices()) {
        tree2aux[tree_v] = aux_v;
        aux2tree[aux_v] = tree_v;
        ++aux_v;
      }
    }

    std::vector<int> w;
    const auto f_ve = [&](const auto& v, const auto e) {
      return (v.second ? 0 : v.first) + w[e];
    };
    const auto f_ev = [&](const auto e, const auto v) {
      return std::make_pair(e, c[aux2tree[v]] == color);
    };
    tools::rerooting_dp<std::pair<int, bool>, tools::monoid::min<int>, decltype(f_ve), decltype(f_ev)> dp(aux.size(), f_ve, f_ev);
    for (const auto tree_v : aux.vertices()) {
      if (tree_v == aux.root()) continue;
      dp.add_edge(tree2aux[tree_v], tree2aux[aux.parent(tree_v)]);
      const auto lca = tree.lca(tree_v, aux.parent(tree_v));
      w.push_back(tree.depth(tree_v) + tree.depth(aux.parent(tree_v)) - 2 * tree.depth(lca));
    }
    const auto partial_answers = dp.query();

    for (int aux_v = 0; std::cmp_less(aux_v, aux.size()); ++aux_v) {
      if (partial_answers[aux_v].second) {
        answers[aux2tree[aux_v]] = partial_answers[aux_v].first;
      }
    }
  }

  for (const auto answer : answers) {
    std::cout << answer << '\n';
  }
  return 0;
}
#line 1 "tests/auxiliary_tree.test.cpp"
// competitive-verifier: PROBLEM https://onlinejudge.u-aizu.ac.jp/problems/0439

#include <iostream>
#include <vector>
#include <utility>
#line 1 "tools/auxiliary_tree.hpp"



#include <cstddef>
#line 7 "tools/auxiliary_tree.hpp"
#include <algorithm>
#include <stack>
#include <limits>
#include <iterator>
#include <type_traits>
#line 1 "tools/lca.hpp"



#include <cstdint>
#line 7 "tools/lca.hpp"
#include <cassert>
#include <numeric>
#line 13 "tools/lca.hpp"
#include <tuple>
#line 1 "tools/ceil.hpp"



#line 1 "tools/is_integral.hpp"



#line 5 "tools/is_integral.hpp"

namespace tools {
  template <typename T>
  struct is_integral : ::std::is_integral<T> {};

  template <typename T>
  inline constexpr bool is_integral_v = ::tools::is_integral<T>::value;
}


#line 1 "tools/is_unsigned.hpp"



#line 5 "tools/is_unsigned.hpp"

namespace tools {
  template <typename T>
  struct is_unsigned : ::std::is_unsigned<T> {};

  template <typename T>
  inline constexpr bool is_unsigned_v = ::tools::is_unsigned<T>::value;
}


#line 8 "tools/ceil.hpp"

namespace tools {
  template <typename M, typename N> requires (
    ::tools::is_integral_v<M> && !::std::is_same_v<::std::remove_cv_t<M>, bool> &&
    ::tools::is_integral_v<N> && !::std::is_same_v<::std::remove_cv_t<N>, bool>)
  constexpr ::std::common_type_t<M, N> ceil(const M x, const N y) noexcept {
    assert(y != 0);
    if (y >= 0) {
      if (x > 0) {
        return (x - 1) / y + 1;
      } else {
        if constexpr (::tools::is_unsigned_v<::std::common_type_t<M, N>>) {
          return 0;
        } else {
          return x / y;
        }
      }
    } else {
      if (x >= 0) {
        if constexpr (::tools::is_unsigned_v<::std::common_type_t<M, N>>) {
          return 0;
        } else {
          return x / y;
        }
      } else {
        return (x + 1) / y + 1;
      }
    }
  }
}


#line 1 "tools/less_by.hpp"



namespace tools {

  template <class F>
  class less_by {
  private:
    F selector;

  public:
    less_by(const F& selector) : selector(selector) {
    }

    template <class T>
    bool operator()(const T& x, const T& y) const {
      return selector(x) < selector(y);
    }
  };
}


#line 1 "tools/ceil_log2.hpp"



#line 1 "tools/bit_width.hpp"



#include <bit>
#line 1 "tools/is_signed.hpp"



#line 5 "tools/is_signed.hpp"

namespace tools {
  template <typename T>
  struct is_signed : ::std::is_signed<T> {};

  template <typename T>
  inline constexpr bool is_signed_v = ::tools::is_signed<T>::value;
}


#line 1 "tools/make_unsigned.hpp"



#line 5 "tools/make_unsigned.hpp"

namespace tools {
  template <typename T>
  struct make_unsigned : ::std::make_unsigned<T> {};

  template <typename T>
  using make_unsigned_t = typename ::tools::make_unsigned<T>::type;
}


#line 10 "tools/bit_width.hpp"

namespace tools {
  template <typename T>
  constexpr int bit_width(T) noexcept;

  template <typename T>
  constexpr int bit_width(const T x) noexcept {
    static_assert(::tools::is_integral_v<T> && !::std::is_same_v<::std::remove_cv_t<T>, bool>);
    if constexpr (::tools::is_signed_v<T>) {
      assert(x >= 0);
      return ::tools::bit_width<::tools::make_unsigned_t<T>>(x);
    } else {
      return ::std::bit_width(x);
    }
  }
}


#line 6 "tools/ceil_log2.hpp"

namespace tools {
  template <typename T>
  constexpr T ceil_log2(T x) noexcept {
    assert(x > 0);
    return ::tools::bit_width(x - 1);
  }
}


#line 1 "tools/floor_log2.hpp"



#line 6 "tools/floor_log2.hpp"

namespace tools {
  template <typename T>
  constexpr T floor_log2(T x) noexcept {
    assert(x > 0);
    return ::tools::bit_width(x) - 1;
  }
}


#line 1 "tools/pow2.hpp"



#line 6 "tools/pow2.hpp"

namespace tools {

  template <typename T, typename ::std::enable_if<::std::is_unsigned<T>::value, ::std::nullptr_t>::type = nullptr>
  constexpr T pow2(const T x) {
    return static_cast<T>(1) << x;
  }

  template <typename T, typename ::std::enable_if<::std::is_signed<T>::value, ::std::nullptr_t>::type = nullptr>
  constexpr T pow2(const T x) {
    return static_cast<T>(static_cast<typename ::std::make_unsigned<T>::type>(1) << static_cast<typename ::std::make_unsigned<T>::type>(x));
  }
}


#line 19 "tools/lca.hpp"

namespace tools {
  class lca {
    using u32 = ::std::uint32_t;
    ::std::vector<::std::vector<u32>> m_graph;
    ::std::vector<u32> m_depth;
    ::std::vector<u32> m_tour;
    ::std::vector<u32> m_in;
    u32 m_block_size;
    ::std::vector<::std::vector<u32>> m_sparse_table;
    ::std::vector<::std::vector<::std::vector<u32>>> m_lookup_table;
    ::std::vector<u32> m_patterns;

    bool built() const {
      return !this->m_depth.empty();
    }

    u32 nblocks() const {
      return ::tools::ceil(this->m_tour.size(), this->m_block_size);
    }

    auto less_by_depth() const {
      return ::tools::less_by([&](const auto v) { return this->m_depth[v]; });
    }

  public:
    lca() = default;
    explicit lca(const ::std::size_t n) : m_graph(n) {
      assert(n >= 1);
    }

    ::std::size_t size() const {
      return this->m_graph.size();
    }

    void add_edge(const ::std::size_t u, const ::std::size_t v) {
      assert(!this->built());
      assert(u < this->size());
      assert(v < this->size());
      assert(u != v);
      this->m_graph[u].push_back(v);
      this->m_graph[v].push_back(u);
    }

    void build(const ::std::size_t r) {
      assert(!this->built());
      assert(::std::accumulate(this->m_graph.begin(), this->m_graph.end(), static_cast<::std::size_t>(0), [](const auto sum, const auto& neighbors) { return sum + neighbors.size(); }) == 2 * (this->size() - 1));

      this->m_depth.assign(this->size(), ::std::numeric_limits<u32>::max());
      this->m_tour.resize(2 * this->size() - 1);
      this->m_in.resize(this->size());

      u32 t = 0;
      ::std::stack<::std::pair<u32, u32>> stack;
      stack.emplace(r, 0);
      while (!stack.empty()) {
        const auto [here, depth] = stack.top();
        stack.pop();
        this->m_tour[t] = here;
        if (this->m_depth[here] == ::std::numeric_limits<u32>::max()) {
          this->m_depth[here] = depth;
          this->m_in[here] = t;
          for (const auto next : this->m_graph[here]) {
            if (this->m_depth[next] == ::std::numeric_limits<u32>::max()) {
              stack.emplace(here, depth);
              stack.emplace(next, depth + 1);
            }
          }
        }
        ++t;
      }

      if (this->size() > 1) {
        this->m_tour.pop_back();
      }

      this->m_block_size = ::std::max<u32>(1, ::tools::ceil(::tools::ceil_log2(this->m_tour.size()), 2));
      this->m_sparse_table.resize(::tools::floor_log2(this->nblocks()) + 1);
      this->m_sparse_table[0].resize(this->nblocks());
      for (u32 b = 0; (b + 1) * this->m_block_size <= this->m_tour.size(); ++b) {
        const auto l = b * this->m_block_size;
        const auto r = ::std::min<u32>(l + this->m_block_size, this->m_tour.size());
        this->m_sparse_table[0][b] = *::std::min_element(this->m_tour.begin() + l, this->m_tour.begin() + r, this->less_by_depth());
      }
      for (u32 h = 1; h < this->m_sparse_table.size(); ++h) {
        this->m_sparse_table[h].resize(this->nblocks() + UINT32_C(1) - (UINT32_C(1) << h));
        for (u32 b = 0; b < this->m_sparse_table[h].size(); ++b) {
          this->m_sparse_table[h][b] = ::std::min(this->m_sparse_table[h - 1][b], this->m_sparse_table[h - 1][b + (UINT32_C(1) << (h - 1))], this->less_by_depth());
        }
      }

      this->m_lookup_table.resize(::tools::pow2(this->m_block_size - 1));
      for (u32 p = 0; p < this->m_lookup_table.size(); ++p) {
        this->m_lookup_table[p].resize(this->m_block_size + 1);
        for (u32 l = 0; l <= this->m_block_size; ++l) {
          this->m_lookup_table[p][l].resize(this->m_block_size + 1);
        }

        ::std::vector<u32> partial_sum(this->m_block_size);
        partial_sum[0] = this->m_block_size;
        for (u32 i = 1; i < this->m_block_size; ++i) {
          partial_sum[i] = partial_sum[i - 1] - UINT32_C(1) + (((p >> (i - 1)) & UINT32_C(1)) << 1);
        }

        for (u32 l = 0; l < this->m_block_size; ++l) {
          this->m_lookup_table[p][l][l + 1] = l;
          for (u32 r = l + 2; r <= this->m_block_size; ++r) {
            this->m_lookup_table[p][l][r] = ::std::min(this->m_lookup_table[p][l][r - 1], r - 1, ::tools::less_by([&](const auto i) { return partial_sum[i]; }));
          }
        }
      }

      this->m_patterns.resize(this->nblocks());
      for (u32 b = 0; b * this->m_block_size < this->m_tour.size(); ++b) {
        const auto l = b * this->m_block_size;
        const auto r = ::std::min<u32>(l + this->m_block_size, this->m_tour.size());
        this->m_patterns[b] = 0;
        for (u32 i = l; i + 1 < r; ++i) {
          this->m_patterns[b] |= static_cast<u32>(this->m_depth[this->m_tour[i]] < this->m_depth[this->m_tour[i + 1]]) << (i - l);
        }
      }
    }

    ::std::size_t depth(const ::std::size_t v) const {
      assert(this->built());
      assert(v < this->size());
      return this->m_depth[v];
    }

    ::std::size_t query(::std::size_t u, ::std::size_t v) const {
      assert(this->built());
      assert(u < this->size());
      assert(v < this->size());

      ::std::tie(u, v) = ::std::minmax({u, v}, ::tools::less_by([&](const auto w) { return this->m_in[w]; }));

      const auto l = this->m_in[u];
      const auto r = this->m_in[v] + UINT32_C(1);
      const auto bl = ::tools::ceil(l, this->m_block_size);
      const auto br = r / this->m_block_size;
      u32 lca;
      if (br < bl) {
        lca = this->m_tour[br * this->m_block_size + this->m_lookup_table[this->m_patterns[br]][l % this->m_block_size][r % this->m_block_size]];
      } else {
        lca = u;
        if (bl < br) {
          const auto h = ::tools::floor_log2(br - bl);
          lca = ::std::min(this->m_sparse_table[h][bl], this->m_sparse_table[h][br - (UINT32_C(1) << h)], this->less_by_depth());
        }
        if (l < bl * this->m_block_size) {
          lca = ::std::min(lca, this->m_tour[(bl - UINT32_C(1)) * this->m_block_size + this->m_lookup_table[this->m_patterns[bl - 1]][l % this->m_block_size][this->m_block_size]], this->less_by_depth());
        }
        if (br * this->m_block_size < r) {
          lca = ::std::min(lca, this->m_tour[br * this->m_block_size + this->m_lookup_table[this->m_patterns[br]][0][r % this->m_block_size]], this->less_by_depth());
        }
      }

      return lca;
    }

    // for tools::auxiliary_tree
    ::std::size_t internal_in(const ::std::size_t v) const {
      assert(this->built());
      assert(v < this->size());
      return this->m_in[v];
    }
  };
}


#line 1 "tools/less_by_first.hpp"



#line 5 "tools/less_by_first.hpp"

namespace tools {

  class less_by_first {
  public:
    template <class T1, class T2>
    bool operator()(const ::std::pair<T1, T2>& x, const ::std::pair<T1, T2>& y) const {
      return x.first < y.first;
    }
  };
}


#line 15 "tools/auxiliary_tree.hpp"

namespace tools {
  class auxiliary_tree {
    ::tools::lca m_lca;

  public:
    auxiliary_tree() = default;
    explicit auxiliary_tree(const ::std::size_t n) : m_lca(n) {
    }

    ::std::size_t size() const {
      return this->m_lca.size();
    }

    void add_edge(const ::std::size_t u, const ::std::size_t v) {
      this->m_lca.add_edge(u, v);
    }

    void build(const ::std::size_t r) {
      this->m_lca.build(r);
    }

    ::std::size_t depth(const ::std::size_t v) const {
      return this->m_lca.depth(v);
    }

    ::std::size_t lca(const ::std::size_t u, const ::std::size_t v) const {
      return this->m_lca.query(u, v);
    }

    class query_result {
      ::std::vector<::std::pair<::std::size_t, ::std::size_t>> m_parent;
      ::std::vector<::std::vector<::std::size_t>> m_children;
      ::std::size_t m_root;

      template <typename InputIterator>
      query_result(const ::tools::auxiliary_tree& tree, const InputIterator begin, const InputIterator end) {
        ::std::vector<::std::size_t> X(begin, end);
        assert(!X.empty());
        ::std::sort(X.begin(), X.end(), ::tools::less_by([&](const auto v) { return tree.m_lca.internal_in(v); }));

        ::std::stack<::std::size_t> stack;
        auto it = X.begin();
        stack.push(*(it++));
        for (; it != X.end(); ++it) {
          const auto w = tree.lca(stack.top(), *it);
          while (!stack.empty() && tree.depth(w) < tree.depth(stack.top())) {
            const auto u = stack.top();
            stack.pop();
            this->m_parent.emplace_back(u, w);
            if (!stack.empty() && tree.depth(w) < tree.depth(stack.top())) {
              this->m_parent.back().second = stack.top();
            }
          }
          if (stack.empty() || stack.top() != w) {
            stack.push(w);
          }
          stack.push(*it);
        }
        while (!stack.empty()) {
          const auto u = stack.top();
          stack.pop();
          if (stack.empty()) {
            this->m_parent.emplace_back(u, ::std::numeric_limits<::std::size_t>::max());
            this->m_root = u;
          } else {
            this->m_parent.emplace_back(u, stack.top());
          }
        }

        ::std::sort(this->m_parent.begin(), this->m_parent.end(), ::tools::less_by_first{});

        this->m_children.resize(this->m_parent.size());
        for (const auto& [v, p] : this->m_parent) {
          if (v != this->m_root) {
            const auto it = ::std::lower_bound(this->m_parent.begin(), this->m_parent.end(), ::std::make_pair(p, ::std::numeric_limits<::std::size_t>::max()), ::tools::less_by_first{});
            assert(it != this->m_parent.end());
            assert(it->first == p);
            this->m_children[::std::distance(this->m_parent.begin(), it)].push_back(v);
          }
        }
      }

    public:
      class vertices_iterable {
        query_result const *m_qr;

      public:
        class iterator {
          query_result const *m_qr;
          ::std::size_t m_i;

        public:
          using difference_type = ::std::ptrdiff_t;
          using value_type = ::std::size_t;
          using reference = const ::std::size_t&;
          using pointer = const ::std::size_t*;
          using iterator_category = ::std::input_iterator_tag;

          iterator() = default;
          iterator(query_result const * const qr, const ::std::size_t i) : m_qr(qr), m_i(i) {
          }

          reference operator*() const {
            return this->m_qr->m_parent[this->m_i].first;
          }
          iterator& operator++() {
            ++this->m_i;
            return *this;
          }
          iterator operator++(int) {
            const auto self = *this;
            ++*this;
            return self;
          }
          friend bool operator==(const iterator& lhs, const iterator& rhs) {
            assert(lhs.m_qr == rhs.m_qr);
            return lhs.m_i == rhs.m_i;
          }
          friend bool operator!=(const iterator& lhs, const iterator& rhs) {
            return !(lhs == rhs);
          }
        };

        vertices_iterable() = default;
        vertices_iterable(query_result const * const qr) : m_qr(qr) {
        }

        iterator begin() const {
          return iterator(this->m_qr, 0);
        };
        iterator end() const {
          return iterator(this->m_qr, this->m_qr->m_parent.size());
        }
      };

      query_result() = default;

      ::std::size_t size() const {
        return this->m_parent.size();
      }

      vertices_iterable vertices() const {
        return vertices_iterable(this);
      }

      ::std::size_t root() const {
        return this->m_root;
      }

      ::std::size_t parent(const ::std::size_t v) const {
        const auto it = ::std::lower_bound(this->m_parent.begin(), this->m_parent.end(), ::std::make_pair(v, ::std::numeric_limits<::std::size_t>::max()), ::tools::less_by_first{});
        assert(it != this->m_parent.end());
        assert(it->first == v);
        return it->second;
      }

      const ::std::vector<::std::size_t>& children(const ::std::size_t v) const {
        const auto it = ::std::lower_bound(this->m_parent.begin(), this->m_parent.end(), ::std::make_pair(v, ::std::numeric_limits<::std::size_t>::max()), ::tools::less_by_first{});
        assert(it != this->m_parent.end());
        assert(it->first == v);
        return this->m_children[::std::distance(this->m_parent.begin(), it)];
      }

      friend ::tools::auxiliary_tree;
    };

    template <typename InputIterator>
    query_result query(const InputIterator begin, const InputIterator end) const {
      return query_result(*this, begin, end);
    }

    template <typename Z, ::std::enable_if_t<::std::is_integral_v<Z>, ::std::nullptr_t> = nullptr>
    query_result query(const ::std::vector<Z>& X) const {
      return this->query(X.begin(), X.end());
    }
  };
}


#line 1 "tools/rerooting_dp.hpp"



#line 11 "tools/rerooting_dp.hpp"

namespace tools {
  template <typename R, typename M, typename F_VE, typename F_EV>
  class rerooting_dp {
  private:
    ::std::vector<::std::size_t> m_edges;
    ::std::vector<::std::vector<::std::size_t>> m_graph;
    F_VE m_f_ve;
    F_EV m_f_ev;

    class vertex {
    private:
      const ::tools::rerooting_dp<R, M, F_VE, F_EV> *m_self;

    public:
      ::std::size_t id;
      ::std::size_t neighbor_id_of_parent;
      ::std::vector<::std::size_t> neighbor_ids_of_children;
      typename M::T parent_dp;
      ::std::vector<typename M::T> children_dp;
      ::std::vector<typename M::T> children_dp_cumsum1;
      ::std::vector<typename M::T> children_dp_cumsum2;

      vertex() = default;
      vertex(const vertex&) = default;
      vertex(vertex&&) = default;
      ~vertex() = default;
      vertex& operator=(const vertex&) = default;
      vertex& operator=(vertex&&) = default;

      explicit vertex(const ::tools::rerooting_dp<R, M, F_VE, F_EV> * const self, const ::std::size_t id) :
        m_self(self), id(id), parent_dp(M::e()) {
      }

      ::std::size_t parent_edge_id() const {
        return this->m_self->m_graph[this->id][this->neighbor_id_of_parent];
      }
      ::std::size_t parent_vertex_id() const {
        return this->m_self->m_edges[this->parent_edge_id()] ^ this->id;
      }
      ::std::size_t child_size() const {
        return this->neighbor_ids_of_children.size();
      }
      ::std::size_t child_edge_id(const ::std::size_t child_number) const {
        return this->m_self->m_graph[this->id][this->neighbor_ids_of_children[child_number]];
      }
      ::std::size_t child_vertex_id(const ::std::size_t child_number) const {
        return this->m_self->m_edges[this->child_edge_id(child_number)] ^ this->id;
      }
      R dp_as_root() const {
        return this->m_self->m_f_ev(M::op(this->parent_dp, this->children_dp_cumsum1.back()), this->id);
      }
      R dp_excluding_parent() const {
        return this->m_self->m_f_ev(this->children_dp_cumsum1.back(), this->id);
      }
      R dp_excluding_child(const ::std::size_t excluded_child_number) const {
        return this->m_self->m_f_ev(M::op(this->parent_dp, M::op(this->children_dp_cumsum1[excluded_child_number], this->children_dp_cumsum2[excluded_child_number + 1])), this->id);
      }
    };

  public:
    rerooting_dp() = default;
    rerooting_dp(const ::tools::rerooting_dp<R, M, F_VE, F_EV>&) = default;
    rerooting_dp(::tools::rerooting_dp<R, M, F_VE, F_EV>&&) = default;
    ~rerooting_dp() = default;
    ::tools::rerooting_dp<R, M, F_VE, F_EV>& operator=(const ::tools::rerooting_dp<R, M, F_VE, F_EV>&) = default;
    ::tools::rerooting_dp<R, M, F_VE, F_EV>& operator=(::tools::rerooting_dp<R, M, F_VE, F_EV>&&) = default;

    rerooting_dp(const ::std::size_t n, const F_VE& f_ve, const F_EV& f_ev) : m_graph(n), m_f_ve(f_ve), m_f_ev(f_ev) {
      assert(n >= 1);
    }

    ::std::size_t size() const {
      return this->m_graph.size();
    }

    ::std::size_t add_edge(const ::std::size_t u, const ::std::size_t v) {
      this->m_graph[u].push_back(this->m_edges.size());
      this->m_graph[v].push_back(this->m_edges.size());
      this->m_edges.push_back(u ^ v);
      return this->m_edges.size() - 1;
    }

    ::std::vector<R> query() const {
      assert(this->m_edges.size() + 1 == this->size());

      const int PRE_VERTEX = 1;
      const int POST_EDGE = 2;
      const int POST_VERTEX = 3;
      const ::std::size_t INVALID = ::std::numeric_limits<::std::size_t>::max();

      ::std::vector<vertex> vertices;
      for (::std::size_t i = 0; i < this->size(); ++i) {
        vertices.emplace_back(this, i);
      }

      ::std::stack<::std::tuple<int, ::std::size_t, ::std::size_t>> stack;
      ::std::vector<bool> will_visit(this->size(), false);
      stack.emplace(PRE_VERTEX, 0, INVALID);
      will_visit[0] = true;
      while (!stack.empty()) {
        const int type = ::std::get<0>(stack.top());
        if (type == PRE_VERTEX) {

          const ::std::size_t vertex_id = ::std::get<1>(stack.top());
          stack.pop();

          vertex& v = vertices[vertex_id];
          stack.emplace(POST_VERTEX, vertex_id, INVALID);
          for (::std::size_t neighbor_id = 0; neighbor_id < this->m_graph[vertex_id].size(); ++neighbor_id) {
            const ::std::size_t child_vertex_id = this->m_edges[this->m_graph[vertex_id][neighbor_id]] ^ vertex_id;
            if (will_visit[child_vertex_id]) {
              v.neighbor_id_of_parent = neighbor_id;
            } else {
              v.neighbor_ids_of_children.push_back(neighbor_id);
              stack.emplace(POST_EDGE, vertex_id, v.child_size() - 1);
              stack.emplace(PRE_VERTEX, child_vertex_id, INVALID);
              will_visit[child_vertex_id] = true;
            }
          }
          v.children_dp.resize(v.child_size());

        } else if (type == POST_EDGE) {

          const ::std::size_t vertex_id = ::std::get<1>(stack.top());
          const ::std::size_t child_number = ::std::get<2>(stack.top());
          stack.pop();

          vertex& v = vertices[vertex_id];
          const vertex& c = vertices[v.child_vertex_id(child_number)];
          v.children_dp[child_number] = this->m_f_ve(c.dp_excluding_parent(), v.child_edge_id(child_number));

        } else { // POST_VERTEX

          const ::std::size_t vertex_id = ::std::get<1>(stack.top());
          stack.pop();

          vertex& v = vertices[vertex_id];

          v.children_dp_cumsum1.reserve(v.child_size() + 1);
          v.children_dp_cumsum1.push_back(M::e());
          for (::std::size_t child_number = 0; child_number < v.child_size(); ++child_number) {
            v.children_dp_cumsum1.push_back(M::op(v.children_dp_cumsum1.back(), v.children_dp[child_number]));
          }

          v.children_dp_cumsum2.reserve(v.child_size() + 1);
          v.children_dp_cumsum2.push_back(M::e());
          for (::std::size_t child_number = v.child_size(); child_number --> 0;) {
            v.children_dp_cumsum2.push_back(M::op(v.children_dp[child_number], v.children_dp_cumsum2.back()));
          }
          ::std::reverse(v.children_dp_cumsum2.begin(), v.children_dp_cumsum2.end());

        }
      }

      stack.emplace(PRE_VERTEX, 0, INVALID);
      while (!stack.empty()) {
        const ::std::size_t vertex_id = ::std::get<1>(stack.top());
        stack.pop();

        const vertex& v = vertices[vertex_id];
        for (::std::size_t child_number = 0; child_number < v.child_size(); ++child_number) {
          vertex& c = vertices[v.child_vertex_id(child_number)];
          c.parent_dp = this->m_f_ve(v.dp_excluding_child(child_number), c.parent_edge_id());
          stack.emplace(PRE_VERTEX, c.id, INVALID);
        }
      }

      ::std::vector<R> result;
      result.reserve(this->size());
      for (const vertex& v : vertices) {
        result.push_back(v.dp_as_root());
      }
      return result;
    }
  };
}


#line 1 "tools/monoid.hpp"



#line 1 "tools/gcd.hpp"



#line 6 "tools/gcd.hpp"

namespace tools {
  template <typename M, typename N>
  constexpr ::std::common_type_t<M, N> gcd(const M m, const N n) {
    return ::std::gcd(m, n);
  }
}


#line 9 "tools/monoid.hpp"

namespace tools {
  namespace monoid {
    template <typename M, M ...dummy>
    struct max;

    template <typename M>
    struct max<M> {
      static_assert(::std::is_arithmetic_v<M>, "M must be a built-in arithmetic type.");

      using T = M;
      static T op(const T lhs, const T rhs) {
        return ::std::max(lhs, rhs);
      }
      static T e() {
        if constexpr (::std::is_integral_v<M>) {
          return ::std::numeric_limits<M>::min();
        } else {
          return -::std::numeric_limits<M>::infinity();
        }
      }
    };

    template <typename M, M E>
    struct max<M, E> {
      static_assert(::std::is_integral_v<M>, "M must be a built-in integral type.");

      using T = M;
      static T op(const T lhs, const T rhs) {
        assert(E <= lhs);
        assert(E <= rhs);
        return ::std::max(lhs, rhs);
      }
      static T e() {
        return E;
      }
    };

    template <typename M, M ...dummy>
    struct min;

    template <typename M>
    struct min<M> {
      static_assert(::std::is_arithmetic_v<M>, "M must be a built-in arithmetic type.");

      using T = M;
      static T op(const T lhs, const T rhs) {
        return ::std::min(lhs, rhs);
      }
      static T e() {
        if constexpr (::std::is_integral_v<M>) {
          return ::std::numeric_limits<M>::max();
        } else {
          return ::std::numeric_limits<M>::infinity();
        }
      }
    };

    template <typename M, M E>
    struct min<M, E> {
      static_assert(::std::is_integral_v<M>, "M must be a built-in integral type.");

      using T = M;
      static T op(const T lhs, const T rhs) {
        assert(lhs <= E);
        assert(rhs <= E);
        return ::std::min(lhs, rhs);
      }
      static T e() {
        return E;
      }
    };

    template <typename M>
    struct multiplies {
    private:
      using VR = ::std::conditional_t<::std::is_arithmetic_v<M>, const M, const M&>;

    public:
      using T = M;
      static T op(VR lhs, VR rhs) {
        return lhs * rhs;
      }
      static T e() {
        return T(1);
      }
    };

    template <>
    struct multiplies<bool> {
      using T = bool;
      static T op(const bool lhs, const bool rhs) {
        return lhs && rhs;
      }
      static T e() {
        return true;
      }
    };

    template <typename M>
    struct gcd {
    private:
      static_assert(!::std::is_arithmetic_v<M> || (::std::is_integral_v<M> && !::std::is_same_v<M, bool>), "If M is a built-in arithmetic type, it must be integral except for bool.");
      using VR = ::std::conditional_t<::std::is_arithmetic_v<M>, const M, const M&>;

    public:
      using T = M;
      static T op(VR lhs, VR rhs) {
        return ::tools::gcd(lhs, rhs);
      }
      static T e() {
        return T(0);
      }
    };

    template <typename M, M E>
    struct update {
      static_assert(::std::is_integral_v<M>, "M must be a built-in integral type.");

      using T = M;
      static T op(const T lhs, const T rhs) {
        return lhs == E ? rhs : lhs;
      }
      static T e() {
        return E;
      }
    };
  }
}


#line 9 "tests/auxiliary_tree.test.cpp"

int main() {
  std::cin.tie(nullptr);
  std::ios_base::sync_with_stdio(false);

  int N;
  std::cin >> N;
  std::vector<int> c(N);
  for (auto&& c_i : c) {
    std::cin >> c_i;
    --c_i;
  }
  tools::auxiliary_tree tree(N);
  for (int i = 0; i < N - 1; ++i) {
    int s, t;
    std::cin >> s >> t;
    --s, --t;
    tree.add_edge(s, t);
  }
  tree.build(0);

  std::vector<std::vector<int>> huts(N);
  for (int v = 0; v < N; ++v) {
    huts[c[v]].push_back(v);
  }

  std::vector<int> answers(N);
  std::vector<int> tree2aux(N);
  for (int color = 0; color < N; ++color) {
    if (huts[color].empty()) continue;

    const auto aux = tree.query(huts[color].begin(), huts[color].end());
    std::vector<int> aux2tree(aux.size());
    {
      int aux_v = 0;
      for (const auto tree_v : aux.vertices()) {
        tree2aux[tree_v] = aux_v;
        aux2tree[aux_v] = tree_v;
        ++aux_v;
      }
    }

    std::vector<int> w;
    const auto f_ve = [&](const auto& v, const auto e) {
      return (v.second ? 0 : v.first) + w[e];
    };
    const auto f_ev = [&](const auto e, const auto v) {
      return std::make_pair(e, c[aux2tree[v]] == color);
    };
    tools::rerooting_dp<std::pair<int, bool>, tools::monoid::min<int>, decltype(f_ve), decltype(f_ev)> dp(aux.size(), f_ve, f_ev);
    for (const auto tree_v : aux.vertices()) {
      if (tree_v == aux.root()) continue;
      dp.add_edge(tree2aux[tree_v], tree2aux[aux.parent(tree_v)]);
      const auto lca = tree.lca(tree_v, aux.parent(tree_v));
      w.push_back(tree.depth(tree_v) + tree.depth(aux.parent(tree_v)) - 2 * tree.depth(lca));
    }
    const auto partial_answers = dp.query();

    for (int aux_v = 0; std::cmp_less(aux_v, aux.size()); ++aux_v) {
      if (partial_answers[aux_v].second) {
        answers[aux2tree[aux_v]] = partial_answers[aux_v].first;
      }
    }
  }

  for (const auto answer : answers) {
    std::cout << answer << '\n';
  }
  return 0;
}

Test cases

Env Name Status Elapsed Memory
g++ 00_sample_00.in :heavy_check_mark: AC 5 ms 4 MB
g++ 01_small_00.in :heavy_check_mark: AC 5 ms 4 MB
g++ 01_small_01.in :heavy_check_mark: AC 4 ms 4 MB
g++ 02_corner_00.in :heavy_check_mark: AC 4 ms 4 MB
g++ 02_corner_01.in :heavy_check_mark: AC 4 ms 4 MB
g++ 02_corner_02.in :heavy_check_mark: AC 4 ms 4 MB
g++ 04_rand_00.in :heavy_check_mark: AC 4 ms 4 MB
g++ 04_rand_01.in :heavy_check_mark: AC 4 ms 4 MB
g++ 04_rand_02.in :heavy_check_mark: AC 4 ms 4 MB
g++ 04_rand_03.in :heavy_check_mark: AC 5 ms 4 MB
g++ 04_rand_04.in :heavy_check_mark: AC 4 ms 4 MB
g++ 04_rand_05.in :heavy_check_mark: AC 4 ms 4 MB
g++ 04_rand_06.in :heavy_check_mark: AC 5 ms 4 MB
g++ 04_rand_07.in :heavy_check_mark: AC 5 ms 4 MB
g++ 05_large_00.in :heavy_check_mark: AC 6 ms 4 MB
g++ 05_large_01.in :heavy_check_mark: AC 6 ms 4 MB
g++ 05_large_02.in :heavy_check_mark: AC 6 ms 4 MB
g++ 05_large_03.in :heavy_check_mark: AC 5 ms 4 MB
g++ 06_huge_00.in :heavy_check_mark: AC 17 ms 5 MB
g++ 06_huge_01.in :heavy_check_mark: AC 14 ms 5 MB
g++ 07_maximum_00.in :heavy_check_mark: AC 168 ms 16 MB
g++ 07_maximum_01.in :heavy_check_mark: AC 120 ms 17 MB
g++ 08_extreme_00.in :heavy_check_mark: AC 388 ms 30 MB
g++ 08_extreme_01.in :heavy_check_mark: AC 306 ms 29 MB
g++ 08_extreme_02.in :heavy_check_mark: AC 236 ms 30 MB
g++ 10_long_01.in :heavy_check_mark: AC 239 ms 29 MB
g++ 10_long_02.in :heavy_check_mark: AC 237 ms 29 MB
g++ 10_long_03.in :heavy_check_mark: AC 230 ms 29 MB
g++ 11_long_01.in :heavy_check_mark: AC 212 ms 28 MB
g++ 12_long_01.in :heavy_check_mark: AC 216 ms 29 MB
g++ 20_star_01.in :heavy_check_mark: AC 187 ms 29 MB
g++ 20_star_02.in :heavy_check_mark: AC 200 ms 29 MB
g++ 20_star_03.in :heavy_check_mark: AC 205 ms 28 MB
g++ 20_star_04.in :heavy_check_mark: AC 194 ms 27 MB
g++ 20_star_05.in :heavy_check_mark: AC 220 ms 28 MB
g++ 21_star_01.in :heavy_check_mark: AC 203 ms 29 MB
g++ 21_star_02.in :heavy_check_mark: AC 244 ms 28 MB
g++ 21_star_03.in :heavy_check_mark: AC 296 ms 29 MB
g++ 21_star_04.in :heavy_check_mark: AC 353 ms 28 MB
g++ 21_star_05.in :heavy_check_mark: AC 359 ms 29 MB
Back to top page