DEV Community

dhavalraj007
dhavalraj007

Posted on • Edited on

Understanding Backtracking/Recursion

recursive tree/choice tree
Abstract Pseudo code to help understand the gist of backtracking

bool Solve(configuration conf)
{
    if (no more choices) // BASE CASE
       return (if conf is goal state);

    for (all available choices)
    {
        try one choice c;
        // recursively solve after making choice
        ok = Solve(updated conf with choice c made);
        if (ok)
           return true;
        else
           unmake choice c;
    }

return false; // tried all choices, no soln found
}
Enter fullscreen mode Exit fullscreen mode

When you need to find all possible solutions

void Solve(configuration conf,solutions& sols)
{
    if (no more choices) // BASE CASE
       if conf is goal state
           sols.add(conf);
       return;

    for (all available choices)
    {
        try one choice c;
        // recursively solve after making choice
        Solve(updated conf with choice c made);
        unmake choice c;
    }

return; // tried all choices, no soln found
}
Enter fullscreen mode Exit fullscreen mode

Examples:

N-Queens

Problem statement

Given a grid of size NxN arrange N queens in a grid such that no queen attacks other.

Input:

N = 4
Enter fullscreen mode Exit fullscreen mode

Output:

0  1  0  0
0  0  0  1
1  0  0  0
0  0  1  0
Enter fullscreen mode Exit fullscreen mode

Pseudo Code

bool solve(grid,queens)
{
    //if(no more choices) //conf is goal state // BASE CASE
    if(queens==0)           
        return true;

    //for (all available choices)
    for(all safe grid cell c)
    {
        //try one choice c;
        // recursively solve after making choice
        put queen at cell c;
        // ok = Solve(updated conf with choice c made);
        ok = solve(grid,queens-1);

        if(ok)
            return true;
        else
            remove queen from cell c;   //unmake choice c;
    }

    return false;       // tried all choices, no soln found
}
Enter fullscreen mode Exit fullscreen mode

Cpp code

//checks whether the cell (i,j) is unsafe or safe
bool is_Atttacked(int i, int j, vector<vector<int>>& grid)
{
    int N = grid.size();
    for (int ele : grid[i])
    {
        if (ele == 1)
            return true;
    }


    for (int row = 0; row < N; row++)
    {
        if (grid[row][j] == 1)
            return true;
    }

    for (int y = 0; y < N; y++)
    {
        int x = (i - j) + y;
        if (x >= 0 and x < N)
        {
            if (grid[x][y] == 1)
                return true;
        }
    }


    for (int y = 0; y < N; y++)
    {
        int x = (i + j) - y;
        if (x >= 0 and x < N)
        {
            if (grid[x][y] == 1)
                return true;
        }
    }

    return false;
}

//main recursive function
bool NQueens(vector<vector<int>>& grid, int N)
{
    if (N == 0)
        return true;

    for (int i = 0; i < grid.size(); i++)
    {
        for (int j = 0; j < grid.size(); j++)
        {
            if (is_Atttacked(i, j, grid))
                continue;

            grid[i][j] = 1;

            if (NQueens(grid, N - 1))
                               return true;
                        else
                    grid[i][j] = 0;
        }
    }
    return false;
}
Enter fullscreen mode Exit fullscreen mode

Knights Tour

Problem statement

Given a NxN board with the Knight placed on the first block of an empty board. Moving according to the rules of chess knight must visit each square exactly once. Print the order of each the cell in which they are visited.

Input:

N = 5
Enter fullscreen mode Exit fullscreen mode

Output:

0  5  14 9  20 
13 8  19 4  15 
18 1  6  21 10 
7  12 23 16 3 
24 17 2  11 22
Enter fullscreen mode Exit fullscreen mode

Pseudo code

//grid filled with -1 and grid[0][0] = 0 , steps = 1 , x=y=0
bool solve(grid,steps,x,y)
{
    if(steps==grid.size()^2)
    {
        print(grid);
        return true;
    }

    for(all possible/safe moves from (x,y))
    {
        grid[move.x][move.y] = steps;
        ok = solve(grid,steps+1,move.x,move.y);
        if(ok)
            return true;
        else
            grid[move.x][move.y] = -1;
    }
    return false;
}
Enter fullscreen mode Exit fullscreen mode

Cpp code

#include <iostream>
#include <vector>

using namespace std;


void print(vector<vector<int>> &grid)
{
    for(auto row:grid)
    {
        for(auto ele:row)
        {
            cout<<ele<<" ";
        }
        cout<<endl;
    }
    cout<<endl;
}


int isSafe(int x, int y,vector<vector<int>>& grid )
{
    return (x >= 0 && x < grid.size() && y >= 0 && y < grid.size()
            && grid[x][y] == -1);
}

bool solve(vector<vector<int>> grid,int steps,int x,int y,vector<int> &xMove,vector<int> &yMove)
{

    if(steps == grid.size()*grid.size())
    {
        print(grid);
        return true;
    }
    for(int i=0;i<8;i++)    //for all possible moves from (x,y)
    {
        int Xnext = x+xMove[i];
        int Ynext = y+yMove[i];
        if(isSafe(Xnext,Ynext,grid))
        {
            grid[Xnext][Ynext]=steps;
            int ok = solve(grid,steps+1,Xnext,Ynext,xMove,yMove);
            if(ok)
                return true;
            else
                grid[Xnext][Ynext] = -1;
        }
    }

    return false;
}


int main() {
    int N = 5;
    vector<vector<int>> grid(N,vector<int>(N,-1));
    grid[0][0] = 0;
    vector<int> xMove{ 2, 1, -1, -2, -2, -1, 1, 2 };
    vector<int> yMove{ 1, 2, 2, 1, -1, -2, -2, -1 };

    solve(grid,1,0,0,xMove,yMove);
}
Enter fullscreen mode Exit fullscreen mode

Cpp code for all possible solutions

#include <iostream>
#include <vector>

using namespace std;

void print(vector<vector<int>> &grid)
{
    for(auto row:grid)
    {
        for(auto ele:row)
        {
            cout<<ele<<" ";
        }
        cout<<endl;
    }
    cout<<endl;
}


int isSafe(int x, int y,vector<vector<int>>& grid )
{
    return (x >= 0 && x < grid.size() && y >= 0 && y < grid.size()
            && grid[x][y] == -1);
}


void solve(vector<vector<int>> grid,int steps,int x,int y,vector<int> &xMove,vector<int> &yMove)
{

    if(steps == grid.size()*grid.size())
    {
        print(grid);
        return;
    }
    for(int i=0;i<8;i++)    //for all possible moves from (x,y)
    {
        int Xnext = x+xMove[i];
        int Ynext = y+yMove[i];
        if(isSafe(Xnext,Ynext,grid))
        {
            grid[Xnext][Ynext]=steps;
            solve(grid,steps+1,Xnext,Ynext,xMove,yMove);
            grid[Xnext][Ynext] = -1;
        }
    }

    return;
}


int main() {
    int N = 5;
    vector<vector<int>> grid(N,vector<int>(N,-1));
    grid[0][0] = 0;
    vector<int> xMove{ 2, 1, -1, -2, -2, -1, 1, 2 };
    vector<int> yMove{ 1, 2, 2, 1, -1, -2, -2, -1 };

    solve(grid,1,0,0,xMove,yMove);
}
Enter fullscreen mode Exit fullscreen mode

Find all Paths in a Maze/Rat in a Maze

Problem Statement:

Given a binary matrix of size NxN find and print Path from top left corner to bottom right corner.

Input:

1 0 0 0
1 1 1 1
0 1 0 1
1 1 1 1
Enter fullscreen mode Exit fullscreen mode

Output:

1 0 0 0
1 1 0 0
0 1 0 0
0 1 1 1

1 0 0 0
1 1 1 1
0 0 0 1
0 0 0 1
Enter fullscreen mode Exit fullscreen mode

Cpp code for all Possible Paths

#include <iostream>
#include <vector>

using namespace std;

void print(vector<vector<int>>& grid)
{
    for(int i=0;i<grid.size();i++)
    {
        for(int j=0;j<grid.size();j++)
        {
            cout<<grid[i][j]<<" ";
        }
        cout<<endl;
    }
    cout<<endl;
}


bool isValid(vector<vector<int>> &maze,vector<vector<int>> &ans,int x,int y)
{
    return (x>=0 and x<maze.size() and y>=0 and y<maze.size() and maze[x][y]==1 and ans[x][y]==0);
}

void findPath(vector<vector<int>> &maze,vector<vector<int>> ans,int x,int y)
{
    if(ans[maze.size()-1][maze.size()-1] == 1)
        {
            print(ans);
            return;
        }

    if(isValid(maze,ans,x+1,y))     //down
    {
        ans[x+1][y] = 1;
        findPath(maze,ans,x+1,y);
        ans[x+1][y] = 0;
    }

    if(isValid(maze,ans,x-1,y))     //up
    {
        ans[x-1][y] = 1;
        findPath(maze,ans,x-1,y);
        ans[x-1][y] = 0;
    }

    if(isValid(maze,ans,x,y+1))     //right
    {
        ans[x][y+1] = 1;
        findPath(maze,ans,x,y+1);
        ans[x][y+1] = 0;
    }

    if(isValid(maze,ans,x,y-1))     //left
    {
        ans[x][y-1] = 1;
        findPath(maze,ans,x,y-1);
        ans[x][y-1] = 0;
    }
    return;
}

int main() {
    int x,y;
    x=0,y=0;

    vector<vector<int>> maze = { 
                    { 1, 0, 0, 0 },
                    { 1, 1, 1, 1 },
                    { 0, 1, 0, 1 },
                    { 1, 1, 1, 1 } };
    int N = maze.size();
    vector<vector<int>> ans(N,vector<int>(N,0));
    ans[0][0] = 1;

    findPath(maze,ans,x,y);
}
Enter fullscreen mode Exit fullscreen mode

Find all possible combinations satisfying given constraints

Problem Statement

Given a number N find all possible combination of 2N numbers such that every element from 1 to 2N appears exactly twice and the distance between them is equal to that number.

Input:

N = 3
Enter fullscreen mode Exit fullscreen mode

Output:

 3 1 2 1 3 2
 2 3 1 2 1 3
Enter fullscreen mode Exit fullscreen mode

In this problem, recursive tree can be formed in two ways. first way is you first consider N elements as choices and other is you consider 2N positions as choices. In first way you choose one of N elements and place it at the 0th position in array and then you recurs for all other indices(here you need to have some way to not put the number more than once). In second way you take first element from N and choose one of 2N locations to put the first element and recurs for all other N-1 numbers. as the second way seems much easy we will implement that.

Pseudo Code

void solve(array,n)
{
    if(n becomes 0)
    {
        print array
        return;
    }

    for(all indices i in array)
    {
        if(inserting in n in array at index i is safe)  // safe means array[i] and array[i+n+1] is free
        {
            array[i] = n;       // put n at i and i+n+1
            array[i+n+1] = n;

            solve(array,n-1);   //recur for all n-1 numbers

            array[i] = -1;      //unmake choice
            array[i+n+1] = -1;
        }
    }
    return; //no soln found
}
Enter fullscreen mode Exit fullscreen mode

Cpp code

#include <iostream>
#include <vector>

using namespace std;

void print(vector<int> ar)
{
    for(auto e:ar)
    {
        cout<<e<<" ";
    }
    cout<<endl;
}

bool is_safe(vector<int> array,int n,int i)
{
    return (i<array.size() and i+n+1<array.size() and array[i]==-1 and array[i+n+1]==-1);
}


void solve(vector<int> array,int n)
{
    if(n==0)
    {
        print(array);
        return;
    }

    for(int i=0;i<array.size();i++)
    {
        if(is_safe(array,n,i))
        {
            array[i] = n;
            array[i+n+1]=n;
            solve(array,n-1);
            array[i] = -1;
            array[i+n+1] = -1;
        }
    }
    return;
}

int main() {
    int N = 7;
    vector<int> v(2*N,-1);
    solve(v,N);
}
Enter fullscreen mode Exit fullscreen mode

K - Partitions

Problem Statement

Given an array of integers Partition the array into K subsets having equal sum.

Input:

k = 3
array = 7 3 5 12 2 1 5 3 8 4 6 4
Enter fullscreen mode Exit fullscreen mode

Output:

7 3 5 2 3 
12 8 
1 5 4 6 4 
Enter fullscreen mode Exit fullscreen mode

In this problem, for each element we have to choose one of k partitions. So choices are k partitions and then we recurs for all other elements.
Pseudo code

bool solve(partitons,sums,array,index,k)
{
    if(index is at end of array +1 and all the sums are equal)
    {
        print partitions;
        return true;
    }
    else if(index is at the ened of array +1 and all sums are not equal)
        return false;

    //choices for arrray[index] element
    for(all the part in partitions)
    {
        part.add(array[index])
        correspnding sum in sums+= array[index]
        //recurs for next element
        ok = solve(partitions,sums,array,index+1,k);    
        if(ok)
            return true;
        else
        {
            part.pop()
            correspnding sum in sums-= array[index];
        }
    }
    return false;   //all choices tried no solutions.
}
Enter fullscreen mode Exit fullscreen mode

Cpp code

#include <iostream>
#include <algorithm>
#include <vector>

using namespace std;

void print(vector<vector<int>> grid)
{
    for(auto row:grid)
    {
        for(auto ele:row)
        {
            cout<<ele<<" ";
        }
        cout<<endl;
    }
    cout<<endl;
}

void print(vector<int> v)
{
    for(auto ele:v)
        cout<<ele<<" ";
        cout<<endl;
}

bool solve(vector<vector<int>> parts,vector<int> sums,vector<int> vec,int index,int k)
{
   // if reached the end and all partition sum is equal
    if(index == vec.size() and all_of(begin(sums),end(sums),[&sums](int val){return val==sums[0];}))
    {
        print(parts);
        return true;
    }
    // if reached the end but all partition sum is not equal
    else if(index>=vec.size())
    {
        return false;
    }

    for(int i=0;i<k;i++)
    {
        parts[i].push_back(vec[index]);
        sums[i] += vec[index];
       // print(sums);
        bool ok = solve(parts,sums,vec,index+1,k);
        if(ok)
        {
            return true;
        }
        else
        {
            parts[i].pop_back();
            sums[i] -= vec[index]; 
        }
    }

    return false;
}

int main() {
    vector<int> vec{7, 3, 5, 12, 2, 1, 5, 3, 8, 4, 6, 4};
    int k = 3;
    vector<int> sums(k,0);
    vector<vector<int>> parts(k,vector<int>());
    solve(parts,sums,vec,0,k);
}
Enter fullscreen mode Exit fullscreen mode

Print All Combinations with repetitions allowed

Problem Statement

Given a array and an integer k print its all k element combinations with repetitions allowed.
update: I got a more general idea about this problem and made a separate post about it. check out it here.

Input:

k = 2
array = { 1, 2, 3 }
Enter fullscreen mode Exit fullscreen mode

Output

1 1  1 2  1 3  2 2  2 3  3 3  
Enter fullscreen mode Exit fullscreen mode

Input:

k = 3
array = { 1, 2, 3, 4 }
Enter fullscreen mode Exit fullscreen mode

Output

1 1 1  1 1 2  1 1 3  1 1 4  1 2 2  1 2 3  1 2 4  1 3 3   1 3 4  1 4 4  2 2 2  2 2 3  2 2 4  2 3 3  2 3 4  2 4 4   3 3 3  3 3 4  3 4 4  4 4 4
Enter fullscreen mode Exit fullscreen mode

First of all Notice that this is the standard permutations and combination thing. See the figure below.
Permutation and combination choosing diagram
As our problem requires no order and repetitions the formula is the third one( if you wanna know how check out this page). And that formula gives the number answers for our problem. But that doesn't help.
Notice the the pattern in the first column of the third table. To solve this problem using recursion, we first think of what choices are we making. hmmm, we are making an smaller array of k-elements from the bigger input data array. and for each of the index of smaller array we are making choices as to what to put at that index. for the index 0 of k-element array we can have all the elements from the bigger array. at the next index (referring to children calls) we can have only elements after that element(referring to current parent node). its little confusing. to make it easier to understand look at the recursion tree for the first input provided above.
Recursion tree
here if you observe the node 2_ then at the next level you can only have elements after and including 2. (this is just to avoid the reverse copies that is because 12 and 21 are same thing).

Cpp Code

#include <iostream>
#include <vector>

using namespace std;

void print(vector<int> data)
{
    for (auto e : data)
    {
        cout << e << " ";
    }
    cout << endl;
}

//data index is the index for the actual data and out_index is for output k-element array
void solve(vector<int> data, int data_index, vector<int> out, int out_index)
{
    if (out_index == out.size())
    {
        print(out);
        return;
    }

    for (int i = data_index; i < data.size(); i++)
    {
        if(out[out_index]==-1)
            out[out_index] = data[i];
        solve(data,i, out, out_index + 1);
        out[out_index] = -1;
    }

    return;
}

int main() {
    vector<int> data{ 1,2,3,4 };
    int r = 3;
    vector<int> out(r, -1);
    solve(data, 0, out, 0);
}
Enter fullscreen mode Exit fullscreen mode

Print all combinations of numbers from 1 to n having sum n

Problem statement

Print all combinations of numbers from 1 to n having sum n

Input:

n = 3
Enter fullscreen mode Exit fullscreen mode

Output:

1 1 1
1 2
3
Enter fullscreen mode Exit fullscreen mode

Input:

n = 4
Enter fullscreen mode Exit fullscreen mode

Output:

1 1 1 1
1 1 2
1 3
2 2
4
Enter fullscreen mode Exit fullscreen mode

If you observe this problem carefully and you have read my permutation and combination post then you might identify solution to this problem.
In this problem it is clear that repetitions are allowed and order does not matter. and the set to choose the elements from is 1...sum . so n is sum but we don't know r which is the length of the combination. (remember in the recursive tree shown in the PNC (permutation and combination) post has intermediate nodes , these nodes are also partial combinations for example if you take all the nodes for the tree of the fourth type of combination then you will find something called powerset (set of all subsets)). We decrement the sum as we go deep and at each node we check if sum==0 and if it is then that partial solution is our one of possible solution. to take care of the order thing, In the for loop we start from the parent node.( Again if you don't know what I am talking about please checkout my PNC post.)
Cpp code

#include <iostream>
#include <vector>


using namespace std;

void print(vector<int> data)
{
    for (auto e : data)
    {
        cout << e << " ";
    }
    cout << endl;
}


void solve(int i,int n,int sum,vector<int> out)
{

    if (sum == 0)
    {
        print(out);
        return;
    }
    if (sum < 0)
    {
        return;
    }

    for (int x = i; x <= n; x++)
    {
        out.push_back(x);
        sum -= x;
        solve(x, n, sum, out); 
        sum += x;
        out.pop_back();
    }
    return;
}


int main() {
    int sum = 4;
    solve(1, sum, sum, vector<int>());
}
Enter fullscreen mode Exit fullscreen mode

Print all triplets in an array with a sum less than or equal to a given target

Problem Statement

Print all triplets in an array v with a sum less than or equal to a given target

Input:

target = 10
v = { 2, 7, 4, 9, 5, 1, 3 }
Enter fullscreen mode Exit fullscreen mode

Output:

2 7 1
2 4 1
2 4 3
2 5 1
2 5 3
2 1 3
4 5 1
4 1 3
5 1 3
Enter fullscreen mode Exit fullscreen mode

In this problem r is fixed at 3. And n is the size of the array. And the set to choose values from is array v. In this problem the order does not matter and repetitions are also not allowed. So this is the fourth type of combination.
Cpp code

#include <iostream>
#include <algorithm>
#include <vector>

using namespace std;

void print(vector<int> data)
{
    for (auto e : data)
    {
        cout << e << " ";
    }
    cout << endl;
}

void solve(vector<int> v, int target, int r, int sum, vector<int> out = {}, int i = -1)
{
    if (r == 0 and sum <= target)
    {
        print(out);
        return;
    }
    if (sum > target)
    {
        return;
    }

    for (int x = (i==-1?0:i); x < v.size(); x++)
    {
        if (x == i) continue;
        out.push_back(v[x]);
        sum += v[x];
        r--;
        solve(v,target,r,sum,out,x);
        r++;
        sum -= v[x];
        out.pop_back();
    }
    return;
}


int main()
{
    vector<int> v{ 2, 7, 4, 9, 5, 1, 3 };
    int target = 10;
    solve(v, target,3,0);
    return 0;
}
Enter fullscreen mode Exit fullscreen mode

Find all combinations of non-overlapping substrings of a string

Problem statement

Find all combinations of non-overlapping parenthesized substrings of a string

Input:

"ABC"
Enter fullscreen mode Exit fullscreen mode

Output:

(ABC)
(AB) (C)
(A) (BC)
(A) (B) (C)
Enter fullscreen mode Exit fullscreen mode

To solve this question, Notice where you put spaces, design the recursive tree. and then implement it.

Cpp code

void solve(string data,int index,int n)
{
    if (n == 1)
    {
        cout <<"(" << data <<")" << endl;
        return;
    }
    n--;
    solve(data, index + 1, n);
    data.insert(index, ") (");
    solve(data, index + 4, n);  // +4 because insert added 3 extra chars
    //no need to remove em because this is the last choice
}

int main()
{
    string s = "ABC";
    solve(s,1,s.length());
    return 0;
}
Enter fullscreen mode Exit fullscreen mode

// to be continued

Top comments (0)