pyhaya’s diary

機械学習系の記事をメインで書きます

AtCoder Beginners Contest 128 C. Switches を解いたのでメモ

久しくブログを更新していなかった気がするので、AtCoder の過去問を解いた記録でも書いてみる。AtCoder Beginners Contest (通称 ABC) 128 の C 問題を解いてみたら再帰関数の良い練習になったのでこれを紹介します。


問題文は下のリンクを参照。

atcoder.jp

解答環境

解く方針

問題文を読んで最初に思ったのが、変数が多いということ。N がスイッチの数で M が電球の数、k が各電球につながっているスイッチの数で、sが各電球につながっているスイッチの番号、pが各電球の点灯条件。。。頭ごっちゃになりそう。変数の持ち方を考えるが、あまり深く考えずに N, M は int、s は多次元ベクトルでまとめて、p もベクトルとして持つことにする。

N, M も大きくなっても 10 までなのでスイッチの状態を全列挙してそれぞれの状態でスイッチが全部つくかどうかを調べることにする。

コード

Step 1

ということで早速、解答の雛形を書いてみる。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;

int N, M;
vector<vector<int>> info;

int main() {
  // N: switch
  // M: light
  cin >> N >> M;
  vector<int> P(M);

  for (int i = 0; i < M; i++) {
    int k; cin >> k;
    vector<int> tmp(k);
    for (int j = 0; j < k; j++) {
      int a; cin >> a;
      tmp[j] = a - 1;
    }

    info.push_back(tmp);
  }
  for (int i = 0; i < M; i++) {
    cin >> P[i];
  }
  vector<int> init_state(N, 0); // present state of switches
  int res = calc(0, P, init_state);

  cout << res << endl;
}

このコードではデータの読み込みしか書いてない。具体的な計算は calc に丸投げしてる。書くことがあるとすれば電球の番号が1始まりで与えられていて扱いづらいので0始まりにしていることくらい (19 行目)。あとは N, M, info をグローバル変数にしていることくらいか。これは単純に calc にたくさん引数を与えるのが面倒という理由しか無い。

Step 2

calc 関数を書いていく。

int calc(int now_idx, vector<int> P, vector<int> now_state) {
  if (now_idx == N) {
    return is_valid(P, now_state);
  }

  vector<int> new_state;  // state that 'now_idx'th switch is on.
  for (int i = 0; i < now_state.size(); i++) {
    if(i == now_idx) new_state.push_back(1);
    else new_state.push_back(now_state[i]);
  }

  return calc(now_idx + 1, P, now_state) + calc(now_idx + 1, P, new_state);
}

よくある再帰関数の形。main() からは now_state として 0 のみを要素に持つベクトルが与えられる。これは全部のスイッチがoffである状態。スイッチの状態の総数は最初のスイッチが off の場合と on の場合の和になることを利用している。終端条件は、最後のスイッチの on/off を決めた時でこのときに return 1; とすれば単純にスイッチの状態の総数が返るようになる。しかしこの問題では状態に条件がついていて、このスイッチの状態で全ての電球が点灯していなければならない。なのでここでは単純に 1 を返さずにもうワンステップ噛ませる。この判定は is_valid に任せる。

Step 3
int is_valid(vector<int> P, vector<int> state) {
  // P: how to turn on the lights
  map<int, int> search;
  for (int i = 0; i < state.size(); i++) {
    if (state[i] == 1) search[i] = 1;
  }

  vector<int> P_copy;
  copy(P.begin(), P.end(), back_inserter(P_copy));

  for (int i = 0; i < M; i++) {
    for (int j = 0; j < info[i].size(); j++) {
      if (search[info[i][j]]) P_copy[i] = !P_copy[i];
    }
  }

  int count = 0;
  for (int i = 0; i < P_copy.size(); i++) {
    if (P_copy[i] == 0) count++;
  }

  return count == P.size();
}

各電球について、つながっているスイッチの中で on になっている数を数えてそれを2で割ったあまりが点灯条件を満たすか調べる。1つ1つの電球に対してつながっているスイッチの on/off をベクトルから調べるのは面倒なので最初に search というマップにして持つ。

あとは「on になっている数を数えてそれを2で割ったあまりが点灯条件を満たすか調べる」部分だが、これも簡単にかけないか考えると、P の値を スイッチが on であるたびに 0, 1 反転させて 0 になっていればその電球はついていることに気づく。どういうことかというと、ある電球が on になっているスイッチが奇数のときにつくとすると、この電球に対応する p の値は 1 である。この電球につながっている3つのスイッチが on になっているとして p を 3 回反転させると 0 になる。2回なら 1 になる。逆にスイッチの数が偶数のときに電球がつくとするとp = 0 で、2つのスイッチが on になっている状態では p を2回反転させると 0、3つだと 1 になる。つまり電球のつく条件がどちらでも on になっているスイッチの数だけ p の値を反転させるたときに p が 1になっていれば消えているし、0 になっていればついている。これをコードにしているのが 13 行目。

全体のコード

以上のコードをまとめると下のようになる。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;

int N, M;
vector<vector<int>> info;

int is_valid(vector<int> P, vector<int> state) {
  // P: how to turn on the lights
  map<int, int> search;
  for (int i = 0; i < state.size(); i++) {
    if (state[i] == 1) search[i] = 1;
  }

  vector<int> P_copy;
  copy(P.begin(), P.end(), back_inserter(P_copy));

  for (int i = 0; i < M; i++) {
    for (int j = 0; j < info[i].size(); j++) {
      if (search[info[i][j]]) P_copy[i] = !P_copy[i];
    }
  }

  int count = 0;
  for (int i = 0; i < P_copy.size(); i++) {
    if (P_copy[i] == 0) count++;
  }

  return count == P.size();
}

int calc(int now_idx, vector<int> P, vector<int> now_state) {
  if (now_idx == N) {
    return is_valid(P, now_state);
  }

  vector<int> new_state;
  for (int i = 0; i < now_state.size(); i++) {
    if(i == now_idx) new_state.push_back(1);
    else new_state.push_back(now_state[i]);
  }

  return calc(now_idx + 1, P, now_state) + calc(now_idx + 1, P, new_state);
}

int main() {
  // N: switch
  // M: light
  cin >> N >> M;
  vector<int> P(M);

  for (int i = 0; i < M; i++) {
    int k; cin >> k;
    vector<int> tmp(k);
    for (int j = 0; j < k; j++) {
      int a; cin >> a;
      tmp[j] = a - 1;
    }

    info.push_back(tmp);
  }
  for (int i = 0; i < M; i++) {
    cin >> P[i];
  }
  vector<int> init_state(N, 0); // present state of switches
  int res = calc(0, P, init_state);

  cout << res << endl;
}