Time Complexity: O(NlogN)\mathcal O(N \log N)

We're asked to evaluate a polynomial at NN points, which implies that our solution will involve the Fast Fourier Transform. Indeed, our solution is mostly based on the idea behind the Discrete Fourier Transform - that P(x)=Peven(x2)+xPodd(x2)P(x) = P_{even}(x^2) + x \cdot P_{odd}(x^2) for any polynomial PP.

Let solve(k,P)=(P(qki)modM)i=0N/k\texttt{solve}(k, P) = (P(q^{ki}) \bmod M)_{i = 0}^{N / k}. Our answer will be solve(1,W)\texttt{solve}(1, W), and we will compute solve(k,P)\texttt{solve}(k, P) recursively in O(NklogNk)\mathcal O(\frac{N}{k} \log \frac{N}{k}) time.

First, precompute all qimodMq^i \bmod M. If k=Nk = N, then PP is a constant polynomial. and can be evaluated in constant time. Otherwise, let v=solve(k,P)v = \texttt{solve}(k, P), ueven=solve(2k,Peven)u_\text{even} = \texttt{solve}(2k, P_\text{even}), and uodd=solve(2k,Podd)u_\text{odd} = \texttt{solve}(2k, P_\text{odd}). We have two cases when computing each v[i]v[i]:

  • Case 1: i<N2ki < \frac{N}{2k}:
    v[i]=P(qki)modM=(Peven(q2ki)+qkiPodd(q2ki))modM=(ueven[i]+qkiuodd[i])modM\begin{aligned} v[i] &= P(q^{ki}) \bmod M\\ &= (P_\text{even}(q^{2ki}) + q^{ki} \cdot P_\text{odd}(q^{2ki})) \bmod M\\ &= (u_\text{even}[i] + q^{ki} \cdot u_\text{odd}[i]) \bmod M \end{aligned}
  • Case 2: iN2ki \geq \frac{N}{2k}. In this case, let i=N2k+ji = \frac{N}{2k} + j:
    v[i]=P(qN2+kj)modM=(Peven(qNq2kj)+qkiPodd(qNq2kj))modM=(Peven(q2kj)+qkiPodd(q2kj))modM=(ueven[j]+qkiuodd[j])modM\begin{aligned} v[i] &= P(q^{\frac{N}{2} + kj}) \bmod M\\ &= (P_\text{even}(q^N \cdot q^{2kj}) + q^{ki} \cdot P_\text{odd}(q^N \cdot q^{2kj})) \bmod M\\ &= (P_\text{even}(q^{2kj}) + q^{ki} \cdot P_\text{odd}(q^{2kj})) \bmod M\\ &= (u_\text{even}[j] + q^{ki} \cdot u_\text{odd}[j]) \bmod M \end{aligned}

Putting this all together, we get:

v[i]=(ueven[imodN2k]+pkiuodd[imodN2k])modMv[i] = \left(u_\text{even}\left[i \bmod \frac{N}{2k}\right] + p^{ki} \cdot u_\text{odd}\left[i \bmod \frac{N}{2k}\right]\right) \bmod M
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
int n;
ll m, q, a[1 << 20], q_pow[1 << 20];
vector<ll> dft(int k = 1, int idx = 0) {
if (k == n) return {a[idx]};
else {

Join the USACO Forum!

Stuck on a problem, or don't understand a module? Join the USACO Forum and get help from other competitive programmers!