sklearn-porter

Repository: https://github.com/nok/sklearn-porter

DecisionTreeClassifier

Documentation: sklearn.tree.DecisionTreeClassifier


In [1]:
import sys
sys.path.append('../../../../..')

Load data


In [2]:
from sklearn.datasets import load_iris

iris_data = load_iris()

X = iris_data.data
y = iris_data.target

print(X.shape, y.shape)


((150, 4), (150,))

Train classifier


In [3]:
from sklearn.tree import tree

clf = tree.DecisionTreeClassifier()
clf.fit(X, y)


Out[3]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')

Transpile classifier


In [4]:
from sklearn_porter import Porter

porter = Porter(clf, language='go')
output = porter.export(embed_data=True)

print(output)


package main

import (
	"os"
	"fmt"
	"strconv"
)

func predict(features []float64) int {
	var classes [3]float64
		
	if features[2] <= 2.449999988079071 {
		classes[0] = 50
		classes[1] = 0
		classes[2] = 0
	} else {
		if features[3] <= 1.75 {
			if features[2] <= 4.950000047683716 {
				if features[3] <= 1.6500000357627869 {
					classes[0] = 0
					classes[1] = 47
					classes[2] = 0
				} else {
					classes[0] = 0
					classes[1] = 0
					classes[2] = 1
				}
			} else {
				if features[3] <= 1.550000011920929 {
					classes[0] = 0
					classes[1] = 0
					classes[2] = 3
				} else {
					if features[2] <= 5.450000047683716 {
						classes[0] = 0
						classes[1] = 2
						classes[2] = 0
					} else {
						classes[0] = 0
						classes[1] = 0
						classes[2] = 1
					}
				}
			}
		} else {
			if features[2] <= 4.8500001430511475 {
				if features[1] <= 3.100000023841858 {
					classes[0] = 0
					classes[1] = 0
					classes[2] = 2
				} else {
					classes[0] = 0
					classes[1] = 1
					classes[2] = 0
				}
			} else {
				classes[0] = 0
				classes[1] = 0
				classes[2] = 43
			}
		}
	}

    var index int = 0
	for i := 0; i < len(classes); i++ {
	    if classes[i] > classes[index] {
	        index = i
	    }
	}
	return index
}

func main() {

	// Features:
	var features []float64
	for _, arg := range os.Args[1:] {
		if n, err := strconv.ParseFloat(arg, 64); err == nil {
			features = append(features, n)
		}
	}

	// Prediction:
	var estimation = predict(features)
	fmt.Printf("%d\n", estimation)

}