プログラムのお勉強メモ

プログラムの勉強メモです. Python, Rust, など.

ABC152_D(Handstand 2)

問題

考えたこと

  • n が  2 * 10^{5} オーダなので, n * n のオーダにならなければ良さそう
  • (1, 1) -> (1, 11) -> (1, 101)... と推移していって, xの種類 * yの種類を計算すれば良い
    • (101, 101) の場合, 101, 111, 121...191 の10種類なので 10 * 10 種類
    • (x, y) と (y, x) の両方があるので, 2倍しながら足していけば良さそう
    • x = y の時は2倍にしてはいけない
    • a > b の時は既に加算済なので加算しない
    • a と bの最大値(初期値101なら191)が最大値より大きい場合, 終端処理が必要
  • 実装はできそうだったのでトライ

実装

n = int(input())

ans = 0
a = 1


def check_max(x, n):
    if len(str(x)) <= 2:
        return 1
    elif len(str(x)) == 3:
        x_max = x + 90
    elif len(str(x)) == 4:
        x_max = x + 990
    elif len(str(x)) == 5:
        x_max = x + 9990
    else:
        x_max = x + 99990

    if x_max <= n:
        return 10**max(0, len(str(x))-2)
    else:
        res = n - int(str(n)[0]) * 10**(len(str(n))-1)
        res //= 10

        check = n // 10
        check *= 10

        if check + int(str(x)[-1]) <= n:
            return res+1
        else:
            return res


while True:
    c = str(a)

    if len(c) == 1:
        b = a
    else:
        if c[-1] == "0":
            continue
        else:
            b = int(c[-1]) * 10**(len(c)-1) + int(c[0])

    while b <= n:
        a_c = check_max(a, n)
        b_c = check_max(b, n)

        if a > b:
            count = 0
        elif a == b:
            count = a_c * b_c
        else:
            count = a_c * b_c * 2

        ans += count

        b = int(str(b)[0]) * 10**(len(str(b))) + int(str(b)[-1])

    if int(c[-1]) < 9:
        a += 1
        c = str(a)
    elif int(c[0]) < 9:
        a = (int(c[0]) + 1) * 10**(len(c) - 1) + 1
    else:
        a = 1 * 10**len(c) + 1
        c = str(a)

    if a > n:
        break

print(ans)

結果

  • バグで WA を出してしまったが, 無事に AC できた
    • テストケースの1個が間違っていたまま提出していたらしい
    • 完全に見間違えていた...勿体ない
  • AC できた一方で, こんなに実装重くなる!?と思ったため, AC 後に他に人の答えを見る
  • どうやらもっと簡単に解けたらしい

簡単な解法

  • count[x][y] を x を 先頭文字, y が末尾文字 の数値の出現数として置く
  • 先頭の数字と末尾の数字が何パターンあるか算出する(1~nまで回す)
  • 出現数 * 出現数 を 全パターンやれば全ての出現パターンを網羅できる
    • dp[x][y] * dp[y][x] を 全パターン加算すれば答えが得られる

簡単な解法による実装

n = int(input())

dp = [[0 for _ in range(10)] for _ in range(10)]

for i in range(n+1):
    a = str(i)[0]
    b = str(i)[-1]
    dp[int(a)][int(b)] += 1

ans = 0
for y in range(1, 10):
    for x in range(1, 10):
        ans += dp[y][x] * dp[x][y]

print(ans)

感想

  • 嘘でしょ...実装瞬殺なんだけど...
  • この発想はなかったな。他の問題に応用できる気がしない。。。