CS Academy Random Nim Generator

リンク : https://csacademy.com/contest/archive/task/random_nim_generator/statement/

問題概要

N個の0より大きくK以下の数字の入った配列を作ります。作った配列のすべてをxorすると0より大きくなるものは何通りあるか求めよ。

解説

まず、普通にdpをしてみようと考える。dp[i][j]をi番目の配列までのxorがjの通り数とする。すると
dp[i+1][j]=sum{{k=0,1,....K}{l=0,1,...K}if(kl==j)dp[i][k]}となる。しかしこれではO(NK3)となり到底間に合わない。そこでアダマール変換をする。アダマール変換とは、簡単に説明すると、C[i] = sum{jk=i} A[j]B[k] をO(NlogN)でやるものである。これを用いると、O(NKlogK)となる。これでもまだ足りない。ここで、アダマール変換は毎回同じことをしているかつ結合法則が成り立つため、繰り返し二乗法のようにアダマール変換をする。するとO(logNK*logK)となり間に合う。

namespace FWT {
    LL mod = MOD;
    void init(LL _mod) {
        mod = _mod;
    }
    LL modpow(LL a, LL n = mod - 2) {
        LL r = 1;
        while (n) r = r * ((n % 2) ? a : 1) % mod, a = a * a%mod, n >>= 1;
        return r;
    }
    void FWT(vector<LL>&a) {
        int n = a.size();
        for (int d = 1; d < n; d <<= 1)
            for (int m = d << 1, i = 0; i < n; i += m)
                for (int j = 0; j < d; j++)
                {
                    LL x = a[i + j], y = a[i + j + d];
                    a[i + j] = (x + y) % mod, a[i + j + d] = (x - y + mod) % mod;
                }
    }
    void UFWT(vector<LL>&a) {
        int n = a.size();
        LL rev = modpow(2);
        for (int d = 1; d < n; d <<= 1)
            for (int m = d << 1, i = 0; i < n; i += m)
                for (int j = 0; j < d; j++)
                {
                    LL x = a[i + j], y = a[i + j + d];
                    a[i + j] = 1LL * (x + y)*rev%mod, a[i + j + d] = (1LL * (x - y + mod)%mod*rev) % mod;
                }
    }
    vector<LL> solve(vector<LL>&a, vector<LL>b){
        int n = a.size();
        FWT(a);
        FWT(b);
        for (int i = 0; i<n; i++) a[i] = 1LL * a[i] * b[i] % mod;
        UFWT(a);
        return a;
    }
};
LL n, k,ans;
vector<LL>dp, ret;
int main() {
    cin >> n >> k;
    FWT::init(30011);
    dp.resize(1 << 16);
    ret.resize(1 << 16);
    rep(i, k + 1)dp[i] = 1;
    ret[0] = 1;
    while (n > 0) {
        if (n & 1) {
            FWT::solve(ret, dp);
        }
        FWT::solve(dp, dp);
        n >>= 1;
    }
    rep(i, 1 << 16)
        if(i)
        ans += ret[i];
    cout << ans%30011 << endl;
    return 0;
}