sklearn-porter

Repository: https://github.com/nok/sklearn-porter

DecisionTreeClassifier

Documentation: sklearn.tree.DecisionTreeClassifier


In [1]:
import sys
sys.path.append('../../../../..')

Load data


In [2]:
from sklearn.datasets import load_iris

iris_data = load_iris()

X = iris_data.data
y = iris_data.target

print(X.shape, y.shape)


((150, 4), (150,))

Train classifier


In [3]:
from sklearn.tree import tree

clf = tree.DecisionTreeClassifier()
clf.fit(X, y)


Out[3]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')

Transpile classifier


In [4]:
from sklearn_porter import Porter

porter = Porter(clf, language='ruby')
output = porter.export(embed_data=True)

print(output)


class DecisionTreeClassifier
	def self.predict (atts)
    	classes = Array.new(3, 0)
    	    
        if features[2] <= 2.449999988079071
            classes[0] = 50 
            classes[1] = 0 
            classes[2] = 0 
        else
            if features[3] <= 1.75
                if features[2] <= 4.950000047683716
                    if features[3] <= 1.6500000357627869
                        classes[0] = 0 
                        classes[1] = 47 
                        classes[2] = 0 
                    else
                        classes[0] = 0 
                        classes[1] = 0 
                        classes[2] = 1 
                    end
                else
                    if features[3] <= 1.550000011920929
                        classes[0] = 0 
                        classes[1] = 0 
                        classes[2] = 3 
                    else
                        if features[2] <= 5.450000047683716
                            classes[0] = 0 
                            classes[1] = 2 
                            classes[2] = 0 
                        else
                            classes[0] = 0 
                            classes[1] = 0 
                            classes[2] = 1 
                        end
                    end
                end
            else
                if features[2] <= 4.8500001430511475
                    if features[1] <= 3.100000023841858
                        classes[0] = 0 
                        classes[1] = 0 
                        classes[2] = 2 
                    else
                        classes[0] = 0 
                        classes[1] = 1 
                        classes[2] = 0 
                    end
                else
                    classes[0] = 0 
                    classes[1] = 0 
                    classes[2] = 43 
                end
            end
        end
    
    	pos = classes.each_with_index.select {|e, i| e==classes.max}.map &:last
    	return pos.min
    end
end

if ARGV.length == 4

	# Features:
	features = ARGV.collect { |i| i.to_f }

	# Prediction:
	puts DecisionTreeClassifier.predict(features)

end