pyhaya’s diary

機械学習系の記事をメインで書きます

AtCoder Beginners Contest (ABC) 011 C: 123引き算 を解いた

競技プログラミング初心者が初心者向けに問題の解説を行います。

使用環境

問題文

今回挑戦するのは、次の問題です。

あなたは、友人から、一人用のゲームを紹介されました。

最初に、数字 N が与えられます。1, 2, 3 の中から好きな数字を選び、 与えられた数字に対し、引き算を行う、という処理を行うことできます。この処理は100回まで行うことが可能であり、最終的に数字を0にすることが目標のゲームです。しかし、計算途中でなってはいけないNG数字が3つ与えられており、 この数字に一時的にでもなってしまった瞬間、このゲームは失敗となります。 NG数字がN と同じ場合も失敗となります。

あなたは、このゲームが、目標達成可能なゲームとなっているか調べたいです。

目標達成可能な場合はYES、そうでない場合はNOと出力してください。

入力は以下の形式で標準入力から与えられる。

N
NG1
NG2
NG3

1行目には、最初に与えられる数字N(1≦N≦300)が与えられる。
2行目には、1番目のNG数字NG1(1≦NG1≦300)が与えられる。
3行目には、2番目のNG数字NG2(1≦NG2≦300)が与えられる。
4行目には、3番目のNG数字NG3(1≦NG3≦300) が与えられる。

例えば N=2NG1=1, NG2=7, NG3=15の時には1を引くとNG1に一致してしまうのでダメで、2を引けば無事0になるので出力は"YES"となります。

単純に考える

単純に考えると、この問題は再帰関数を使って全探索で解くことができる気がします。

#include <iostream>
using namespace std;

int ng1, ng2, ng3;
bool solve(int res, int k) {    // res : 残り    k : 引き算を行った回数
        // 残りが0になっていればOK
	if (res == 0) {
		return true;
	}
        // 0を行き過ぎていたり、NGの数に一致してしまったらダメ
	else if (res < 0 || res == ng1 || res == ng2 || res == ng3) {
		return false;
	}
        // 引き算を100回以上したらダメ
	else if (k >= 100) {
		return false;
	}
        // それ以外だったら1引いた場合と2引いた場合と3引いた場合を計算してみて
        // どれか一つでも成功したらOK
	else {
		return solve(res - 1, k + 1) || solve(res - 2, k + 1) || solve(res - 3, k + 1);
	}

}

int main() {
	int n; cin >> n;
	cin >> ng1 >> ng2 >> ng3;

	if (solve(n, 0)) cout << "YES" << endl;
	else cout << "NO" << endl;
}

しかし、これはNが大きくなると時間的に全然間に合わなくなります。なぜなら、

return solve(res - 1, k + 1) || solve(res - 2, k + 1) || solve(res - 3, k + 1);

を見ると、1つの計算のために3種類の計算をすることになっているので、再帰が深くなると指数関数的に計算量が増えていくからです。今回、Nは300まで行きますから、単純計算で、最大の計算量は3^{300}にもなります。

満点解法

制限時間内にプログラムが愁傷するためには、視点を変える必要があります。ある数mに到達するまでの引き算の必要最低限の回数を計算していって、0になるまでに必要な回数が100回以下であればよいと考えます。

mに到達するまでの引き算の回数の最小値をdp[m]とします。配列dpには最初には十分大きな数を入れておきます。そしてNは初期値なのでdp[N]=0をセットします。後はNから順番に小さいほうに進んでいき、配列の要素を更新していきます。更新の様子を図示すると下のようになります。
f:id:pyhaya:20181113000910p:plain

これをコードに落とし込むと、

#include <iostream>
#include <algorithm>
using namespace std;

int INF = 1000000;
int ng1, ng2, ng3;
int dp[305];

int main() {
	int n; cin >> n;
	cin >> ng1 >> ng2 >> ng3;

	for (int i = 0; i < n; i++) {
		dp[i] = INF;
	}
	dp[n] = 0;

	for (int i = n; i != 0; i--) {
		if (i == ng1 || i == ng2 || i == ng3) {
			continue;
		}
		for (int j = 1; j < 4; j++) {
			dp[i - j] = min(dp[i - j], dp[i] + 1);
		}
	}

	if (dp[0] <= 100) cout << "YES" << endl;
	else cout << "NO" << endl;
}

これで時間内になります。