Experimental code to produce a abc synthesizable netlist from a sci-kit learn random forest. Untested. Proceed with caution. author: schatter@google.com
In [0]:
#%tensorflow_version 1.x
import os
import re
import numpy as np
# TF is used only to read MNIST data
import tensorflow as tf
import sklearn
from sklearn.ensemble import RandomForestClassifier
print('The scikit-learn version is {}.'.format(sklearn.__version__))
print('The TF version is {}.'.format(tf.__version__))
In [0]:
def tree_predict(tree, node, x):
assert node >= 0
left = tree.children_left
right = tree.children_right # -1 is sentinel for none
feats = tree.feature # -2 is sentinel for none
thresh = tree.threshold
values = tree.value
if feats[node] == -2: # leaf node
assert left[node] == -1
assert right[node] == -1
return values[node] / values[node].sum()
else:
assert left[node] != -1
assert right[node] != -1
# note: we are int'ing the threshold since we don't think it matters
# as the features are all ints anyway
if x[feats[node]] <= int(thresh[node]):
return tree_predict(tree, left[node], x)
else:
return tree_predict(tree, right[node], x)
def forest_predict(model, x, debug=False):
res = tree_predict(model.estimators_[0].tree_, 0, x)
for estimator in model.estimators_[1:]:
res += tree_predict(estimator.tree_, 0, x)
if debug:
print(res.reshape(-1).astype(np.int32))
return res.reshape(-1).argmax()
def accuracy(model, examples):
return np.array([forest_predict(model, example) for example in examples])
In [0]:
def generate(name, randomize_labels, nverify=1000):
(tx, ty), (vx, vy) = tf.keras.datasets.mnist.load_data()
if randomize_labels:
ty = np.roll(ty, 127) # np.random.permutation(ty)
tx = tx.reshape(60000, -1)
vx = vx.reshape(10000, -1)
# note we turn off bootstrap so that samples are taken without resampling
# and as a result sample weights are always 1 and so inference is simpler
# m = RandomForestClassifier(n_estimators=10, bootstrap=False, random_state=0)
# TODO: tiny tree
m = RandomForestClassifier(n_estimators=2, max_depth=3, bootstrap=False, random_state=0)
m.fit(tx, ty)
print(m)
print("name = {}, ta = {}, va = {}".format(name, m.score(tx, ty),
m.score(vx, vy)))
nverify = min(60000, nverify)
mine = accuracy(m, tx[:nverify])
golden = m.predict(tx[:nverify])
assert (mine == golden).all()
# print(np.arange(nverify)[mine != golden])
print("verified")
# write_model(m, name)
# print("done writing {}".format(name))
return m
mreal = generate('real', randomize_labels=False)
mrand = generate('random', randomize_labels=True)
In [0]:
def dump_tree(tree, node, tree_id, n_classes_y, file):
assert node >= 0
left = tree.children_left
right = tree.children_right # -1 is sentinel for none
feats = tree.feature # -2 is sentinel for none
thresh = tree.threshold
values = tree.value
for i in range(n_classes_y):
print(' wire [7:0] n_{}_{}_{};'.format(tree_id, node, i), file=file)
if feats[node] == -2: # leaf node
assert left[node] == -1
assert right[node] == -1
#print(' wire [7:0] n{};'.format(node), file=file)
# for some reason (multi output classes?) tree.value has an extra dimension
assert values[node].shape == (1, n_classes_y)
class_probabilities = (values[node] / values[node].sum())[0]
for i in range(n_classes_y):
p_float = class_probabilities[i]
p_fixed = int(p_float * 255. + 0.5)
print(' assign n_{}_{}_{} = 8\'h{:x}; // {}'.format(tree_id, node, i, p_fixed, p_float), file=file)
return
else:
assert left[node] != -1
assert right[node] != -1
# note: we are int'ing the threshold since we don't think it matters
# as the features are all ints anyway
dump_tree(tree, left[node], tree_id, n_classes_y, file=file)
dump_tree(tree, right[node], tree_id, n_classes_y, file=file)
#for i in range(n_classes_y):
# print(' wire [7:0] n{}_{};'.format(node, i), file=file)
print(' wire c_{}_{};'.format(tree_id, node), file=file)
assert 0. <= thresh[node]
assert thresh[node] < 255.
threshold = int(thresh[node])
print(' assign c_{}_{} = x{} <= 8\'h{:x};'.format(tree_id, node, feats[node], threshold), file=file)
for i in range(n_classes_y):
print(' assign n_{}_{}_{} = c_{}_{} ? n_{}_{}_{} : n_{}_{}_{};'.format(
tree_id, node, i,
tree_id, node,
tree_id, left[node], i,
tree_id, right[node], i),
file=file)
def dump_verilog(model, width_x, n_classes_y):
with open('output.v', 'w') as f:
print("module forest(", file=f)
for i in range(width_x):
print(" input wire [7:0] x{}{}".format(i, ','), file=f)
for i in range(n_classes_y):
print(" output wire [15:0] y{}{}".format(i, ',' if i < n_classes_y - 1 else ''), file=f)
print(" );", file=f)
for i, estimator in enumerate(model.estimators_):
print(' // dumping tree {}'.format(i), file=f)
dump_tree(estimator.tree_, node=0, tree_id=i, n_classes_y=n_classes_y, file=f)
for c in range(n_classes_y):
print(' wire [15:0] s_{}_{};'.format(i, c), file=f)
print(' wire [15:0] e_{}_{};'.format(i, c), file=f)
print(' assign e_{}_{} = {} 8\'h0, n_{}_0_{} {};'.format(i, c, '{', i, c, '}'), file=f)
if i > 0:
print(' assign s_{}_{} = s_{}_{} + e_{}_{};'.format(i, c, i - 1, c, i, c), file=f)
else:
print(' assign s_{}_{} = e_{}_{};'.format(i, c, i, c), file=f)
for c in range(n_classes_y):
print(' assign y{} = s_{}_{};'.format(c, len(model.estimators_) - 1, c), file=f)
print("endmodule", file=f)
dump_verilog(mreal, width_x=784, n_classes_y=10)
!head output.v
# verilator can take 3 mins to lint the resulting Verilog file if 10 trees and unlimited depth is used!
# !verilator output.v --lint-only
# !abc/abc -c "%read output.v; %blast; &ps; &put; write test_syn.v"
#!cat test_syn.v
In [0]:
# ABC limitations:
# read silently fails whereas %read works
# if a PO is not driven an assertion fails in blast
# verilator limitations:
# sometimes when the input is bad verilator may get stuck!