sklearn-porter

Repository: https://github.com/nok/sklearn-porter

NuSVC

Documentation: sklearn.svm.NuSVC


In [1]:
import sys
sys.path.append('../../../../..')

Load data


In [2]:
from sklearn.datasets import load_iris

iris_data = load_iris()

X = iris_data.data
y = iris_data.target

print(X.shape, y.shape)


((150, 4), (150,))

Train classifier


In [3]:
from sklearn import svm

clf = svm.NuSVC(gamma=0.001, kernel='rbf', random_state=0)
clf.fit(X, y)


Out[3]:
NuSVC(cache_size=200, class_weight=None, coef0=0.0,
   decision_function_shape='ovr', degree=3, gamma=0.001, kernel='rbf',
   max_iter=-1, nu=0.5, probability=False, random_state=0, shrinking=True,
   tol=0.001, verbose=False)

Transpile classifier


In [4]:
from sklearn_porter import Porter

porter = Porter(clf, language='java')
output = porter.export()

print(output)


class NuSVC {

    private enum Kernel { LINEAR, POLY, RBF, SIGMOID }

    private int nClasses;
    private int nRows;
    private int[] classes;
    private double[][] vectors;
    private double[][] coefficients;
    private double[] intercepts;
    private int[] weights;
    private Kernel kernel;
    private double gamma;
    private double coef0;
    private double degree;

    public NuSVC (int nClasses, int nRows, double[][] vectors, double[][] coefficients, double[] intercepts, int[] weights, String kernel, double gamma, double coef0, double degree) {
        this.nClasses = nClasses;
        this.classes = new int[nClasses];
        for (int i = 0; i < nClasses; i++) {
            this.classes[i] = i;
        }
        this.nRows = nRows;

        this.vectors = vectors;
        this.coefficients = coefficients;
        this.intercepts = intercepts;
        this.weights = weights;

        this.kernel = Kernel.valueOf(kernel.toUpperCase());
        this.gamma = gamma;
        this.coef0 = coef0;
        this.degree = degree;
    }

    public int predict(double[] features) {
    
        double[] kernels = new double[vectors.length];
        double kernel;
        switch (this.kernel) {
            case LINEAR:
                // <x,x'>
                for (int i = 0; i < this.vectors.length; i++) {
                    kernel = 0.;
                    for (int j = 0; j < this.vectors[i].length; j++) {
                        kernel += this.vectors[i][j] * features[j];
                    }
                    kernels[i] = kernel;
                }
                break;
            case POLY:
                // (y<x,x'>+r)^d
                for (int i = 0; i < this.vectors.length; i++) {
                    kernel = 0.;
                    for (int j = 0; j < this.vectors[i].length; j++) {
                        kernel += this.vectors[i][j] * features[j];
                    }
                    kernels[i] = Math.pow((this.gamma * kernel) + this.coef0, this.degree);
                }
                break;
            case RBF:
                // exp(-y|x-x'|^2)
                for (int i = 0; i < this.vectors.length; i++) {
                    kernel = 0.;
                    for (int j = 0; j < this.vectors[i].length; j++) {
                        kernel += Math.pow(this.vectors[i][j] - features[j], 2);
                    }
                    kernels[i] = Math.exp(-this.gamma * kernel);
                }
                break;
            case SIGMOID:
                // tanh(y<x,x'>+r)
                for (int i = 0; i < this.vectors.length; i++) {
                    kernel = 0.;
                    for (int j = 0; j < this.vectors[i].length; j++) {
                        kernel += this.vectors[i][j] * features[j];
                    }
                    kernels[i] = Math.tanh((this.gamma * kernel) + this.coef0);
                }
                break;
        }
    
        int[] starts = new int[this.nRows];
        for (int i = 0; i < this.nRows; i++) {
            if (i != 0) {
                int start = 0;
                for (int j = 0; j < i; j++) {
                    start += this.weights[j];
                }
                starts[i] = start;
            } else {
                starts[0] = 0;
            }
        }
    
        int[] ends = new int[this.nRows];
        for (int i = 0; i < this.nRows; i++) {
            ends[i] = this.weights[i] + starts[i];
        }
    
        if (this.nClasses == 2) {
    
            for (int i = 0; i < kernels.length; i++) {
                kernels[i] = -kernels[i];
            }
    
            double decision = 0.;
            for (int k = starts[1]; k < ends[1]; k++) {
                decision += kernels[k] * this.coefficients[0][k];
            }
            for (int k = starts[0]; k < ends[0]; k++) {
                decision += kernels[k] * this.coefficients[0][k];
            }
            decision += this.intercepts[0];
    
            if (decision > 0) {
                return 0;
            }
            return 1;
    
        }
    
        double[] decisions = new double[this.intercepts.length];
        for (int i = 0, d = 0, l = this.nRows; i < l; i++) {
            for (int j = i + 1; j < l; j++) {
                double tmp = 0.;
                for (int k = starts[j]; k < ends[j]; k++) {
                    tmp += this.coefficients[i][k] * kernels[k];
                }
                for (int k = starts[i]; k < ends[i]; k++) {
                    tmp += this.coefficients[j - 1][k] * kernels[k];
                }
                decisions[d] = tmp + this.intercepts[d];
                d++;
            }
        }
    
        int[] votes = new int[this.intercepts.length];
        for (int i = 0, d = 0, l = this.nRows; i < l; i++) {
            for (int j = i + 1; j < l; j++) {
                votes[d] = decisions[d] > 0 ? i : j;
                d++;
            }
        }
    
        int[] amounts = new int[this.nClasses];
        for (int i = 0, l = votes.length; i < l; i++) {
            amounts[votes[i]] += 1;
        }
    
        int classVal = -1, classIdx = -1;
        for (int i = 0, l = amounts.length; i < l; i++) {
            if (amounts[i] > classVal) {
                classVal = amounts[i];
                classIdx= i;
            }
        }
        return this.classes[classIdx];
    
    }

    public static void main(String[] args) {
        if (args.length == 4) {

            // Features:
            double[] features = new double[args.length];
            for (int i = 0, l = args.length; i < l; i++) {
                features[i] = Double.parseDouble(args[i]);
            }

            // Parameters:
            double[][] vectors = {{4.9, 3.0, 1.4, 0.2}, {4.6, 3.1, 1.5, 0.2}, {5.4, 3.9, 1.7, 0.4}, {5.0, 3.4, 1.5, 0.2}, {4.9, 3.1, 1.5, 0.1}, {5.4, 3.7, 1.5, 0.2}, {4.8, 3.4, 1.6, 0.2}, {5.7, 4.4, 1.5, 0.4}, {5.7, 3.8, 1.7, 0.3}, {5.1, 3.8, 1.5, 0.3}, {5.4, 3.4, 1.7, 0.2}, {5.1, 3.7, 1.5, 0.4}, {5.1, 3.3, 1.7, 0.5}, {4.8, 3.4, 1.9, 0.2}, {5.0, 3.0, 1.6, 0.2}, {5.0, 3.4, 1.6, 0.4}, {5.2, 3.5, 1.5, 0.2}, {4.7, 3.2, 1.6, 0.2}, {4.8, 3.1, 1.6, 0.2}, {5.4, 3.4, 1.5, 0.4}, {4.9, 3.1, 1.5, 0.2}, {5.1, 3.4, 1.5, 0.2}, {4.5, 2.3, 1.3, 0.3}, {5.0, 3.5, 1.6, 0.6}, {5.1, 3.8, 1.9, 0.4}, {4.8, 3.0, 1.4, 0.3}, {5.1, 3.8, 1.6, 0.2}, {5.3, 3.7, 1.5, 0.2}, {7.0, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.9, 3.1, 4.9, 1.5}, {5.5, 2.3, 4.0, 1.3}, {6.5, 2.8, 4.6, 1.5}, {5.7, 2.8, 4.5, 1.3}, {6.3, 3.3, 4.7, 1.6}, {4.9, 2.4, 3.3, 1.0}, {6.6, 2.9, 4.6, 1.3}, {5.2, 2.7, 3.9, 1.4}, {5.0, 2.0, 3.5, 1.0}, {5.9, 3.0, 4.2, 1.5}, {6.0, 2.2, 4.0, 1.0}, {6.1, 2.9, 4.7, 1.4}, {5.6, 2.9, 3.6, 1.3}, {6.7, 3.1, 4.4, 1.4}, {5.6, 3.0, 4.5, 1.5}, {5.8, 2.7, 4.1, 1.0}, {6.2, 2.2, 4.5, 1.5}, {5.6, 2.5, 3.9, 1.1}, {5.9, 3.2, 4.8, 1.8}, {6.1, 2.8, 4.0, 1.3}, {6.3, 2.5, 4.9, 1.5}, {6.1, 2.8, 4.7, 1.2}, {6.6, 3.0, 4.4, 1.4}, {6.8, 2.8, 4.8, 1.4}, {6.7, 3.0, 5.0, 1.7}, {6.0, 2.9, 4.5, 1.5}, {5.7, 2.6, 3.5, 1.0}, {5.5, 2.4, 3.8, 1.1}, {5.5, 2.4, 3.7, 1.0}, {5.8, 2.7, 3.9, 1.2}, {6.0, 2.7, 5.1, 1.6}, {5.4, 3.0, 4.5, 1.5}, {6.0, 3.4, 4.5, 1.6}, {6.7, 3.1, 4.7, 1.5}, {6.3, 2.3, 4.4, 1.3}, {5.6, 3.0, 4.1, 1.3}, {5.5, 2.5, 4.0, 1.3}, {5.5, 2.6, 4.4, 1.2}, {6.1, 3.0, 4.6, 1.4}, {5.8, 2.6, 4.0, 1.2}, {5.0, 2.3, 3.3, 1.0}, {5.6, 2.7, 4.2, 1.3}, {5.7, 3.0, 4.2, 1.2}, {5.7, 2.9, 4.2, 1.3}, {6.2, 2.9, 4.3, 1.3}, {5.1, 2.5, 3.0, 1.1}, {5.7, 2.8, 4.1, 1.3}, {5.8, 2.7, 5.1, 1.9}, {6.3, 2.9, 5.6, 1.8}, {4.9, 2.5, 4.5, 1.7}, {6.5, 3.2, 5.1, 2.0}, {6.4, 2.7, 5.3, 1.9}, {5.7, 2.5, 5.0, 2.0}, {5.8, 2.8, 5.1, 2.4}, {6.4, 3.2, 5.3, 2.3}, {6.5, 3.0, 5.5, 1.8}, {6.0, 2.2, 5.0, 1.5}, {5.6, 2.8, 4.9, 2.0}, {6.3, 2.7, 4.9, 1.8}, {6.2, 2.8, 4.8, 1.8}, {6.1, 3.0, 4.9, 1.8}, {7.2, 3.0, 5.8, 1.6}, {6.3, 2.8, 5.1, 1.5}, {6.1, 2.6, 5.6, 1.4}, {6.4, 3.1, 5.5, 1.8}, {6.0, 3.0, 4.8, 1.8}, {6.9, 3.1, 5.4, 2.1}, {6.9, 3.1, 5.1, 2.3}, {5.8, 2.7, 5.1, 1.9}, {6.7, 3.0, 5.2, 2.3}, {6.3, 2.5, 5.0, 1.9}, {6.5, 3.0, 5.2, 2.0}, {6.2, 3.4, 5.4, 2.3}, {5.9, 3.0, 5.1, 1.8}};
            double[][] coefficients = {{4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 0.0, 4.680538527007988, 0.0, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 0.0, -0.0, -0.0, -0.0, -4.680538527007988, -0.0, -0.0, -0.0, -4.680538527007988, -0.0, -4.680538527007988, -4.680538527007988, -4.680538527007988, -4.680538527007988, -0.0, -4.680538527007988, -0.0, -0.0, -4.680538527007988, -0.0, -4.680538527007988, -0.0, -4.680538527007988, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -4.680538527007988, -4.680538527007988, -4.680538527007988, -4.680538527007988, -0.0, -0.0, -0.0, -0.0, -0.0, -4.680538527007988, -4.680538527007988, -4.680538527007988, -0.0, -4.680538527007988, -4.680538527007988, -4.680538527007988, -4.680538527007988, -4.680538527007988, -4.680538527007988, -4.680538527007988, -4.680538527007988, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -0.0, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -0.0, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366}, {0.0, 0.0, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 0.0, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 47.52934177369389, 47.52934177369389, 47.52934177369389, 0.0, 47.52934177369389, 47.52934177369389, 47.52934177369389, 0.0, 47.52934177369389, 0.0, 0.0, 0.0, 0.0, 47.52934177369389, 0.0, 47.52934177369389, 47.52934177369389, 0.0, 47.52934177369389, 0.0, 47.52934177369389, 0.0, 47.52934177369389, 47.52934177369389, 47.52934177369389, 47.52934177369389, 47.52934177369389, 47.52934177369389, 0.0, 0.0, 0.0, 0.0, 47.52934177369389, 47.52934177369389, 47.52934177369389, 47.52934177369389, 47.52934177369389, 0.0, 0.0, 47.52934177369389, 47.52934177369389, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -0.0, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -0.0, -47.52934177369389}};
            double[] intercepts = {0.09572808365772528, 0.049757317370245795, -0.08418168966801846};
            int[] weights = {28, 49, 27};

            // Prediction:
            NuSVC clf = new NuSVC(3, 3, vectors, coefficients, intercepts, weights, "rbf", 0.001, 0.0, 3);
            int estimation = clf.predict(features);
            System.out.println(estimation);

        }
    }
}

Run classification in Java


In [5]:
# Save classifier:
# with open('NuSVC.java', 'w') as f:
#     f.write(output)

# Compile model:
# $ javac -cp . NuSVC.java

# Run classification:
# $ java NuSVC 1 2 3 4