Wednesday, October 10, 2012

Zero-Suppressed Binary Decision Diagrams and Polynomials

In these days I read a lot of articles about polynomial representation.
I was interested in finding an efficient representation for polynomials, to be used in my AKS implementation.

Most of these representation are quite straightforward, like using arrays of coefficients, arrays of (coefficient, exponent) pairs or simply using integers in which n bits represent a coefficient.

I decided to ask on StackOverflow how could I optimize my current polynomial implementation, and in a comment pointed out that Minato wrote an article in which he described a new way to represent polynomials.

I've decided to give this idea a try and, even though I don't think it will get my AKS faster, I found it really interesting.

Minato's representation is really different from all the "classic" representations. The idea is to represent the polynomial as a DAG(Directed Acyclic Graph), in particular a special form of BDD(Binary Decision Diagram), which is called Zero-Suppressed BDD.
The result is that it is possible to represent polynomials with millions of terms using only thousands of nodes(actually it could be even less, depending on the polynomial).

Since the operations, such as addition or equality, have a complexity proportional to the size of the graph, and not to the number of terms, so the representation seems promising(it can be thousands of times faster if the polynomials are particularly "regular").

Brief description of ZBDDs


Each node of a ZBDD has a label v and two outgoing edges, the 1-edge and the 0-edge(which will be drawn with a dashed line in the images). The 1-edge is connected to a child node, which is called the high child, while the 0-edge connects to the low child. So we can represent a node \(A\) as a triplet \((v, low, high)\).

The ZBDD is a rooted DAG, so there is a node that has indegree 0 and such that there is a path from this node to any other node in the graph(and it is the only node with such property). The order of the labels is fixed, so that in any path from the root different labels always appear in the same order.
In a ZBDD there are also two special nodes, called sinks, which are labelled \(\top\) and \(\bot\). The first is also called the true sink, while the latter is called the false sink. These two nodes have outdegree 0. 

The graph must also abide two reduction rules:

  1. In a graph \(G\) there must not be two nodes \(A\) and \(B\) such that \(v_A = v_B\) and \(low_A = low_B\) and \(high_A = high_B\)[no isomorphic subgraphs]
  2. Every node whose high child is the false sink should be removed(eventually attaching its low child to its parent).
It's quite simply to abide this rules, we just need a procedure, which I'll call ZUNIQUE, that controls node creation. Whenever we want to create a node \((v, low, high)\) we call ZUNIQUE and this procedure checks if there exist an isomorphic node(and returns it in such case), or, if the high child is the false sink it returns the low child.

An example of a ZBDD is the following:

 Operations on these graphs are quite easy to write in a recursive fashion.
Let us suppose that we want to apply a binary operation \(\diamond\)(which I'll call meld) between two ZBDDs. We know the results of \(\top \diamond \top, \top \diamond \bot, \bot \diamond \top, \bot \diamond \bot \).

An algorithm, MELD(\(F\),\(G\)), to apply such generic operation to two ZBDDs \(F\) and \(G\) can be described by these steps:

  1. If \(F\) and \(G\) are sinks, return \(F \diamond G \), since it's a base case.
  2. If \(v_F = v_G\) then return ZUNIQUE(\(v_F\), MELD(\(low_F\),\(low_G\)), MELD(\(high_F\), \(high_G\)))
  3. else if \(v_F < v_G\), then return ZUNIQUE(\(v_F\), MELD(\(low_F\), \(G\)), MELD(\(high_F\), \(G\)))
  4. otherwise return ZUNIQUE(\(v_G\), MELD(\(low_G\), \(F\)), MELD(\(high_G\), \(F\)))
If we replace the base case operations we can produce the result of any boolean binary operation, such as AND, OR, XOR.


Polynomial representation

Minato had a simple yet brilliant idea. Let us consider the polynomial \(x^4 + x^3 + x\). Since any natural number \(n\) can be written uniquely as a sum of different powers of two, we can rewrite it as \(x^4 + x^2 \cdot x^1 + x^1\).
Now we can consider \(x^4, x^2\) and \(x^1\) as three different boolean variables. Grouping \(x^1\) we get \(x^4 + x^1\cdot(x^2 + 1)\), now if we replace sums with 0-edges and products with 1-edges we obtain the following ZBDD:
If we consider the paths from the root \(x^1\) to the sinks we can re-obtain the original polynomial in this way: if two nodes are connected with a 1-edge multiply the two labels together. Otherwise if they are connected by a 0-edge simply skip that label. Then sum the results for all the paths.

So we can see that we have the path \(x^1 - x^2 - \top \) which is replaced by \(x^1 \cdot x^2 \cdot 1  = x^3\), then there is the path \(x^1 - x^2 \cdots \top\), which is replaced by \(x^1 \cdot 1 = x^1\), and there is also the path \(x^1 \cdots x^4 - \top \) which yields \(x^4\). The path \(x^1\cdots x^4 \cdots \bot\) yields \(0\). So if we sum these results together we get \(x^4 + x^3 + x^1 +0\), which is our polynomial.

To represent the integer coefficients we can use the same trick.
Every natural number can be written uniquely as sum of powers of two, then every power of two can be considered as a boolean variable and be used as \(x^i\) before.

Is the representation compact?

You may wonder if this funny representation is compact and or efficient.
For example if we take the polynomial \(24x^7 + 4x^6 + 3x^3 + 16x^2 + 15x\), the resulting ZBDD is:
And it does not seem so compact. It has 15 nodes but it represent a polynomial with 8 terms(and the first one is zero). But if we consider a polynomial such as \(257 x^{55}+769 x^{54}+8 x^{52}+257 x^{43}+769 x^{42}+8 x^{40}+257 x^{23}+769 x^{22}+8 x^{20}+257 x^{11}+769 x^{10}+8 x^8\), then we obtain the following ZBDD:

This ZBDD contains only 12 nodes and represents a polynomial of degree 55, with 12 terms. This polynomial would require 56 "slots" to be represented as array of coefficients and would require 12 pairs to represent it as an array of coefficient-exponent pairs, but in that case the operations would be slower. Also, we can modify it slightly to obtain something like this:
Which represents \(257 x^{55}+769 x^{54}+8 x^{52}+257 x^{43}+769 x^{42}+8 x^{40}+16x^{37}+16x^{33}+16x^{32}+257 x^{23}\)
\(+769 x^{22}+8 x^{20}+257 x^{11}+769 x^{10}+8 x^8+16x^5+16x+16\), a polynomial with 18 terms using only 15 nodes.

This may seem a small gain, but it becomes really important when we want to deal with big polynomials. For example the polynomial \(\prod_{k=1}^{8}{(x_k + 1)^8}\) which has \(9^8= 43046721\)) terms can be represented with only \(26279\) nodes which is about four times the square root of the number of nodes.

Operations on Polynomials


We will now see that it's pretty easy to devise algorithms that compute the sum or product of two polynomials represented as ZBDDs.

First of all we have to devise an algorithm that computes the product of a polynomial and a variable. The algorithm MULVAR(\(F\),\(v\)) is pretty straightforward if we think carefully about this example:

  
Basically we have to notice that to multiply a polynomial whose root is labelled \(v\) we simply have to swap its children and to multiply the new low child(which previously was the high child) by \(v^2\). This allows us to write a simple recursive function, like the following:
  1. If \(v < v_F\) then return ZUNIQUE(\(v\), false-sink, \(F\))
  2. Else if \(v = v_F\) then return ZUNIQUE(\(v\), MULVAR(\(high_F\), \(v^2\)), \(low_F\))
  3. Else return ZUNIQUE(\(v_F\), MULVAR(\(low_F\), \(v\)), MULVAR(\(high_F\), \(v\)))

Addition

If we consider two polynomials \(F\) and \(G\), we can easily see that if their graphs do not share any subgraph, then their sum is their union(which we can compute as \(F\) OR \(G\), using the ZBDD algorithm).
If there are common subgraphs then the merging wouldn't count them twice.
So what we have to do is writing \(F + G\) as \(F \oplus G + 2 \times (F \wedge G) \), where \(oplus\) is the XOR of two ZBDDs, \(\wedge\) is their AND and we can compute \(2 \times F\) multiplying \(F\) by the variable \(2\).


Multiplication

Let \(F\) and \(G\) be polynomials such that \(v_F = v_G = v\), then we can compute their producting with the relation:
\[
F \times G = (low_F \times low_G) + (high_F \times high_G) \times v^2 + ((high_F \times low_G) + (low_F \times high_G)) \times v
\]

We can already compute the expression, since it uses only addition and multiplication by a variable(plus the recursive calls).

What happens if we have \(F \times G\) and \(v_F \neq v_G\)? Suppose \(v_F < v_G\) than we should simply multiply \(low_F\) and \(high_F\) by \(G\) and we would be done.

An implementation in Python

Writing an implementation in python is quite straightforward. It takes only about 130 lines of code. The only thing that we should decide is where to put the "caching system". We could create a "Polynomial" factory that will return unique polynomials, just like ZUNIQUE, but giving the opportunity to create more factories, or we can just provide a single factory.

My implementation uses a single factory, which is actually provided as the __new__ method. All operations can be cached using a simple decorator(to avoid rewriting boiler-plate code every time).


def memoize(name=None, symmetric=True):
    def cached(meth):
        if name is None:
            name_ = meth.__name__.strip('_')
        else:
            name_ = name
        
        def decorator(self, other):
            try:
                return self.CACHE[name_][(self, other)]
            except KeyError:
                result = meth(self,other)
                self.CACHE[name_][(self,other)] = result
                if symmetric:
                    self.CACHE[name_][(other,self)] = result
                return result
        return decorator
    return cached



class Poly(object):
    
    CACHE = {
        'new': {},
        'and': {},
        'or': {},
        'xor': {},
        'add': {},
        'mul': {},
    }
    
    def __new__(cls, label, low=None, high=None):
        if high is not None and high.label is False:
            return low
        
        try:
            return cls.CACHE['new'][(label, low, high)]
        except KeyError:
            r = cls.CACHE['new'][(label, low, high)] = object.__new__(cls, label, low, high)
            return r
            
    def __init__(self, label, low=None, high=None):
        self.label, self.low, self.high = label, low, high
    
    @memoize()
    def __and__(self, other):
        if other.is_terminal():
            f, g = other, self
        else:
            f, g = self, other
        
        if f.is_terminal():
            return g if f.label is True else f
        elif f.label == g.label:
            return Poly(f.label, f.low & g.low, f.high & g.high)
        return f.low & g.low if f.label < g.label else f & g.low
    
    @memoize()
    def __or__(self, other):
        if other.is_terminal():
            f, g = other, self
        else:
            f, g = self, other
        
        if f.is_terminal():
            return (g if g.label is not False else f) if f.label is True else g
        elif f.label == g.label:
            return Poly(f.label, f.low | g.low, f.high | g.high)
        return (Poly(f.label, f.low | g, f.high) if f.label < g.label else
                Poly(g.label, f | g.low, g.high))
    
    @memoize()
    def __xor__(self, other):
        if other.is_terminal():
            f, g = other, self
        else:
            f, g = self, other
        
        if f.is_terminal():
            return (Poly(f.label ^ g.label) if g.is_terminal() else
                    Poly(g.label, f ^ (g.low), g.high))
        elif f.label == g.label:
            return Poly(f.label, f.low ^ g.low, f.high ^ g.high)
        return (Poly(f.label, f.low ^ g, f.high) if f.label < g.label else
                Poly(g.label, f ^ g.low, g.high))
    
    @memoize()
    def __mul__(self, other):
        if isinstance(other, Poly):
            if other.is_terminal():
                f, g = other, self
            else:
                f, g = self, other
            
            if f.is_terminal():
                return g if f.label is True else f
            elif f.label == g.label:
                return (f.low * g.low +
                        (f.high * g.high) * (f.label[0], f.label[1]*2) +
                        (f.high * g.low + f.low * g.high) * f.label)
            
            return (Poly(f.label, f.low * g, f.high * g) if f.label < g.label else
                    Poly(g.label, f * g.low, f * g.high))
        else:
            if self.is_terminal():
                return (Poly(other, Poly(False), Poly(True)) if self.label is True
                        else self)
            elif self.label < other:
                return Poly(self.label, self.low * other, self.high * other)
            elif self.label == other:
                return Poly(self.label, self.high * (other[0], other[1]*2), self.low)
            
            return Poly(other, Poly(False), self)
    
    @memoize()
    def __add__(self, other):       
        if self.is_terminal() and other.is_terminal():
            return self | other
        
        xor = self ^ other
        intersect = self & other
        if intersect.is_terminal() and not intersect.label:
            result = xor
        else:
            result = xor + (intersect * ('2', 1))
        
        return result
    
    def is_terminal(self):
        return self.low is self.high is None