Algorithm to count permutations with specific fixed points and relative value constraints
Asked Answered
E

1

5

I'm trying to optimize an algorithm that counts permutations with specific constraints. Given integers n, t, a, b where:

  • n is the length of permutation (1 to n)
  • t is the required number of fixed points (numbers in their original position)
  • a is the required number of elements less than their position
  • b is the required number of elements greater than their position

For example, with n=3, t=1, a=1, b=1:

The permutation [1,3,2] is valid because:

  • 1 is in its original position (counts toward t=1)
  • 2 is less than position 3 (counts toward a=1)
  • 3 is greater than position 2 (counts toward b=1)

Current Issues:

  1. Memory usage is O(2^n * n) which becomes problematic for n > 20
  2. Runtime is approximately O(n * 2^n)
  3. Many redundant calculations in the DP transitions

Questions:

  1. Are there any known optimizations for this type of combinatorial counting?
  2. Could this be solved without using bitmask DP?
  3. Are there mathematical properties I could leverage to reduce the complexity?

I've considered:

  • Using inclusion-exclusion principle
  • Trying to find a mathematical formula
  • Using a different DP state representation

The constraints are:

  • 1 ≤ n ≤ 100
  • 0 ≤ t, a, b ≤ n
  • t + a + b must equal n

I just can't produce a code that would be fast enough. Most importantly here's the code:

    #include <iostream>
    #include <vector>
    #include <algorithm>

    using namespace std;
    using ll = long long;

    const int MAXN = 21;

    ll combination[MAXN][MAXN];

    void init_combination() {
       for (int n = 0; n < MAXN; ++n) {
          combination[n][0] = combination[n][n] = 1;
          for (int k = 1; k < n; ++k)
               combination[n][k] = combination[n - 1][k - 1] + 
 combination[n - 1][k];
      }
    }

 int main() {
int n, t, a, b;
std::cin >> n >> t >> a >> b;

if (t + a + b != n || t < 0 || a < 0 || b < 0 || t > n || a > n || b > n) {
    std::cout << 0 << std::endl;
    return 0;
}

init_combination();

int n_prime = n - t;

ll fixed_points_ways = combination[n][t];

int total_states = 1 << n_prime;
std::vector<std::vector<ll>> dp(total_states, std::vector<ll>(n_prime + 1, 0));
dp[0][0] = 1;


for (int mask = 0; mask < total_states; ++mask) {
    int pos = __builtin_popcount(mask);
    if (pos == n_prime) continue; 

    for (int num = 0; num < n_prime; ++num) {
        if (mask & (1 << num)) continue; 
        if (num == pos) continue; 

        int next_excedances = 0;
        if (num > pos) {
            next_excedances = 1;
        }
        int total_excedances = 0;
        for (int k = 0; k <= n_prime; ++k) {
            int new_excedances = k + next_excedances;
            dp[mask | (1 << num)][new_excedances] += dp[mask][k];
        }
    }
}

ll total_permutations = dp[total_states - 1][b];
ll result = fixed_points_ways * total_permutations;

std::cout << result << std::endl;
return 0;

}

Erhard answered 31/10 at 12:15 Comment(4)
Are you asking for help in optimizing code that no one here has seen?Bacciferous
@ScottHunter, added the code now, sorry about it being a little messy I was having trouble with the indentation.Erhard
Can you provide any explanation of how this code works?Bacciferous
Can you share some sample data?Portable
P
6

I will explain my solution step by step so you can catch up.

Step1: Cycle Permutation

We define every permutation as a mapping: f: list(1 to N) => list(P), where P is the permutation itself.

Take an example, for permutation 4, 1, 6, 2, 5, 3, it is a mapping:

[1, 2, 3, 4, 5, 6] => [4, 1, 6, 2, 5, 3]

we draw an edge from i => f(i), and we will get a graph.

For our example, we have 1 => 4, 2 => 1, 3 => 6, 4 => 1, 5 => 5, 6 => 3, and then we have: enter image description here

So we will find there are 3 separate graphs.

For another example, the permutation 4, 1, 2, 5, 6, 3, we have:

enter image description here

We can find there is only 1 graph and all the nodes are connected for that permutation.

We define a permutation as a Cycle Permutation when its nodes of the graph are all connected.

We introduce Cycle Notation, which is:

(a b c .. x y z): a=>b, b=>c, .. x=>y, y=>z, z=>a

For example, the cycle notation of permutation 4, 1, 2, 5, 6, 3 is (1, 4, 5, 6, 3, 2).

We can find:

(a b c .. x y z) = (b c .. x y z a) = (c .. x y z a b) = (z a b c .. x y)

Also, the cycle notation of permutation 4, 1, 6, 2, 5, 3 is (1 4 2)(3 6)(5).

So there is only 1 pair of parentheses in the cycle notation of a cycle permutation.

Step2: Eulerian Number

We handle a simple task first:

For a length N cycle permutation, how many permutations have exactly A numbers that are less than their position?

Come back to our cycle notation, to eliminate repeat counts, we fix node 1 at the beginning of our notation. Then other numbers remained to fill into our notation:

(1 ? ? ... ?) # N-1 '?' in total 

For any two adjacent ?, if the first is less than the second one, then the first number is less than its position after mapping.

We can easily see that 1 must be less than its position after mapping. So the question becomes:

How many combos in all possibilities the exact A-1 numbers are greater than the previous numbers?

The answer to that problem is Eulerian Number:

In combinatorics, the Eulerian number is the number of permutations of the numbers 1 to in which exactly elements are greater than the previous element (permutations with "ascents").

The first few lines of Eulerian Number:

enter image description here

We take an example of E(3, 1), there are 4 combos matches:

1 3 2, 2 1 3, 2 3 1, 3 1 2

We leave more details to Wikipedia and continue.

Step3: The original problem and the smaller problem

Now we come to the original problem, considering the greatest number in our permutation, there are only 2 case for this number:

Case1: in its original position.
Case2: in a length k(2 to n) cycle permutation

We take an example for 6 length permutation:

Case1: ? ? ? ? ? 6
Case2: ? ? ? 6 ? 4(a length 2 cycle), ? ? 6 3 ? 4 (a length 3 cycle) ...

Notice in Case2, we come to a smaller problem since there are only n-k ? remains.

We define:

sol(n, t, a) is the answer of length n permutation with t fixed point and "a" numbers less than their position

In our cases, there are:

sol(n, t, a) is the sum of 
 sol(n-1, t-1, a) # case 1
 for 2<=i<=n and 0<=j<=i-1, C(n-1, i-1)* eulerian_number(i-1, j)*sol(n-i, t, a-j-1) # case 2

Explanation to C(n-1, i-1) * eulerian_number(i-1, j) * sol(n-i, t, a-j-1):

For a length i cycle, since number n is fixed we can select i-1 positions from n-1 numbers, so we have C(n-1, i-1), which C is the combination numbers.

And we will enumerate the number of less-than-positions j, so we have eulerian_number(i-1, j)(j numbers less than its position and there are (i-1)! combos since it is a i length cycle permutation)

Also, notice that there will be a number mapped into the position of the greatest number, so we have sol(n-i, t, a-j-1).

Summarize all the above, and with memorization, we come to an O(N^3) space and O(N^5) time solution.

You can check my code for details.

Appendix: Code

import itertools
import math

ed = {}
dp = {}


def eulerian_number(n, k):
    if ed.get((n, k)) is not None:
        return ed[(n, k)]
    if n == 0:
        if k == 0:
            ans = 1
        else:
            ans = 0
    else:
        ans = (n - k) * eulerian_number(n - 1, k - 1) + (k + 1) * eulerian_number(n - 1, k)
    ed[(n, k)] = ans
    return ans


def sol(n, t, a):
    p = (n, t, a)
    if dp.get(p) is not None:
        return dp[p]
    if n == 0:
        if a == 0 and t == 0:
            return 1
        else:
            return 0
    ans = sol(n - 1, t - 1, a)
    for i in range(2, n + 1):
        for j in range(0, i - 1):
            ans += sol(n - i, t, a - j - 1) * eulerian_number(i - 1, j) * math.comb(n - 1, i - 1)
    dp[p] = ans
    return and

I write some tests:

def test():
    def _test(permutation_length):
        d = {}
        for numbers in itertools.permutations(list(range(permutation_length))):
            t = 0
            a = 0
            b = 0
            for index, n in enumerate(numbers):
                if n == index:
                    t += 1
                if n < index:
                    a += 1
                if n > index:
                    b += 1
            if d.get((t, a, b)) is None:
                d[(t, a, b)] = 0
            d[(t, a, b)] += 1
        test_pass = True
        for k in d:
            t, a, b = k
            ans = sol(permutation_length, t, a)
            if ans != d[k]:
                test_pass = False
                print('Error on {}: {}, answer: {}, expected: {}'.format(permutation_length, k, ans, d[k]))
        if test_pass:
            print('Passed on Permutation Length: {}'.format(permutation_length))
    for i in range(1, 8):
        _test(i)


test()

Output:

Passed on Permutation Length: 1
Passed on Permutation Length: 2
Passed on Permutation Length: 3
Passed on Permutation Length: 4
Passed on Permutation Length: 5
Passed on Permutation Length: 6
Passed on Permutation Length: 7

A Better Solution

We can select t fixed points first, and then there is no Case1 in our formula.

So we come to this O(N^4)(Possible to O(N^3)) time solution and O(N^2) space solution.

def sol2(n, t, a):

    def _sol2(n, a):
        p = (n, a)
        if dp2.get(p) is not None:
            return dp2[p]
        if n == 0:
            if a == 0:
                return 1
            else:
                return 0
        ans = 0
        for i in range(2, n + 1):
            for j in range(0, i - 1):
                ans += _sol2(n - i, a - j - 1) * eulerian_number(i - 1, j) * math.comb(n - 1, i - 1)
        dp2[p] = ans
        return ans

    return math.comb(n, t) * _sol2(n - t, a)
Portable answered 3/11 at 15:55 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.