[AtCoder] ABC156 E – Roaming (500点)

2020年2月24日AtCoder剰余,二項係数,操作,数え上げ,500点,重複組合せ

問題概要

n 個の部屋に 1 ずつ人がいる。以下を k 回行ったとき、考えられる状態は何通りあるか \(10^9+7\) で割った余りで答えよ。

  • ある部屋 \(i\) にいた人が、\(i \neq j\) を満たす任意の部屋 \(j\) に移動する

問題へのリンク

制約

  • \(3 \leq n \leq 2 \times 10^5\)
  • \(2 \leq k\leq 10^9\)

考え方

何から考えたら良いか分からなくなりますが、「操作を k 回行う」ような問題では、操作で変化する量や変化しない量に注目すると見通しが立つことがあります。

今回では

  • 移動をしても全体の人数は変わらない
  • 移動をすると生じる可能性のある 0 人の部屋の数が増える

などが言えます。

前者は当たり前ですが、後者はどういうことかを説明しましょう。

例えば、はじめ3部屋あったとしたら、

  • 0回移動: (1,1,1) で0人の部屋は最大でも0個
  • 1回移動: (0,1,2) (0,2,1) (1,0,2) (1,2,0) (2,0,1) (2,1,0) で0人の部屋は最大でも1個
  • 2回移動: (1,1,1) (0,1,2) (0,2,1) (1,0,2) (1,2,0) (2,0,1) (2,1,0) (0,0,3) (3,0,0) で0人の部屋は最大でも2個

というようになります。これの太字部分を見ると分かりますが、「i 回移動した時、i-1 回移動した時に比べて、空き部屋が i 個の場合が増える」ことに気がつけばほぼ正解です。
( (1,1,1) だけは例外で、1回移動の時に含まれていません。そこに気がつくと、制約の \(2\leq k\) が親切に見えてきます。)

また、空き部屋は最大でも n-1 部屋までしか作れないので、k が n 以上になっているときはどれだけ移動しても k=n-1の時の値と変わらないです。

よって、以下の計算をすれば良いです。

  • \(\sum_{i=0}^{\min(n-1,k)}\) 空き部屋が \(i\) 個 ある場合の組み合わせ

解法

  • \(\sum_{i=0}^{\min(n-1,k)}\) 空き部屋が \(i\) 個 ある場合の組み合わせ

が答えになります。空き部屋が \(i\) 個 ある場合が何通りあるかは以下のように計算ができます。

$$ _nC_i \times {}_{n-i}H_i$$

意味としては、

  • \(_nC_i\) : n 個の中から 0人の部屋を i 個選ぶ
  • \( {}_{n-i}H_i\) : 残りの \(n-i\) 個の中から重複ありで \(i\) 人の移動先を決める

ということになります。気持ちとしては、剰余を考えないなら以下のような計算をすれば良いです。

long long ans = 1;
for(int i=1; i < k + 1; i++) {
    ans += nCk(n,i) * nHk(n - i, i);
}

重複組合せ \( {}_{n-i}H_i\) は二項係数を用いて以下のように計算することができます。

$$ {}_{x}H_y = {}_{x+y-1}C_y$$

以上より、二項係数を適切な前処理によって、\(O(1)\) で求めることができるならば、十分高速に計算可能です。

C++での実装例

以下のように、剰余の計算をするときは modint 構造体などを利用すると整数のようにプログラムができます。

#include <bits/stdc++.h>
#define LOOP(n) for (int _i = 0; _i < (n); _i++)
#define REP(i, n) for (int i = 0; i < (n); ++i)
#define RREP(i, n) for (int i = (n); i >= 0; --i)
#define FOR(i, r, n) for (int i = (r); i < (n); ++i)
#define ALL(obj) begin(obj), end(obj)
using namespace std;

/* ModInt
    modが絡む計算を整数の計算のようにするための構造体。
*/
template <int mod>
struct ModInt {
    int val;
    ModInt() : val(0) {}
    ModInt(long long x) : val(x >= 0 ? x % mod : (mod - (-x) % mod) % mod) {}
    int getmod() { return mod; }
    ModInt &operator+=(const ModInt &p) {
        if ((val += p.val) >= mod) {
            val -= mod;
        }
        return *this;
    }
    ModInt &operator-=(const ModInt &p) {
        if ((val += mod - p.val) >= mod) {
            val -= mod;
        }
        return *this;
    }
    ModInt &operator*=(const ModInt &p) {
        val = (int)(1LL * val * p.val % mod);
        return *this;
    }
    ModInt &operator/=(const ModInt &p) {
        *this *= p.inverse();
        return *this;
    }
    ModInt operator-() const { return ModInt(-val); }
    ModInt operator+(const ModInt &p) const { return ModInt(*this) += p; }
    ModInt operator-(const ModInt &p) const { return ModInt(*this) -= p; }
    ModInt operator*(const ModInt &p) const { return ModInt(*this) *= p; }
    ModInt operator/(const ModInt &p) const { return ModInt(*this) /= p; }
    bool operator==(const ModInt &p) const { return val == p.val; }
    bool operator!=(const ModInt &p) const { return val != p.val; }
    ModInt inverse() const {
        int a = val, b = mod, u = 1, v = 0, t;
        while (b > 0) {
            t = a / b;
            swap(a -= t * b, b);
            swap(u -= t * v, v);
        }
        return ModInt(u);
    }
    ModInt pow(long long n) const {
        ModInt ret(1), mul(val);
        while (n > 0) {
            if (n & 1) ret *= mul;
            mul *= mul;
            n >>= 1;
        }
        return ret;
    }
    friend ostream &operator<<(ostream &os, const ModInt &p) { return os << p.val; }
    friend istream &operator>>(istream &is, ModInt &a) {
        long long t;
        is >> t;
        a = ModInt<mod>(t);
        return (is);
    }
    static int get_mod() { return mod; }
};

/* Comb:modintで二項係数を計算する構造体
    前処理:O(n)
    二項係数の計算:O(1)
*/
template <class T>
struct Comb {
    vector<T> fact_, fact_inv_, inv_;
    Comb() {}
    Comb(int SIZE) : fact_(SIZE, 1), fact_inv_(SIZE, 1), inv_(SIZE, 1) { init(SIZE); }
    void init(int SIZE) {
        fact_.assign(SIZE, 1), fact_inv_.assign(SIZE, 1), inv_.assign(SIZE, 1);
        int MOD = fact_[0].getmod();
        for (int i = 2; i < SIZE; i++) {
            fact_[i] = fact_[i - 1] * i;
            inv_[i] = -inv_[MOD % i] * (MOD / i);
            fact_inv_[i] = fact_inv_[i - 1] * inv_[i];
        }
    }
    T nCk(int n, int k) {
        assert(!(n < k));
        assert(!(n < 0 || k < 0));
        return fact_[n] * fact_inv_[k] * fact_inv_[n - k];
    }
    T nHk(int n, int k) {
        assert(!(n < 0 || k < 0));
        return nCk(n + k - 1, k);
    }
    T fact(int n) {
        assert(!(n < 0));
        return fact_[n];
    }
    T fact_inv(int n) {
        assert(!(n < 0));
        return fact_inv_[n];
    }
    T inv(int n) {
        assert(!(n < 0));
        return inv_[n];
    }
};

const int MOD = 1000000007;  // if inv is needed, this shold be prime.
using modint = ModInt<MOD>;
Comb<modint> comb(1000000);

int main() {
    long long N, K;
    cin >> N >> K;

    if (K >= N) {
        K = N - 1;
    }

    modint ans = modint(1);

    FOR(k, 1, K + 1) { ans += comb.nCk(N, k) * comb.nHk(N - k, k); }

    cout << ans << endl;
    return 0;
}