Andrew Dalke: Faster parity calulation

Datetime:2016-08-23 03:12:14          Topic: Python           Share

[previous | next ]     /home/ writings / diary / archive / 2016 / 08 /15/fragment_parity_calculation

Faster parity calulation

In the previous essay I needed determine the parity of a permutation. I used a Shell sort and counted the number of swaps needed to order the list. The parity is even (or "0") if the number of swaps is even, otherwise it's odd (or "1"). The final code was:

def parity_shell(values):
    # Simple Shell sort; while O(N^2), we only deal with at most 4 values 
    values = list(values)
    N = len(values)
    num_swaps = 0
    for i in range(N-1):
        for j in range(i+1, N):
            if values[i] > values[j]:
                values[i], values[j] = values[j], values[i]
                num_swaps += 1
    return num_swaps % 2

I chose this implementation because it's easy to understand, and any failure case is easily found. However, it's not fast.

It's tempting to use a better sort method. The Shell sort takes quadratic time in the number of elements, while others take O(N*ln(N)) time in the asymptotic case.

However, an asymptotic analysis is pointless for this case. The code will only ever receive 3 terms (if there is a chiral hydrogen) or 4 terms because the code will only ever be called for tetrahedral chirality.

Sorting networks

The first time I worked on this problem, I used a sorting networks . A sorting network works on a fixed number of elements. It uses a pre-determined set of pairwise comparisons, each followed by a swap if needed. These are often used where code branches are expensive, like in hardware or on a GPU. A sorting network takes constant time, so can help minimize timing side-channel attacks, where the time to sort may give some insight into what is being sorted.

A general algorithm to find a perfect sorting network for a given value of 'N' element isn't known, though there are non-optimal algorithms like Bose-Nelson and Batcher's odd–even mergesort, and optimal solutions are known for up to N=10.

John M. Gamble has a CGI script which will generate a sorting network for a given number of elements and choice of algorithm. For N=4 it generates:

N=4 elements: SWAP(0, 1); SWAP(2, 3); SWAP(0, 2); SWAP(1, 3); SWAP(1, 2);

where the SWAP would modify the elements of an array in-place. Here's one way to turn those instructions into a 4-element sort for Python:

def sort4(data):
  if data[1] < data[0]:  # SWAP(0, 1)
    data[0], data[1] = data[1], data[0]
  
  if data[3] < data[2]:  # SWAP(2, 3)
    data[2], data[3] = data[3], data[2]
  
  if data[2] < data[0]:  # SWAP(0, 2)
    data[0], data[2] = data[2], data[0]
  
  if data[3] < data[1]:  # SWAP(1, 3)
    data[1], data[3] = data[3], data[1]
  
  if data[2] < data[1]:  # SWAP(1, 2)
    data[1], data[2] = data[2], data[1]

As a test, I'll sort every permutation of four values and make sure the result is sorted. I could write the test cases out manually, but it's easier to use the " permutations() " function from Python's itertools module, as in this example with 3 values:

>>> list(itertools.permutations([1, 2, 3]))
[(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)]

Here's the test, which confirms that the function sorts correctly:

> > > for permutation in itertools.permutations([0, 1, 2, 3]):
...   permutation = list(permutation) # Convert the tuple to list sort4() can swap elements
...   sort4(permutation)
...   if permutation != [0, 1, 2, 3]:
...     print("ERROR:", permutation)
... 
>>>

I think it's obvious how to turn this into a parity function by adding a swap counter. If the input array cannot be modified then the parity function need to make a copy of the array first. That's what parity_shell() does.

No need to sort

A sort network will always do D comparisions, but those sorts aren't always needed. The reason is simple - if you think of the network as a decision tree, where each comparison is a branch, then D comparison will always have 2 D leaves. This must be at least as large as N!, where N is the number of elements in the list. But N! for N>2 is not a perfect power of 2, so there will be some unused leaves.

I would like to minimize the number of comparisions. I would also like to not modify the array in-place by actually sorting it.

The key realization is that there's no need to sort in order to determine the parity. For example, if there are only two elements in the list, then the parity is as simple as testing

def two_element_parity(x):
  assert len(x) == 2
  return x[1] > x[0]

The three element parity is a bit harder to do by hand:

def three_element_parity(x):
  assert len(x) == 3
  if x[0] < x[1]:
    if x[1] < x[2]:
      return 0      # 1, 2, 3
    elif x[0] < x[2]:
      return 1      # 1, 3, 2
    else:
      return 0      # 2, 3, 1
  elif x[0] < x[2]:
    return 1        # 2, 1, 3
  elif x[1] < x[2]:
    return 0        # 3, 1, 2
  else:
    return 1        # 3, 2, 1

It's complicated enough that it took several attempts before it was correct. I had to fix it using the following test code, which uses parity_shell() as a reference because I'm confident that it gives the correct values. (A useful development technique is to write something that you know works, even if it's slow, so you can use it to test more complicated code which better fits your needs)

The test code is:

def test_three_element_parity():
  for x in itertools.permutations([1,2,3]):
    p1 = parity_shell(x)
    p2 = three_element_parity(x)
    if p1 != p2:
      print("MISMATCH", x, p1, p2)
    else:
      print("Match", x, p1, p2)

which gives the output:

>>> test_three_element_parity()
Match (1, 2, 3) 0 0
Match (1, 3, 2) 1 1
Match (2, 1, 3) 1 1
Match (2, 3, 1) 0 0
Match (3, 1, 2) 0 0
Match (3, 2, 1) 1 1

A debugging technique

As I said, it took a couple of iterations to get correct code. I wasn't sure sometimes which branch was used to get a 0 or 1. During development I added a second field to each return value, to serve as a tag. The code looked like:

def three_element_parity(x):
  assert len(x) == 3
  if x[0] < x[1]:
    if x[1] < x[2]:
      return 0,1      # 1, 2, 3
    elif x[0] < x[2]:
      return 1,2      # 1, 3, 2
    else:
      return 0,3      # 2, 3, 1
  …

which meant I could see which path was in error. Here's one of the debug outputs using that technique:

> > > test_three_element_parity()
MISMATCH (1, 2, 3) 0 (0, 0)
MISMATCH (1, 3, 2) 1 (1, 1)
MISMATCH (2, 1, 3) 1 (1, 3)
MISMATCH (2, 3, 1) 0 (0, 2)
MISMATCH (3, 1, 2) 0 (0, 4)
MISMATCH (3, 2, 1) 1 (1, 5)

Of course the "MISMATCH" now is misleading and I need to compare things by eye, but with this few number of elements that's fine. For more complicated code I would modify the test code as well.

Brute force solution

The last time I worked on this problem I turned the sorting network for N=4 into a decision tree. With 5 swaps there 2 5 =32 terminal nodes, but only N! = 4! = 24 of them will be used. I pruned them by hand, which is possible with 32 elements.

I thought this time I would come up with some clever way to handle this, and pulled out Knuth's "The Art of Computer Programming" for a pointer, which has a lot about optimal sorts and sorting network. Oddly, "parity" wasn't in the index.

There's probably some interesting pattern I could use to figure out which code paths to use, but N is small, so I decided to brute force it.

I'll start with all of the possible permutations:

> > > permutations = list(itertools.permutations(range(3)))
>>> permutations
[(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)]

I want to build a decision tree where each leaf contains only one permutation. Each decision will be made by choosing two indices to use for the comparison test. I'll go through the permtuations. If its values at those indices are sorted then I'll put them into the "lt_permutations" list ("lt" is short for "less than"), otherwise they go into the "gt_permutations" list.

For now, I'll assume the first pair of indices to swap is (0, 1):

>>> lt_permutations = []
>>> gt_permutations = []
>>> for permutation in permutations:
...   if permutation[0] < permutation[1]:
...     lt_permutations.append(permutation)
...   else:
...     gt_permutations.append(permutation)
... 
>>> lt_permutations
[(0, 1, 2), (0, 2, 1), (1, 2, 0)]
>>> gt_permutations
[(1, 0, 2), (2, 0, 1), (2, 1, 0)]

I'll turn the above into a utility function:

def partition_permutations(i, j, permutations):
    lt_permutations = []
    gt_permutations = []
    for permutation in permutations:
        if permutation[i] < permutation[j]:
            lt_permutations.append(permutation)
        else:
            gt_permutations.append(permutation)
    return lt_permutations, gt_permutations

and use it further partition the 'lt_permutations' on the pair (1, 2):

>>> lt_permutations2, gt_permutations2 = partition_permutations(1, 2, lt_permutations)
>>> lt_permutations2
[(0, 1, 2)]
>>> gt_permutations2
[(0, 2, 1), (1, 2, 0)]

The lt_permutations2 list contains one element, so this time I'll partition gt_permutations2 using the swap index pair (0, 2):

>>> lt_permutations3, gt_permutations3 = partition_permutations(0, 2, gt_permutations2)
>>> lt_permutations3
[(0, 2, 1)]
>>> gt_permutations3
[(1, 2, 0)]

Each partitioning corresponds to additional if-statements until there is only one element in the branch. I want to use the above information to make a decision tree which looks like:

def parity3(data):
  if data[0] < data[1]:
    if data[1] < data[2];
      return 0 # parity of (0, 1, 2)
    else:
      if data[0] < data[2]:
        return 1 # parity of (0, 2, 1)
      else:
        return 0 # parity of (1, 2, 0)
  ...

Partition scoring

In the previous section I partioned using the successive pairs (0, 1), (1, 2) and (0, 2). These are pretty obvious. What should I use for N=4 or higher? In truth, I could likely use same swap pairs as from the sorting network, but I decided to continue with brute force.

For N item there are N*(N-1)/2 possible swap pairs.

>>> n = 4
>>> swap_pairs = [(i, j) for i in range(n-1) for j in range(i+1, n)]
>>> swap_pairs = [(i, j) for i in range(n-1) for j in range(i+1, n)]
>>> swap_pairs
[(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]

I decided to pick the one which is more likely to partition the set of permutations in half. For each pair, I partition the given permutations, and use the absolute value of the difference between the "less than" and the "greater than" subsets.

def get_partition_score(swap_pair, permutations):
    i, j = swap_pair
    num_lt = num_gt = 0
    for permutation in permutations:
        if permutation[i] < permutation[j]:
            num_lt += 1
        else:
            num_gt += 1
    return abs(num_lt - num_gt)

I'll create all permutations for 4 terms and score each of the pairs on the results:

>>> permutations = list(itertools.permutations(range(n)))
>>> len(permutations)
24
>>> permutations[:3]
[(0, 1, 2, 3), (0, 1, 3, 2), (0, 2, 1, 3)]
>>> for swap_pair in swap_pairs:
...   print(swap_pair, get_partition_score(swap_pair, permutations))
... 
(0, 1) 0
(0, 2) 0
(0, 3) 0
(1, 2) 0
(1, 3) 0
(2, 3) 0

The score is 0, which means that all partitions are equally good. I'll use the first, which is (0, 1), and partition into two sets of 12 each:

>>>lt_permutations, gt_permutations = partition_permutations(0, 1, permutations)
>>> lt_permutations
[(0, 1, 2, 3), (0, 1, 3, 2), (0, 2, 1, 3), (0, 2, 3, 1), (0, 3, 1, 2),
(0, 3, 2, 1), (1, 2, 0, 3), (1, 2, 3, 0), (1, 3, 0, 2), (1, 3, 2, 0),
(2, 3, 0, 1), (2, 3, 1, 0)]
>>> len(lt_permutations), len(gt_permutations)
(12, 12)

I'll redo the scoring with the "less than" permutations

>>> permutations = lt_permutations
>>> for swap_pair in swap_pairs:
...   print(swap_pair, get_partition_score(swap_pair, permutations))
... 
(0, 1) 12
(0, 2) 4
(0, 3) 4
(1, 2) 4
(1, 3) 4
(2, 3) 0

Obviously it does no good to use (0, 1) again because those are all sorted. Most of the other fields are also partially sorted so using them leads to an imbalanced 4-8 partitioning, but (2, 3) gives a perfect partitioning, so I'll use it for the next partitioning, and again select the "less-than" subset and re-score:

>>> lt_permutations, gt_permutations = partition_permutations(2, 3, permutations)
>>> lt_permutations
[(0, 1, 2, 3), (0, 2, 1, 3), (0, 3, 1, 2), (1, 2, 0, 3), (1, 3, 0, 2), (2, 3, 0, 1)]
>>> gt_permutations
[(0, 1, 3, 2), (0, 2, 3, 1), (0, 3, 2, 1), (1, 2, 3, 0), (1, 3, 2, 0), (2, 3, 1, 0)]
>>> permutations = lt_permutations
>>> for swap_pair in swap_pairs:
...   print(swap_pair, get_partition_score(swap_pair, permutations))
... 
(0, 1) 6
(0, 2) 0
(0, 3) 4
(1, 2) 4
(1, 3) 0
(2, 3) 6

Repeat this process until only one permutation is left, and use the parity of that permutation as the return value.

Code generation

I'll combine the above code together and put it into program which generates Python code that will compute the parity of a list with N distinct items. It uses recursion. The main entry point is "generate_parity_function()", which sets up the data for the recursive function "_generate_comparison()". That identifies the best pair of indices to use for the swap then calls itself to process each side.

On the other hand, if there's one permutation in the list, then there's nothing more do to but compute the parity of that permutation and use that as the return value for that case.

import itertools

def parity_shell(values):
    # Simple Shell sort; while O(N^2), we only deal with at most 4 values 
    values = list(values)
    N = len(values)
    num_swaps = 0
    for i in range(N-1):
        for j in range(i+1, N):
            if values[i] > values[j]:
                values[i], values[j] = values[j], values[i]
                num_swaps += 1
    return num_swaps % 2

def get_partition_score(swap_pair, permutations):
    i, j = swap_pair
    num_lt = num_gt = 0
    for permutation in permutations:
        if permutation[i] < permutation[j]:
            num_lt += 1
        else:
            num_gt += 1
    return abs(num_lt - num_gt)

def partition_permutations(i, j, permutations):
    lt_permutations = []
    gt_permutations = []
    for permutation in permutations:
        if permutation[i] < permutation[j]:
            lt_permutations.append(permutation)
        else:
            gt_permutations.append(permutation)
    return lt_permutations, gt_permutations

def generate_parity_function(n):
    print("def parity{}(data):".format(n))
    permutations = list(itertools.permutations(range(n)))
    swap_pairs = [(i, j) for i in range(n-1) for j in range(i+1, n)]
    _generate_comparison(permutations, swap_pairs, "  ")

def _generate_comparison(permutations, swap_pairs, indent):
    if len(permutations) == 1:
        parity = parity_shell(permutations[0])
        print(indent + "return {} # {} ".format(parity, permutations[0]))
        return
    
    swap_pair = min(swap_pairs, key=lambda x: get_partition_score(x, permutations))
    # Delete the swap pair because it can't be used again.
    # (Not strictly needed as it will always have the worse score.)
    del swap_pairs[swap_pairs.index(swap_pair)]

    # I could have a case where the lt subset has 0 elements while the
    # gt subset has 1 element. Rather than have the 'if' block do nothing,
    # I'll swap the comparison indices and swap branches.
    i, j = swap_pair
    lt_permutations, gt_permutations = partition_permutations(i, j, permutations)
    if not lt_permutations:
        lt_permutations, gt_permutations = gt_permutations, lt_permutations
        i, j = j, i
    
    print(indent + "if data[{i}] < data[{j}]:".format(i=i, j=j))
    # Need to copy the swap_pairs because the 'else' case may reuse a pair.
    _generate_comparison(lt_permutations, swap_pairs[:], indent+"  ")
    if gt_permutations:
        print(indent + "else:")
        _generate_comparison(gt_permutations, swap_pairs, indent+"  ")
    

if __name__ == "__main__":
    import sys
    n = 4
    if sys.argv[1:]:
        n = int(sys.argv[1])
    generate_parity_function(n)

The output for n=2 elements is the expected trivial case:

def parity2(data):
  if data[0] < data[1]:
    return 0 # (0, 1) 
  else:
    return 1 # (1, 0)

For n=3 it's a bit more complicated.

def parity3(data):
  if data[0] < data[1]:
    if data[0] < data[2]:
      if data[1] < data[2]:
        return 0 # (0, 1, 2) 
      else:
        return 1 # (0, 2, 1) 
    else:
      return 0 # (1, 2, 0) 
  else:
    if data[0] < data[2]:
      return 1 # (1, 0, 2) 
    else:
      if data[1] < data[2]:
        return 0 # (2, 0, 1) 
      else:
        return 1 # (2, 1, 0)

and for n=4 elements, well, you can see why I wrote a program to help generate the function:

def parity4(data):
  if data[0] < data[1]:
    if data[2] < data[3]:
      if data[0] < data[2]:
        if data[1] < data[2]:
          return 0 # (0, 1, 2, 3) 
        else:
          if data[1] < data[3]:
            return 1 # (0, 2, 1, 3) 
          else:
            return 0 # (0, 3, 1, 2) 
      else:
        if data[0] < data[3]:
          if data[1] < data[3]:
            return 0 # (1, 2, 0, 3) 
          else:
            return 1 # (1, 3, 0, 2) 
        else:
          return 0 # (2, 3, 0, 1) 
    else:
      if data[0] < data[3]:
        if data[1] < data[2]:
          if data[1] < data[3]:
            return 1 # (0, 1, 3, 2) 
          else:
            return 0 # (0, 2, 3, 1) 
        else:
          return 1 # (0, 3, 2, 1) 
      else:
        if data[0] < data[2]:
          if data[1] < data[2]:
            return 1 # (1, 2, 3, 0) 
          else:
            return 0 # (1, 3, 2, 0) 
        else:
          return 1 # (2, 3, 1, 0) 
  else:
    if data[2] < data[3]:
      if data[0] < data[3]:
        if data[0] < data[2]:
          return 1 # (1, 0, 2, 3) 
        else:
          if data[1] < data[2]:
            return 0 # (2, 0, 1, 3) 
          else:
            return 1 # (2, 1, 0, 3) 
      else:
        if data[1] < data[2]:
          return 1 # (3, 0, 1, 2) 
        else:
          if data[1] < data[3]:
            return 0 # (3, 1, 0, 2) 
          else:
            return 1 # (3, 2, 0, 1) 
    else:
      if data[0] < data[2]:
        if data[0] < data[3]:
          return 0 # (1, 0, 3, 2) 
        else:
          if data[1] < data[3]:
            return 1 # (2, 0, 3, 1) 
          else:
            return 0 # (2, 1, 3, 0) 
      else:
        if data[1] < data[2]:
          if data[1] < data[3]:
            return 0 # (3, 0, 2, 1) 
          else:
            return 1 # (3, 1, 2, 0) 
        else:
          return 0 # (3, 2, 1, 0)

The test code is essentially the same as "test_three_element_parity()", so I won't include it here.

Evaluation

I don't think it makes much sense to use this function beyond n=5 because there's so much code. Here's a table of the number of lines of code it generates for difference values of n:

# elements   # lines
----------   -------
    2              5
    3             17
    4             71
    5            349
    6          2,159
    7         15,119
    8        120,959
    9      1,088,662

This appears to be roughly factorial growth, which is what it should be. For my case, n=4, so 71 lines is not a problem.

I wrote some timing code which does 100,000 random selections from the possible permutations and compares the performance of the parity N () function with parity_shell(). To put them on a more even basis, I changed the parity_shell() implementation so mutates the input values rather than making a temporary list. The timing code for parity5() looks like:

import itertools

def parity_shell(values):
    # Simple Shell sort; while O(N^2), we only deal with at most 4 values 
    #values = list(values)
    N = len(values)
    num_swaps = 0
    for i in range(N-1):
        for j in range(i+1, N):
            if values[i] > values[j]:
                values[i], values[j] = values[j], values[i]
                num_swaps += 1
    return num_swaps % 2


if __name__ == "__main__":
    import random
    import time
    permutations = list(itertools.permutations(range(5)))
    perms = [list(random.choice(permutations)) for i in range(100000)]
    t1 = time.time()
    p1 = [parity5(perm) for perm in perms]
    t2 = time.time()
    p2 = [parity_shell(perm) for perm in perms]
    t3 = time.time()
    if p1 != p2:
      print("oops")
    print("parity5:", t2-t1, "parity_shell:", t3-t2)
    print("ratio:", (t2-t1)/(t3-t2))

The decision tree version is consisently 5-6x faster than the Shell sort version across all the sizes I tested.

A performance improvement

By the way, I was able to raise the performance to 9x faster by switching to local arrays rather than an array index each time. Here's the start of parity4() with that change:

def parity4(data):
  data0,data1,data2,data3 = data
  if data0 < data1:
    if data2 < data3:
      if data0 < data2:
        if data1 < data2:
          return 0 # (0, 1, 2, 3) 
        else:
          if data1 < data3:
            return 1 # (0, 2, 1, 3) 
          else:
            return 0 # (0, 3, 1, 2)
  … additional code omitted …
It's easy to change the code so it generates this version instead, or you can use a bit of text replacement and hand-editing to do it more manually from the code I gave earlier.

Andrew Dalke is an independent consultant focusing on software development for computational chemistry and biology. Need contract programming, help, or training?Contact me

Copyright © 2001-2013 Andrew Dalke Scientific AB




About List