穷举递归和回溯算法
在一般的递归函数中,如二分查找、反转文件等,在每个决策点只需要调用一个递归(比如在二分查找,在每个节点我们只需要选择递归左子树或者右子树),在这样的递归调用中,递归调用形成了一个线性结构,而算法的性能取决于调用函数的栈深度。比如对于反转文件,调用栈的深度等于文件的大小;再比如二分查找,递归深度为O(nlogn),这两类递归调用都非常高效。
现在考虑子集问题或者全排列问题,在每一个决策点我们不在只是选择一个分支进行递归调用,而是要尝试所有的分支进行递归调用。在每一个决策点有多种选择,而这多种选择的每一个选择又导致了更多的选择,直到我们碰到base case。这样的话,随着递归调用的深入,穷举递归(exhaustive recursion)算法时间复杂度就会很高。比如:在每个决策点我们需要确定选择哪个一个字母或者在当前位置选择下一步去哪个城市(TSP)。那么我们有办法避免代价高昂的穷举递归吗?答案是视情况而定。在有些情况下,我们没有办法,必须穷举递归,比如我们需要找到全局的最优解。然而在更多的情况下我们只希望找到满意解,在每个决策点,我们选择只选择一条递归调用路径,希望它能够成功,如果我们最终发现,可以得到一个满意解,OK我们不再遍历其他的情况了。否则如果这次尝试没有成功,我们退回决策点,换一个选择尝试,这就是回溯算法。值得说明的是,关于回溯的深度,我们只需要向上回溯到最近的决策点,该决策点满足还有其他的选择没有尝试。随着回溯的向上攀升,最终我们可能回到初始状态,这时候其实我们已经穷举递归了所有的情况,那么该问题是不可解的。
典型问题回顾
上面说的是不是很抽象?我也觉得,但是没办法,严谨还是要有的,说的再多不如来看几个例子来得实在,毕竟我们学习它是为了解决实际问题的。
【经典穷举问题】穷举所有的排列
问题描述:给定一个字符串,重排列后输出所有可能的排列。
在每个决策点,我们需要在剩余待处理的字符串中,选择一个字母,假设剩余字符串长度为k,那么在每个决策点我们有k种选择,我们对每个选择都尝试一次,每次选择一个后,更新当前已经字符串和剩余字符串。当剩余字符串为空时,我们到达base case,输出当前选择的字符串即可。伪代码及C++代码如下:
1 // Permutation Problem 2 // If you have no more characters left to rearrage, print the current permutation 3 // for (every possible choice among the characters left to rearrage) 4 // { 5 // Make a choice and add that character to the permutation so far 6 // Use recursion to rearrage the remaing letters 7 // } 8 // 9 void RecursivePermutation(string sofar, string remain) 10 { 11 if (remain == "") {cout << sofar << endl; reutrn;} 12 13 for (size_t i = 0; i < remain.size(); ++i) 14 { 15 string sofar2 = sofar + remain[i]; 16 string remain2 = remain.substr(0, i) + remain.substr(i+1); 17 RecursivePermutation(sofar2, remain2); 18 } 19 }
在这个问题中,我们尝试了所有可能的选择,属于穷举递归,总共有n!中排列方法。这是一个非常经典的模式,是许多递归算法的核心,比如猜字谜问题,数独问题,最优化匹配问题,调度问题等都可以通过这种模式解决。
【经典穷举问题】子集问题
问题描述:给定一个集合,列出该集合的所有子集
对于每一个决策点,我们从剩余的集合中选择一个元素后,有两种选择,子集包括该元素或者不包括该元素,这样每次递归一步的话,剩余集合中的元素就会减少一个,直到剩余集合为空,我们到达base case。伪代码及C++代码如下:
1 // Subset Problem 2 // 3 // If there are no more elements remaining, print current subset 4 // Consider the next element of those remaining 5 // Try adding it to current subset and use recursion to build subsets from here 6 // Try not adding it to current subset and use recursion to build subsets from here 7 void RecursiveSubset(string sofar, string remain) 8 { 9 // base case 10 if (remain == "") { cout << sofar << endl; return; } 11 12 char ch = remain[0]; 13 string remain2 = remain.substr(1); 14 RecursiveSubset(sofar, remain2); // choose first element 15 RecursiveSubset(sofar + ch, remain2); // not choose first element 16 }
这是另外一个穷举递归的典型例子。每次递归调用问题规模减少一个,然而会产生两个新的递归调用,因而时间复杂度为O(2^n)。这也是个经典问题,需要牢记解决该类问题的pattern,其他与之类似的问题还有最优填充问题、集合划分问题、最长公共子列问题(longest shared subsequence)等。
这两个问题看起来很像,实际上差别很大,属于不同的两类问题。在permutation问题中,我们在每次决策点是要选择一个字母包含到当前子串中,我们有n中选择(假设剩余子串长度为n),每一次选择后递归调用一次,因而有n个规模为n-1的子问题,即T(n) = n T(n-1)。而对于subset问题,我们在每个决策点对于字母的选择只能是剩余子串的首字母,而我们决策的过程为选择or not选择(这是一个问题,哈哈),我们拿走一个字母后,做了两次递归调用(对比permutation问题,我们拿下一个字母后只进行了一次递归调用),因此T(n) = 2 * T(n-1)。
总结说来:permutation问题拿走一个字母后,递归调用一次,我们的决策点是有n个字母可以拿;而subset问题是拿走一个字母后,进行了两次递归调用,我们的决策点是包括还是不包括该拿下的字母,请仔细体味两者的区别。
递归回溯
在permutation问题和subset问题中,我们探索了每一种可能性。在每一个决策点,我们对每一个可能的选择进行尝试,知道我们穷举了我们所有可能的选择。这样以来时间复杂度就会很高,尤其是如果我们有许多决策点,并且在每一个决策点我们又有许多选择的时候。而在回溯算法中,我们尝试一种选择,如果满足了条件,我们不再进行其他的选择。这种算法的一般的伪代码模式如下:
1 bool Solve(configuration conf) 2 { 3 if (no more choice) 4 return (conf is goal state); 5 6 for (all available choices) 7 { 8 try choice c; 9 10 ok = solve(conf with choice c made); 11 if (ok) 12 return true; 13 else 14 unmake c; 15 } 16 17 retun false; 18 }
写回溯函数的忠告是:将有关格局configuration的细节从函数中拿出去(这些细节包括,在每一个决策点有哪些选择,做出选择,判断是否成功等等),放到helper函数中,从而使得主体函数尽可能的简洁清晰,这有助我们确保回溯算法的正确性,同时有助于开发和调试。
我们先看第一个例子,从permutation问题中变异而来。问题是给定一个字符串,问是否能够通过重新排列组合一个合法的单词?这个问题不需要穷举所有情况,只需要找到一个合法单词即可,因而可用回溯算法加快效率。如果能够构成合法单词,我们return该单词;否则返回空串。问题的base case是检查字典中是否包含该单词。每次我们做出选择之后递归调用,判断做出当前选择之后能否成功,如果能,不再尝试其他可能;如果不能,我们换一个别的选择。代码如下:
1 string FindWord(string sofar, string rest, Dict& dict) 2 { 3 // Base Case 4 if (sofar.empty()) 5 { 6 return (dict.containWords(sofar)? sofar : ""); 7 } 8 9 for (int i = 0; i < rest.size(); ++i) 10 { 11 // make a choice 12 string sofar2 = sofar + rest[i]; 13 string rest2 = rest.substr(0, i) + rest.substr(i+1); 14 String found = FindWord(sofar2, rest2, dict); 15 16 // if find answer 17 if (!found.empty()) return found; 18 // else continue next loop, make an alternative choice 19 } 20 21 return "";
我们可以对这个算法进行进一步剪枝来早些避免进入“死胡同”。例如,如果输入字符串是"zicquzcal",一旦你发现了前缀"zc"你就没有必要再进行进一步的选择,因为字典中没有以“zc”开头的单词。具体说来,在base case中需要加入另一种终止条件,如果sofar不是有效前缀,直接返回“”。
【经典回溯问题1】八皇后问题
问题是要求在8x8的国际象棋盘上放8个queue,要求不冲突。(即任何两个queue不同行,不同列,不同对角线)。按照前面的基本范式,我们可以给出如下的伪代码及C++代码::
#include <iostream> #include <vector> using namespace std; // Start in the leftmose column // // If all queens are placed, return true // else for (every possible choice among the rows in this column) // if the queue can be placed safely there, // make that choice and then recursively check if this choice lead a solution // if successful, return true // else, remove queue and try another choice in this colunm // if all rows have been tried and nothing worked, return false to trigger backtracking const int NUM_QUEUE = 4; const int BOARD_SIZE = 4; typedef vector<vector<int> > Grid; void PlaceQueue(Grid& grid, int row, int col); void RemoveQueue(Grid& grid, int row, int col); bool IsSafe(Grid& grid, int row, int col); bool NQueue(Grid& grid, int curcol); void PrintSolution(const Grid& grid); int main() { vector<vector<int> > grid(BOARD_SIZE, vector<int>(BOARD_SIZE, 0)); if (NQueue(grid, 0)) { cout << "Find Solution" << endl; PrintSolution(grid); } else { cout << "Cannot Find Solution" << endl; } return 0; } void PlaceQueue(Grid& grid, int row, int col) { grid[row][col] = 1; } void RemoveQueue(Grid& grid, int row, int col) { grid[row][col] = 0; } bool IsSafe(Grid& grid, int row, int col) { int i = 0; int j = 0; // check row for (j = 0; j < BOARD_SIZE; ++j) { if (j != col && grid[row][j] == 1) return false; } // check col for (i = 0; i < BOARD_SIZE; ++i) { if (i != row && grid[i][col] == 1) return false; } // check left upper diag for (i = row - 1, j = col - 1; i >= 0 && j >= 0; i--, j--) { if (grid[i][j] == 1) return false; } // check left lower diag for (i = row + 1, j = col - 1; i < BOARD_SIZE && j >= 0; i++, j--) { if (grid[i][j] == 1) return false; } return true; } bool NQueue(Grid& grid, int curcol) { // Base case if (curcol == BOARD_SIZE) { return true; } for (int i = 0; i < BOARD_SIZE;++i) { if (IsSafe(grid, i, curcol)) { // try a choice PlaceQueue(grid, i, curcol); // if this choice lead a solution, return bool success = NQueue(grid, curcol + 1); if (success) return true; // else unmake this choice, try an alternative choice else RemoveQueue(grid, i, curcol); } } return false; } void PrintSolution(const Grid& grid) { for (int i = 0; i < BOARD_SIZE; ++i) { for (int j = 0; j < BOARD_SIZE; ++j) { cout << grid[i][j] << " "; } cout << endl; } cout << endl; }
【经典回溯问题2】数独问题
数独问题可以描述为在空格内填写1-9的数字,要求每一行每一列每一个3*3的子数独内的数字1-9出现一次且仅出现一次。一般数独问题会实现填写一些数字以保证解的唯一性,从而使得不需要暴力破解,只是使用逻辑推理就可以完成。这一次让我们尝试用计算机暴力回溯来得到一个解。解决数独问题的伪代码及C++代码如下:
#include <iostream> #include <string> #include <vector> #include <algorithm> #include <iterator> #include <cstdio> using namespace std; // Base Case: if cannot find any empty cell, return true // Find an unsigned cell (x, y) // for digit from 1 to 9 // if there is not conflict for digit at (x, y) // assign (x, y) as digit and Recursively check if this lead to a solution // if success, return true // else remove the digit at (x, y) and try another digit // if all digits have been tried and still have not worked out, return false to trigger backtracking const int GRID_SIZE = 9; const int SUB_GRID_SIZE = 3; typedef vector<vector<int> > Grid; bool IsSafe(const Grid& grid, int x, int y, int num); bool FindEmptyCell(const Grid& grid, int& x, int& y); bool Sudoku(Grid& grid); void PrintSolution(const Grid& grid); int main() { freopen("sudoku.in", "r", stdin); vector<vector<int> > grid(GRID_SIZE, vector<int>(GRID_SIZE, 0)); for (int i = 0; i < GRID_SIZE; ++i) { for (int j = 0; j < GRID_SIZE; ++j) { cin >> grid[i][j]; } } if (Sudoku(grid)) { cout << "Find Solution " << endl; PrintSolution(grid); cout << endl; } else { cout << "Solution does not exist" << endl; } return 0; } bool Sudoku(Grid& grid) { // base case int x = 0; int y = 0; if (!FindEmptyCell(grid, x, y)) return true; // for all the number for (int num = 1; num <= 9; ++num) { if (IsSafe(grid, x, y, num)) { // try one choice grid[x][y] = num; // if this choice lead to a solution if (Sudoku(grid)) return true; // otherwise, try an alternative choice else grid[x][y] = 0; } } return false; } bool IsSafe(const Grid& grid, int x, int y, int num) { // check the current row for (int j = 0; j < grid[x].size(); ++j) { if (j != y && grid[x][j] == num) return false; } // check current col for (int i = 0; i < grid.size(); ++i) { if (i != x && grid[i][y] == num) return false; } // check the subgrid int ii = x / 3; int jj = y / 3; for (int i = ii * SUB_GRID_SIZE; i < (ii+1) * SUB_GRID_SIZE; ++i) { for (int j = jj * SUB_GRID_SIZE; j < (jj+1) * SUB_GRID_SIZE; ++j) { if (i != x || j != y) { if (grid[i][j] == num) return false; } } } return true; } // Find next Empty Cell bool FindEmptyCell(const Grid& grid, int& x, int& y) { for (int i = 0; i < GRID_SIZE; ++i) { for (int j = 0; j < GRID_SIZE; ++j) { if (grid[i][j] == 0) { x = i; y = j; return true; } } } return false; } void PrintSolution(const Grid& grid) { for (int i = 0; i < GRID_SIZE; ++i) { for (int j = 0; j < GRID_SIZE; ++j) { cout << grid[i][j] << " "; } cout << "\n"; } cout << endl; }
【经典回溯问题3】迷宫搜索问题
该问题在实现给定一些黑白方块构成的迷宫,其中黑块表示该方块不能通过,白块表示该方块可以通过,并且给定迷宫的入口和期待的出口,要求找到一条连接入口和出口的路径。有了前面的题目的铺垫,套路其实都是一样的。在当前位置,对于周围的所有方块,判断可行性,对于每一个可行的方块,就是我们当前所有可能的choices;尝试一个choice,递归的判断是否能够导致一个solution,如果可以,return true;否则,尝试另一个choice。如果所有的choice都不能导致一个成功解,return false。剩下的就是递归终止的条件,当前所在位置如果等于目标位置,递归结束,return true。C++代码如下:
#include <iostream> #include <string> #include <vector> using namespace std; const int BOARD_SIZE = 4; enum GridState {Gray, White, Green}; const int DIRECTION_NUM = 2; const int dx[DIRECTION_NUM] = {0, 1}; const int dy[DIRECTION_NUM] = {1, 0}; typedef vector<vector<GridState> > Grid; bool IsSafe(Grid& grid, int x, int y); bool SolveRatMaze(Grid& grid, int curx, int cury); void PrintSolution(const Grid& grid); int main() { vector<vector<GridState> > grid(BOARD_SIZE, vector<GridState>(BOARD_SIZE, White)); for (int j = 1; j < BOARD_SIZE; ++j) grid[0][j] = Gray; grid[1][2] = Gray; grid[2][0] = Gray; grid[2][2] = Gray; grid[2][3] = Gray; // Place the init position grid[0][0] = Green; bool ok = SolveRatMaze(grid, 0, 0); if (ok) { cout << "Found Solution" << endl; PrintSolution(grid); } else { cout << "Solution does not exist" << endl; } return 0; } bool SolveRatMaze(Grid& grid, int curx, int cury) { // base case if (curx == BOARD_SIZE - 1 && cury == BOARD_SIZE - 1) return true; // for every choice for (int i = 0; i < DIRECTION_NUM; ++i) { int nextx = curx + dx[i]; int nexty = cury + dy[i]; if (IsSafe(grid, nextx, nexty)) { // try a choice grid[nextx][nexty] = Green; // check whether lead to a solution bool success = SolveRatMaze(grid, nextx, nexty); // if yes, return true if (success) return true; // no, try an alternative choice, backtracking else grid[nextx][nexty] = White; } } // try every choice, still cannot find a solution return false; } bool IsSafe(Grid& grid, int x, int y) { return grid[x][y] == White; } void PrintSolution(const Grid& grid) { for (int i = 0; i < BOARD_SIZE; ++i) { for (int j = 0; j < BOARD_SIZE; ++j) { cout << grid[i][j] << " "; } cout << "\n"; } cout << endl; }
本文小结
递归回溯算法想明白了其实很简单,因为大部分工作递归过程已经帮我们做了。再重复一下,递归回溯算法的基本模式:识别出当前格局,识别出当前格局所有可能的choice,尝试一个choice,递归的检查是否导致了一个solution,如果是,直接return true;否则尝试另一个choice。如果尝试了所有的choice,都不能导致一个解,return false从而触发回溯过程。剩下的就是在函数的一开始定义递归终止条件,这个需要具体问题具体分析,一般情况下是,当前格局等于目标格局,递归终止,return false。
在理解了递归回溯算法的思想后,记住经典的permutation问题和子集问题,剩下就是多加练习和思考,基本没有太难的问题。在geekforgeeks网站有一个回溯算法集合Backtracking,题目很经典过一遍基本就没什么问题了。
参考文献
[1] Exhaustive recursion and backtracking
[2] www.geeksforgeeks.org-Backtracking
[3] Backtracking algorithms "CIS 680: DATA STRUCTURES: Chapter 19: Backtracking Algorithms"