Using qsort in Cython to get a sorting index/permutation
Asked Answered
A

1

1

Overview

There are a few questions similar to this one but they are all slightly different. To be clear, if values is an array of integers, I want to find perm such that sorted_values (values sorted by some comparison operator), is given by

sorted_values[i] = values[perm[i]]

Step 1

How to do it in C? Well qsort requires declaring a comparison function to tell you whether one value is greater than another. If we make values a global variable, then we can exploit this comparison function to sort an array perm initially set to 0:N-1 (where N is the length of values) by not comparing perm[i] vs perm[j] but instead comparing values[perm[i]] vs values[perm[j]]. See this link. Some example C code:

// sort_test.c
#include <stdio.h>
#include <stdlib.h>

int *values;

int cmpfunc (const void * a, const void * b) {
   return ( values[*(int*)a] - values[*(int*)b] );
}

int main () {
   int i;
   int n = 5;
   int *perm;

   // Assign memory
   values = (int *) malloc(n*sizeof(int));
   perm = (int *) malloc(n*sizeof(int));

   // Set values to random values between 0 and 99
   for (i=0; i<n; i++) values[i] = rand() % 100;

   // Set perm initially to 0:n-1
   for (i=0; i<n; i++) perm[i] = i;

   printf("Before sorting the list is: \n");
   for (i=0; i<n; i++) printf("%d ", values[i]);

   qsort(perm, n, sizeof(int), cmpfunc);

   printf("\nThe sorting permutation is: \n");
   for (i=0; i<n; i++) printf("%d ", perm[i]);

   free(values);
   free(perm);

   printf("\n");

   return(0);
}

Of course the trick is defining values globally, so the cmpfunc can see it.

Step 2

How to do it in Cython? Unfortunately I cannot get Cython to use the same trick with values declared globally. My best attempt is the following based off the answer here, however the difference is that they just sort an array they do not need to get the indexing/permutation.

# sort_test_c.pyx
cimport cython
from libc.stdlib cimport qsort

# Try declaring global variable for the sort function
cpdef long[:] values

cdef int cmpfunc (const void *a , const void *b) nogil:
    cdef long a_v = (<long *> a)[0] 
    cdef long b_v = (<long *> b)[0]
    return (values[a_v] - values[b_v]);

def sort(long[:] py_values, long[:] perm, int N):
    # Assign to global
    values = py_values

    # Make sure perm is 0:N-1
    for i in range(N):
        perm[i] = i

    # Perform sort
    qsort(&perm[0], N, perm.strides[0], &cmpfunc)

This can be compiled using

cythonize -i sort_test_c.pyx

and tested with the script

# sort_test.py
from sort_test_c import sort
import numpy as np

n = 5
values = np.random.randint(0, 100, n).astype(int)
perm = np.empty(n).astype(int)

sort(values, perm, n)

This however complains about our global variable values i.e.

UnboundLocalError: local variable 'values' referenced before assignment
Exception ignored in: 'sort_test_c.cmpfunc

and the sorting permutation is not correct (unless the values are already ordered in which case it is luck as perm always returns the array 0:4). How can I fix this?

Antilles answered 15/2, 2018 at 9:42 Comment(4)
Global variables in Python and Cython are at module level rather than truely global. (I'm not convinced that this approach to the problem is the best though)Pralltriller
I agree, I dislike the solution of having to use a global variable but I can't see a way around it with the qsort APIAntilles
struct of values and indices looks to be the standard approach. Alternatively you could sort an array of pointers instead and work out the indexes from that.Pralltriller
The link is a much nicer way of doing it, for some reason I didn't think of using a structure. Now the compare function API makes more sense.Antilles
H
0

Thanks for the code, it helped me a lot.
If you still need to make it work (after 6 years):

from libc.stdlib cimport qsort
cdef:
    cython.long[:] values

cdef int cmpfunc (const void *a , const void *b) noexcept nogil:
    cdef long a_v = (<long *> a)[0] 
    cdef long b_v = (<long *> b)[0]
    return (values[a_v] - values[b_v]);

# just pass a generic np.ndarray, sometimes cython is easier than expected
def sort(np.ndarray py_values, long[:] perm, int N):
    global values 
    values = py_values # this creates the fast global memory view
    qsort(&perm[0], N, perm.strides[0], &cmpfunc)


n = 20
values = np.random.randint(0, 100, n).astype(np.int32)
perm = np.arange(len(values),dtype=np.int32)
sort(values, perm, n)
print(values)
print(values[perm])
[18 24 57 38 24 14 17 52 24  8  3 67 69  7 51 98 66 66 82  8]
[ 3  7  8  8 14 17 18 24 24 24 38 51 52 57 66 66 67 69 82 98]
Hawkinson answered 19/7, 2024 at 2:31 Comment(0)

© 2022 - 2025 — McMap. All rights reserved.