Loading [MathJax]/extensions/tex2jax.js

[AtCoder] ABC154 E – Almost Everywhere Zero (500点)

2020年2月10日AtCoder500点,桁DP,場合分け,動的計画法,数え上げ,DP

問題へのリンク

問題概要

1 以上 N 以下の整数で、0 でない数字がちょうど K 個あるものの個数を求めよ。

制約

\begin{align}
&1 \leq N \leq 10^{100} \\
&1 \leq K \leq 3
\end{align}

考え方

入力の N が非常に大きい問題なので、32bitや64bitの整数型では収まりません。このような場合は、文字列として入力を受け取り、「桁ごとに考える」のが良いです。

また、「1 以上 N 以下の整数について、条件を満たす数」を求めるような問題は、桁DPでうまく計算できることが多いです。

K の値に応じて場合分けをして、ゴリ押しで計算をする方法もありますが、計算が面倒なのでオススメできません。

今回は以下のようなDPを考えましょう。

dp[ i ][ smaller ][ k ] := i 桁目以降で 0 以外の数を使用できるのが残り k 個である数の個数。i 桁目までの部分について、 smaller が true なら N より小さく、false なら N と等しい数であるとする。

解き方

以下のように dp を更新しましょう。

dp[ i ][true][ k ]について

  • dp[i+1][true][ k ] (次の桁が0の時)
  • dp[i+1][true][ k+1 ] (次の桁が0以外の時)

dp[ i ][false][ k ] について

  • dp[i+1][true][ k ] (次の桁が0の時)
  • dp[i+1][true][k-1]*(N[i] – '1’) (次の桁が0でもN[i]でも無い時)
  • dp[i+1][true][k-1] (次の桁がN[i]の時)

実装例

dp が3次元配列となるので、実際にループを回して計算しようとすると頭を使う必要があります。

メモ化再帰で計算することにしました。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
#include <bits/stdc++.h>
#define LOOP(n) for (int _i = 0; _i < (n); _i++)
#define REP(i, n) for (int i = 0; i < (n); ++i)
#define FOR(i, r, n) for (int i = (r); i < (n); ++i)
#define ALL(obj) begin(obj), end(obj)
using namespace std;
using ll = long long;
string N;
int K;
ll dp[110][2][4];
ll rec(int i = 0, bool smaller = false, int k = K) {
if (k == 0) return 1;
if (i >= N.size()) return 0;
ll &ret = dp[i][(int)smaller][k];
if (ret > 0) return ret;
ret = 0;
if (smaller) {
ret += rec(i + 1, true, k); // 次の桁が0の時
ret += rec(i + 1, true, k - 1) * 9LL; // 次の桁が0以外の時
} else {
if (N[i] == '0') { // 次の桁が0しかありえない時
ret += rec(i + 1, false, k);
} else {
ret += rec(i + 1, true, k); // 次の桁が0の時
ret += rec(i + 1, true, k - 1) * (N[i] - '1'); // 次の桁が0でもN[i]でも無い時
ret += rec(i + 1, false, k - 1); // 次の桁がN[i]の時
}
}
return ret;
}
int main() {
cin >> N;
cin >> K;
cout << rec() << endl;
return 0;
}
#include <bits/stdc++.h> #define LOOP(n) for (int _i = 0; _i < (n); _i++) #define REP(i, n) for (int i = 0; i < (n); ++i) #define FOR(i, r, n) for (int i = (r); i < (n); ++i) #define ALL(obj) begin(obj), end(obj) using namespace std; using ll = long long; string N; int K; ll dp[110][2][4]; ll rec(int i = 0, bool smaller = false, int k = K) { if (k == 0) return 1; if (i >= N.size()) return 0; ll &ret = dp[i][(int)smaller][k]; if (ret > 0) return ret; ret = 0; if (smaller) { ret += rec(i + 1, true, k); // 次の桁が0の時 ret += rec(i + 1, true, k - 1) * 9LL; // 次の桁が0以外の時 } else { if (N[i] == '0') { // 次の桁が0しかありえない時 ret += rec(i + 1, false, k); } else { ret += rec(i + 1, true, k); // 次の桁が0の時 ret += rec(i + 1, true, k - 1) * (N[i] - '1'); // 次の桁が0でもN[i]でも無い時 ret += rec(i + 1, false, k - 1); // 次の桁がN[i]の時 } } return ret; } int main() { cin >> N; cin >> K; cout << rec() << endl; return 0; }
#include <bits/stdc++.h>
#define LOOP(n) for (int _i = 0; _i < (n); _i++)
#define REP(i, n) for (int i = 0; i < (n); ++i)
#define FOR(i, r, n) for (int i = (r); i < (n); ++i)
#define ALL(obj) begin(obj), end(obj)
using namespace std;
using ll = long long;

string N;
int K;

ll dp[110][2][4];

ll rec(int i = 0, bool smaller = false, int k = K) {
    if (k == 0) return 1;
    if (i >= N.size()) return 0;

    ll &ret = dp[i][(int)smaller][k];
    if (ret > 0) return ret;
    ret = 0;

    if (smaller) {
        ret += rec(i + 1, true, k);            // 次の桁が0の時
        ret += rec(i + 1, true, k - 1) * 9LL;  // 次の桁が0以外の時
    } else {
        if (N[i] == '0') {  // 次の桁が0しかありえない時
            ret += rec(i + 1, false, k);
        } else {
            ret += rec(i + 1, true, k);                     // 次の桁が0の時
            ret += rec(i + 1, true, k - 1) * (N[i] - '1');  // 次の桁が0でもN[i]でも無い時
            ret += rec(i + 1, false, k - 1);                // 次の桁がN[i]の時
        }
    }
    return ret;
}

int main() {
    cin >> N;
    cin >> K;
    cout << rec() << endl;
    return 0;
}

DPの別解

以下のようなDPを考えることもできます。

dp[ i ][ smaller ][ k ] := i 桁目までで 0 以外の数を使用したのが k 個である数の個数。i 桁目までの部分について、 smaller が true なら N より小さく、false なら N と等しい数であるとする。

このように考えると、DPの遷移は以下のような場合3つについて考えれば良いです。(参考:桁DP)

  • dp[ i ][true] からはdp[ i+1 ][ture]にのみ遷移( i桁目まででNより小さいなら i+1桁目をどのように選んでもNより小さい)
  • dp[ i ][false] からdp[ i+1 ][ture]へ遷移( i桁目までNと同じで、 i+1桁目はNより小さい数の時)
  • dp[ i ][false] からdp[ i+1 ][false]へ遷移( i桁目までNと同じで、 i+1桁目もNと同じ数の時)
Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
#include <bits/stdc++.h>
#define LOOP(n) for (int _i = 0; _i < (n); _i++)
#define REP(i, n) for (int i = 0; i < (n); ++i)
#define RREP(i, n) for (int i = n; i >= 0; --i)
#define FOR(i, r, n) for (int i = (r); i < (n); ++i)
#define ALL(obj) begin(obj), end(obj)
using namespace std;
using ll = long long;
using ull = unsigned long long;
const int MOD = 1e9 + 7;
int K;
string N;
ll dp[10005][2][5];
int main() {
cin >> N >> K;
int n = N.size();
dp[0][0][0] = 1;
REP(i, n) {
REP(k, K + 1) {
// i桁目まででNより小さいならi+1桁目は何でも良い
dp[i + 1][1][k + 1] += dp[i][1][k] * 9; // i+1桁目が0以外
dp[i + 1][1][k] += dp[i][1][k]; // i+1桁目が0
int ni = (N[i] - '0');
// i桁目までNと同じで、i+1桁目はNより小さい数の時
if (ni > 0) {
dp[i + 1][1][k + 1] += dp[i][0][k] * (ni - 1); // i+1桁目が0以外
dp[i + 1][1][k] += dp[i][0][k]; // i+1桁目が0
}
// i桁目までNと同じで、i+1桁目もNと同じ数の時
if (ni > 0) {
dp[i + 1][0][k + 1] = dp[i][0][k]; // i+1桁目が0以外
} else {
dp[i + 1][0][k] = dp[i][0][k]; // i+1桁目が0
}
}
}
cout << dp[n][0][K] + dp[n][1][K] << endl;
return 0;
}
#include <bits/stdc++.h> #define LOOP(n) for (int _i = 0; _i < (n); _i++) #define REP(i, n) for (int i = 0; i < (n); ++i) #define RREP(i, n) for (int i = n; i >= 0; --i) #define FOR(i, r, n) for (int i = (r); i < (n); ++i) #define ALL(obj) begin(obj), end(obj) using namespace std; using ll = long long; using ull = unsigned long long; const int MOD = 1e9 + 7; int K; string N; ll dp[10005][2][5]; int main() { cin >> N >> K; int n = N.size(); dp[0][0][0] = 1; REP(i, n) { REP(k, K + 1) { // i桁目まででNより小さいならi+1桁目は何でも良い dp[i + 1][1][k + 1] += dp[i][1][k] * 9; // i+1桁目が0以外 dp[i + 1][1][k] += dp[i][1][k]; // i+1桁目が0 int ni = (N[i] - '0'); // i桁目までNと同じで、i+1桁目はNより小さい数の時 if (ni > 0) { dp[i + 1][1][k + 1] += dp[i][0][k] * (ni - 1); // i+1桁目が0以外 dp[i + 1][1][k] += dp[i][0][k]; // i+1桁目が0 } // i桁目までNと同じで、i+1桁目もNと同じ数の時 if (ni > 0) { dp[i + 1][0][k + 1] = dp[i][0][k]; // i+1桁目が0以外 } else { dp[i + 1][0][k] = dp[i][0][k]; // i+1桁目が0 } } } cout << dp[n][0][K] + dp[n][1][K] << endl; return 0; }
#include <bits/stdc++.h>
#define LOOP(n) for (int _i = 0; _i < (n); _i++)
#define REP(i, n) for (int i = 0; i < (n); ++i)
#define RREP(i, n) for (int i = n; i >= 0; --i)
#define FOR(i, r, n) for (int i = (r); i < (n); ++i)
#define ALL(obj) begin(obj), end(obj)
using namespace std;
using ll = long long;
using ull = unsigned long long;

const int MOD = 1e9 + 7;

int K;
string N;
ll dp[10005][2][5];

int main() {
    cin >> N >> K;
    int n = N.size();

    dp[0][0][0] = 1;

    REP(i, n) {
        REP(k, K + 1) {
            // i桁目まででNより小さいならi+1桁目は何でも良い
            dp[i + 1][1][k + 1] += dp[i][1][k] * 9;  // i+1桁目が0以外
            dp[i + 1][1][k] += dp[i][1][k];          // i+1桁目が0

            int ni = (N[i] - '0');

            // i桁目までNと同じで、i+1桁目はNより小さい数の時
            if (ni > 0) {
                dp[i + 1][1][k + 1] += dp[i][0][k] * (ni - 1);  // i+1桁目が0以外
                dp[i + 1][1][k] += dp[i][0][k];                 // i+1桁目が0
            }

            //  i桁目までNと同じで、i+1桁目もNと同じ数の時
            if (ni > 0) {
                dp[i + 1][0][k + 1] = dp[i][0][k];  // i+1桁目が0以外
            } else {
                dp[i + 1][0][k] = dp[i][0][k];  // i+1桁目が0
            }
        }
    }

    cout << dp[n][0][K] + dp[n][1][K] << endl;

    return 0;
}

おまけ(場合分けでゴリ押し)

例えば N=314159 とした時には

  • 1~99999
  • 100000~299999
  • 300000~314159

のように場合分けして、さらに K の値で場合分けをすることで、それぞれの場合で条件を満たす数を直接求めることが可能になります。

コーナーケースを色々と考える必要がありかなり面倒ですが、高速に求めることができるのが利点です。

先に、「N に出現する0以外の数」を上位の桁から 3 つ求めておくと、後で計算が少し楽になります。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
#include <bits/stdc++.h>
#define LOOP(n) for (int _i = 0; _i < (n); _i++)
#define REP(i, n) for (int i = 0; i < (n); ++i)
#define FOR(i, r, n) for (int i = (r); i < (n); ++i)
#define ALL(obj) begin(obj), end(obj)
using namespace std;
using ll = long long;
int main() {
string N;
cin >> N;
int K;
cin >> K;
ll n = (ll)N.size();
vector<int> keta, val; // 「N に出現する0以外の数」を上位の桁から 3 つ
REP(i, n) {
if (N[i] != '0') {
keta.push_back(n - i - 1);
val.push_back((int)(N[i] - '0'));
}
}
REP(i, 3) { // 3つ以上ない場合も考えて 0 を詰めておく
keta.push_back(0);
val.push_back(0);
}
ll ans = 0;
REP(i, n - 1) { // 1~n-1 桁の数について
if (K == 1) {
ans += 9;
} else if (K == 2) {
ans += 9 * 9 * i;
} else if (K == 3) {
ans += 9 * 9 * 9 * (i * (i - 1) / 2);
}
}
// n 桁の数について
if (K == 1) {
ans += val[0]; // 先頭は 1~val[0]通り
} else if (K == 2) {
ans += (val[0] - 1) * 9 * (n - 1); // 先頭を 1~(val[0]-1) のどれかにする時
ans += val[1]; // 先頭を val[0] にして、2つ目を keta[1]+1 桁目で使う時
ans += 9 * keta[1]; // 先頭を val[0] にして、2つ目を 1~keta[1] 桁目のどこかで使う時
} else if (K == 3) {
ans += (val[0] - 1) * 9 * 9 * (n - 1) * (n - 2) / 2; // 先頭を 1~(val[0]-1) にする時
ans += (val[1] - 1) * 9
* (keta[1]); // 先頭をval[0]にして、2つ目を keta[1]+1 桁目で1~(val[1]-1) のどれかにする時
ans += val[2]; // 先頭をval[0]、2つ目を keta[1]+1 桁目で val[1] にして、3つ目を keta[2]+1 桁目で使う時
ans += 9 * keta[2]; // 先頭をval[0]、2つ目を keta[1]+1 桁目で val[1] にして、3つ目を 1~keta[2]
// 桁目のどこかで使う時
ans += 9 * 9 * (keta[1]) * ((keta[1]) - 1) / 2; // 先頭をval[0]、2つ目を keta[2]+1 桁目で使う時
}
cout << ans << endl;
return 0;
}
#include <bits/stdc++.h> #define LOOP(n) for (int _i = 0; _i < (n); _i++) #define REP(i, n) for (int i = 0; i < (n); ++i) #define FOR(i, r, n) for (int i = (r); i < (n); ++i) #define ALL(obj) begin(obj), end(obj) using namespace std; using ll = long long; int main() { string N; cin >> N; int K; cin >> K; ll n = (ll)N.size(); vector<int> keta, val; // 「N に出現する0以外の数」を上位の桁から 3 つ REP(i, n) { if (N[i] != '0') { keta.push_back(n - i - 1); val.push_back((int)(N[i] - '0')); } } REP(i, 3) { // 3つ以上ない場合も考えて 0 を詰めておく keta.push_back(0); val.push_back(0); } ll ans = 0; REP(i, n - 1) { // 1~n-1 桁の数について if (K == 1) { ans += 9; } else if (K == 2) { ans += 9 * 9 * i; } else if (K == 3) { ans += 9 * 9 * 9 * (i * (i - 1) / 2); } } // n 桁の数について if (K == 1) { ans += val[0]; // 先頭は 1~val[0]通り } else if (K == 2) { ans += (val[0] - 1) * 9 * (n - 1); // 先頭を 1~(val[0]-1) のどれかにする時 ans += val[1]; // 先頭を val[0] にして、2つ目を keta[1]+1 桁目で使う時 ans += 9 * keta[1]; // 先頭を val[0] にして、2つ目を 1~keta[1] 桁目のどこかで使う時 } else if (K == 3) { ans += (val[0] - 1) * 9 * 9 * (n - 1) * (n - 2) / 2; // 先頭を 1~(val[0]-1) にする時 ans += (val[1] - 1) * 9 * (keta[1]); // 先頭をval[0]にして、2つ目を keta[1]+1 桁目で1~(val[1]-1) のどれかにする時 ans += val[2]; // 先頭をval[0]、2つ目を keta[1]+1 桁目で val[1] にして、3つ目を keta[2]+1 桁目で使う時 ans += 9 * keta[2]; // 先頭をval[0]、2つ目を keta[1]+1 桁目で val[1] にして、3つ目を 1~keta[2] // 桁目のどこかで使う時 ans += 9 * 9 * (keta[1]) * ((keta[1]) - 1) / 2; // 先頭をval[0]、2つ目を keta[2]+1 桁目で使う時 } cout << ans << endl; return 0; }
#include <bits/stdc++.h>
#define LOOP(n) for (int _i = 0; _i < (n); _i++)
#define REP(i, n) for (int i = 0; i < (n); ++i)
#define FOR(i, r, n) for (int i = (r); i < (n); ++i)
#define ALL(obj) begin(obj), end(obj)
using namespace std;
using ll = long long;

int main() {
    string N;
    cin >> N;
    int K;
    cin >> K;
    ll n = (ll)N.size();

    vector<int> keta, val;  // 「N に出現する0以外の数」を上位の桁から 3 つ
    REP(i, n) {
        if (N[i] != '0') {
            keta.push_back(n - i - 1);
            val.push_back((int)(N[i] - '0'));
        }
    }
    REP(i, 3) {  // 3つ以上ない場合も考えて 0 を詰めておく
        keta.push_back(0);
        val.push_back(0);
    }

    ll ans = 0;
    REP(i, n - 1) {  // 1~n-1 桁の数について
        if (K == 1) {
            ans += 9;
        } else if (K == 2) {
            ans += 9 * 9 * i;
        } else if (K == 3) {
            ans += 9 * 9 * 9 * (i * (i - 1) / 2);
        }
    }

    // n 桁の数について
    if (K == 1) {
        ans += val[0];  // 先頭は 1~val[0]通り
    } else if (K == 2) {
        ans += (val[0] - 1) * 9 * (n - 1);  // 先頭を 1~(val[0]-1) のどれかにする時
        ans += val[1];                      // 先頭を val[0] にして、2つ目を keta[1]+1 桁目で使う時
        ans += 9 * keta[1];  // 先頭を val[0] にして、2つ目を 1~keta[1] 桁目のどこかで使う時
    } else if (K == 3) {
        ans += (val[0] - 1) * 9 * 9 * (n - 1) * (n - 2) / 2;  // 先頭を 1~(val[0]-1) にする時
        ans += (val[1] - 1) * 9
               * (keta[1]);  // 先頭をval[0]にして、2つ目を keta[1]+1 桁目で1~(val[1]-1) のどれかにする時
        ans += val[2];  // 先頭をval[0]、2つ目を keta[1]+1 桁目で val[1] にして、3つ目を keta[2]+1 桁目で使う時
        ans += 9 * keta[2];  // 先頭をval[0]、2つ目を keta[1]+1 桁目で val[1] にして、3つ目を 1~keta[2]
                             // 桁目のどこかで使う時
        ans += 9 * 9 * (keta[1]) * ((keta[1]) - 1) / 2;  // 先頭をval[0]、2つ目を keta[2]+1 桁目で使う時
    }

    cout << ans << endl;
    return 0;
}