Experimental code to produce a abc synthesizable netlist from a sci-kit learn random forest. Untested. Proceed with caution. author: schatter@google.com


In [0]:
#%tensorflow_version 1.x
import os
import re
import numpy as np
# TF is used only to read MNIST data
import tensorflow as tf
import sklearn
from sklearn.ensemble import RandomForestClassifier

print('The scikit-learn version is {}.'.format(sklearn.__version__))
print('The TF version is {}.'.format(tf.__version__))


The scikit-learn version is 0.21.3.
The TF version is 2.0.0.

In [0]:
def tree_predict(tree, node, x):
  
    assert node >= 0
  
    left   = tree.children_left
    right  = tree.children_right # -1 is sentinel for none
    feats  = tree.feature # -2 is sentinel for none
    thresh = tree.threshold
    values = tree.value
  
    if feats[node] == -2: # leaf node
        assert left[node] == -1
        assert right[node] == -1
        return values[node] / values[node].sum()
    else:
        assert left[node] != -1
        assert right[node] != -1
        # note: we are int'ing the threshold since we don't think it matters
        # as the features are all ints anyway
        if x[feats[node]] <= int(thresh[node]):
            return tree_predict(tree, left[node], x)
        else:
            return tree_predict(tree, right[node], x)

    
def forest_predict(model, x, debug=False):

    res = tree_predict(model.estimators_[0].tree_, 0, x)  
    for estimator in model.estimators_[1:]:
        res += tree_predict(estimator.tree_, 0, x)
  
    if debug:
        print(res.reshape(-1).astype(np.int32))
    return res.reshape(-1).argmax()


def accuracy(model, examples):
    return np.array([forest_predict(model, example) for example in examples])

In [0]:
def generate(name, randomize_labels, nverify=1000):
    (tx, ty), (vx, vy) = tf.keras.datasets.mnist.load_data()

    if randomize_labels:
        ty = np.roll(ty, 127) # np.random.permutation(ty)

    tx = tx.reshape(60000, -1)
    vx = vx.reshape(10000, -1)
  
    # note we turn off bootstrap so that samples are taken without resampling
    # and as a result sample weights are always 1 and so inference is simpler
    # m = RandomForestClassifier(n_estimators=10, bootstrap=False, random_state=0)
    # TODO: tiny tree
    m = RandomForestClassifier(n_estimators=2, max_depth=3, bootstrap=False, random_state=0)
    m.fit(tx, ty)
  
    print(m)
    print("name = {}, ta = {}, va = {}".format(name, m.score(tx, ty), 
                                            m.score(vx, vy)))  

    nverify = min(60000, nverify)
  
    mine   = accuracy(m, tx[:nverify])
    golden = m.predict(tx[:nverify])
  
    assert (mine == golden).all()
    # print(np.arange(nverify)[mine != golden])
    print("verified")

    # write_model(m, name)
    # print("done writing {}".format(name))
  
    return m
    
mreal = generate('real',   randomize_labels=False)
mrand = generate('random', randomize_labels=True)


RandomForestClassifier(bootstrap=False, class_weight=None, criterion='gini',
                       max_depth=3, max_features='auto', max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=2,
                       n_jobs=None, oob_score=False, random_state=0, verbose=0,
                       warm_start=False)
name = real, ta = 0.4905, va = 0.5057
verified
RandomForestClassifier(bootstrap=False, class_weight=None, criterion='gini',
                       max_depth=3, max_features='auto', max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=2,
                       n_jobs=None, oob_score=False, random_state=0, verbose=0,
                       warm_start=False)
name = random, ta = 0.11363333333333334, va = 0.1142
verified

In [0]:
def dump_tree(tree, node, tree_id, n_classes_y, file):
    
    assert node >= 0
  
    left   = tree.children_left
    right  = tree.children_right # -1 is sentinel for none
    feats  = tree.feature # -2 is sentinel for none
    thresh = tree.threshold
    values = tree.value

    for i in range(n_classes_y):
        print('    wire [7:0] n_{}_{}_{};'.format(tree_id, node, i), file=file)
    
    if feats[node] == -2: # leaf node
        assert left[node] == -1
        assert right[node] == -1
        #print('    wire [7:0] n{};'.format(node), file=file)
        
        # for some reason (multi output classes?) tree.value has an extra dimension
        assert values[node].shape == (1, n_classes_y)
        class_probabilities = (values[node] / values[node].sum())[0]
        
        for i in range(n_classes_y):
            p_float = class_probabilities[i]
            p_fixed = int(p_float * 255. + 0.5)
            print('    assign n_{}_{}_{} = 8\'h{:x}; // {}'.format(tree_id, node, i, p_fixed, p_float), file=file)
        return
    else:
        assert left[node] != -1
        assert right[node] != -1
        # note: we are int'ing the threshold since we don't think it matters
        # as the features are all ints anyway
        dump_tree(tree, left[node],  tree_id, n_classes_y, file=file)
        dump_tree(tree, right[node], tree_id, n_classes_y, file=file)

        #for i in range(n_classes_y):
        #    print('    wire [7:0] n{}_{};'.format(node, i), file=file)
        print('    wire c_{}_{};'.format(tree_id, node), file=file)
        
        assert 0. <= thresh[node] 
        assert thresh[node] < 255.
        threshold = int(thresh[node])
        
        print('    assign c_{}_{} = x{} <= 8\'h{:x};'.format(tree_id, node, feats[node], threshold), file=file)
        
        for i in range(n_classes_y):        
            print('    assign n_{}_{}_{} = c_{}_{} ? n_{}_{}_{} : n_{}_{}_{};'.format(
                    tree_id, node, i, 
                    tree_id, node, 
                    tree_id, left[node], i, 
                    tree_id, right[node], i), 
                file=file)


def dump_verilog(model, width_x, n_classes_y):
    with open('output.v', 'w') as f:
        print("module forest(", file=f)
        for i in range(width_x):
            print("    input wire [7:0] x{}{}".format(i, ','), file=f)       
        for i in range(n_classes_y):
            print("    output wire [15:0] y{}{}".format(i, ',' if i < n_classes_y - 1 else ''), file=f)       
        print("    );", file=f)
      
        for i, estimator in enumerate(model.estimators_):
            print('    // dumping tree {}'.format(i), file=f)
            dump_tree(estimator.tree_, node=0, tree_id=i, n_classes_y=n_classes_y, file=f)    

            for c in range(n_classes_y):
                print('    wire [15:0] s_{}_{};'.format(i, c), file=f)
                print('    wire [15:0] e_{}_{};'.format(i, c), file=f)
                print('    assign e_{}_{} = {} 8\'h0, n_{}_0_{} {};'.format(i, c, '{', i, c, '}'), file=f)
                if i > 0:
                    print('    assign s_{}_{} = s_{}_{} + e_{}_{};'.format(i, c, i - 1, c, i, c), file=f)
                else:
                    print('    assign s_{}_{} = e_{}_{};'.format(i, c, i, c), file=f)

        for c in range(n_classes_y):
            print('    assign y{} = s_{}_{};'.format(c, len(model.estimators_) - 1, c), file=f)
            
        print("endmodule", file=f)
        
    
dump_verilog(mreal, width_x=784, n_classes_y=10)
!head output.v
# verilator can take 3 mins to lint the resulting Verilog file if 10 trees and unlimited depth is used!
# !verilator output.v --lint-only
# !abc/abc -c "%read output.v; %blast; &ps; &put; write test_syn.v"
#!cat test_syn.v


module forest(
    input wire [7:0] x0,
    input wire [7:0] x1,
    input wire [7:0] x2,
    input wire [7:0] x3,
    input wire [7:0] x4,
    input wire [7:0] x5,
    input wire [7:0] x6,
    input wire [7:0] x7,
    input wire [7:0] x8,

In [0]:
# ABC limitations:
# read silently fails whereas %read works
# if a PO is not driven an assertion fails in blast
# verilator limitations:
# sometimes when the input is bad verilator may get stuck!