How to count number of inversions in an array using segment trees
Asked Answered
W

2

6

I know that this problem can be solved using modified merge sort and I have coded same. Now I want to solve this problem using Segment Tree. Basically, if we traverse from right to left array then we have to count the "how many values are greater than current value". How this thing can be achieved by Segment Tree?

What type of information we have to store on Segment Tree Node?

Please provide code if possible.

Whittemore answered 12/9, 2013 at 7:51 Comment(0)
T
8

Let me explain step by step with an example:

arr      :  4 3 7 1
position :  0 1 2 3

First, sort the array in descending order as {value, index} pair.

arr      :  7 4 3 1
index    :  2 0 1 3
position :  0 1 2 3

Iterating from left to right, for each element arr[i] -

Query for each element's index (query for range [0, arr[i].index] to get the greater numbers count on left side) and put the query result on corresponding index of output array.

After each query, increment the corresponding segment tree node which covers that index.

This way, we are ensuring to get only greater numbers count from 0 to index - 1 as values only greater than arr[i] have been inserted so far.

Below C++ implementation will make more sense.

class SegmentTree {

    vector<int> segmentNode;

public:
    void init(int n) {
        int N = /* 2 * pow(2, ceil(log((double) n / log(2.0)))) - 1 */ 4 * n;
        segmentNode.resize(N, 0);
    }

    void insert(int node, int left, int right, const int indx) {
        if(indx < left or indx > right) {
            return;
        }
        if(left == right and indx == left) {
            segmentNode[node]++;
            return;
        }
        int leftNode = node << 1;
        int rightNode = leftNode | 1;
        int mid = left + (right - left) / 2;

        insert(leftNode, left, mid, indx);
        insert(rightNode, mid + 1, right, indx);

        segmentNode[node] = segmentNode[leftNode] + segmentNode[rightNode];
    }

    int query(int node, int left, int right, const int L, const int R) {
        if(left > R or right < L) {
            return 0;
        }
        if(left >= L and right <= R) {
            return segmentNode[node];
        }

        int leftNode = node << 1;
        int rightNode = leftNode | 1;
        int mid = left + (right - left) / 2;

        return query(leftNode, left, mid, L, R) + query(rightNode, mid + 1, right, L, R);
    }

};

vector<int> countGreater(vector<int>& nums) {
    vector<int> result;
    if(nums.empty()) {
        return result;
    }
    int n = (int)nums.size();
    vector<pair<int, int> > data(n);
    for(int i = 0; i < n; ++i) {
        data[i] = pair<int, int>(nums[i], i);
    }
    sort(data.begin(), data.end(), greater<pair<int, int> >());
    result.resize(n);
    SegmentTree segmentTree;
    segmentTree.init(n);
    for(int i = 0; i < n; ++i) {
        result[data[i].second] = segmentTree.query(1, 0, n - 1, 0, data[i].second);
        segmentTree.insert(1, 0, n - 1, data[i].second);
    }
    return result;
}

// Input : 4 3 7 1
// output: 0 1 0 3

This is easy one but not pretty "obvious" as other typical segment tree problem. Simulating with pen and paper with an arbitrary input will help.

There are other O(nlogn) approaches with BST, Fenwick tree and merge sort.

Tereus answered 22/4, 2016 at 7:40 Comment(0)
M
1

It is solved quite simply. We construct an empty segment tree of size n with the operation sum. Now go through the permutation elements from left to right. A one in a leaf of a tree of segments will mean that such an element was already visited. When moving to the i-th element of p[i], we will make a request to calculate the sum of [p[i],n] in the segment tree: it will just count the number of elements to the left which are greater than p[i]. And finally, put one in position p[i]. The total time is O(nlogn).

Mcclintock answered 30/8, 2020 at 12:44 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.