[AtCoder] ABC152 D – Handstand 2 (400点)

AtCoder数え上げ, 桁の数が等しい, 400点

問題へのリンク

問題

N以下の整数の組(A, B)について、Aの先頭とBの末尾の数が等しく、Aの末尾とBの先頭の数が等しいのは何通りあるか?

制約

$$1\leq N \leq 2 \times 10^5$$

考え方1(先頭と末尾を決めてしまう)

Aの先頭と末尾をそれぞれ i, j と決めてしまうと、

c[i][j] := N以下の先頭が i で末尾が j となる数が何通りか

を使うことで素早く計算をすることができます。

例えば、Aの先頭が1で末尾が3のときは、Bは先頭が3で末尾が1となっているはずです。つまり、

c[1][3] × c[3][1] 通り

だけの組ができます。これらをすべてについて計算してあげれば良いです。以下の計算結果が答えになります。

$$ \sum_{i=1}^{9}\sum_{j=1}^{9} c[i][j]$$

プログラム例

#include <bits/stdc++.h>
#define REP(i, n) for (int i = 0; i < (n); i++)
#define RREP(i, n) for (int i = (n); i >= 0; i--)
#define FOR(i, m, n) for (int i = (m); i < (n); i++)
#define ALL(obj) begin(obj), end(obj)
using namespace std;
using ll = long long;
using ull = unsigned long long;
const int INF = 2100100100;
const int MOD = 1e9 + 7;
ll N;
ll nt, nb;
ll c[10][10];
int main() {
    cin >> N;
    FOR(i, 1, N + 1) {
        string s = to_string(i);
        int size = s.size();
        int t = s[0] - '0';
        int b = s[s.size() - 1] - '0';
        c[t][b]++;
    }
    ll ans = 0;
    FOR(i, 1, 10) {
        FOR(j, 1, 10) { ans += c[i][j] * c[j][i]; }
    }
    cout << ans << endl;
    return 0;
}

考え方2(Aを固定してしまう)

Aを固定すると、対応するBの数は実はO(logN) 程度で計算できてしまいます。全体としてはO(NlogN)です。

しかし、この計算は場合分けの実装が手間になってしまうのでおすすめできません。

Aに対応するBが、1桁のときの数、2桁の時の数、3桁以上の時の数で場合分けして考えなくてはいけません。

Bが3桁以上で、Bの先頭がNの先頭と同じ時、Bの最後尾がNの最後尾より大きいかどうかでも場合分けする必要があります。

プログラム例

#include <bits/stdc++.h>
#define REP(i, n) for (int i = 0; i < (n); i++)
#define RREP(i, n) for (int i = (n); i >= 0; i--)
#define FOR(i, m, n) for (int i = (m); i < (n); i++)
#define ALL(obj) begin(obj), end(obj)
using namespace std;
using ll = long long;
using ull = unsigned long long;
ll N;
ll nt, nb;
ll count(ll t, ll b, ll size) {
    ll ret = 0;
    if (t == b) {
        if (t <= N) {
            ret++;
        }
    }
    if (b * 10 + t <= N) {
        ret++;
    }
    ll n_size = 3;
    ll tmp = 10;
    while (n_size < size) {
        ret += tmp;
        tmp *= 10;
        n_size++;
    }
    if (size >= 3) {
        tmp *= 10;
        ll now = t;
        now += b * tmp;
        if (now <= N) {
            if (b < nt) {
                ret += tmp / 10;
            } else if (b == nt) {
                if (t <= nb) {
                    ll mid = (N - nt * tmp) / 10;
                    ret += mid + 1;
                } else if (t > nb) {
                    ll mid = (N - nt * tmp) / 10;
                    ret += mid;
                }
            }
        }
    }
    return ret;
}
int main() {
    cin >> N;
    string s = to_string(N);
    int size = s.size();
    nt = s[0] - '0';
    nb = s[s.size() - 1] - '0';
    ll ans = 0;
    FOR(a, 1, N + 1) {
        string s = to_string(a);
        int t = s[0] - '0';
        int b = s[s.size() - 1] - '0';
        if (t != 0 && b != 0) {
            ans += count(t, b, size);
        }
    }
    cout << ans << endl;
    return 0;
}