Repository: https://github.com/nok/sklearn-porter
Documentation: sklearn.svm.NuSVC
In [1]:
import sys
sys.path.append('../../../../..')
In [2]:
from sklearn.datasets import load_iris
iris_data = load_iris()
X = iris_data.data
y = iris_data.target
print(X.shape, y.shape)
((150, 4), (150,))
In [3]:
from sklearn import svm
clf = svm.NuSVC(gamma=0.001, kernel='rbf', random_state=0)
clf.fit(X, y)
Out[3]:
NuSVC(cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape='ovr', degree=3, gamma=0.001, kernel='rbf',
max_iter=-1, nu=0.5, probability=False, random_state=0, shrinking=True,
tol=0.001, verbose=False)
In [4]:
from sklearn_porter import Porter
porter = Porter(clf, language='java')
output = porter.export()
print(output)
class NuSVC {
private enum Kernel { LINEAR, POLY, RBF, SIGMOID }
private int nClasses;
private int nRows;
private int[] classes;
private double[][] vectors;
private double[][] coefficients;
private double[] intercepts;
private int[] weights;
private Kernel kernel;
private double gamma;
private double coef0;
private double degree;
public NuSVC (int nClasses, int nRows, double[][] vectors, double[][] coefficients, double[] intercepts, int[] weights, String kernel, double gamma, double coef0, double degree) {
this.nClasses = nClasses;
this.classes = new int[nClasses];
for (int i = 0; i < nClasses; i++) {
this.classes[i] = i;
}
this.nRows = nRows;
this.vectors = vectors;
this.coefficients = coefficients;
this.intercepts = intercepts;
this.weights = weights;
this.kernel = Kernel.valueOf(kernel.toUpperCase());
this.gamma = gamma;
this.coef0 = coef0;
this.degree = degree;
}
public int predict(double[] features) {
double[] kernels = new double[vectors.length];
double kernel;
switch (this.kernel) {
case LINEAR:
// <x,x'>
for (int i = 0; i < this.vectors.length; i++) {
kernel = 0.;
for (int j = 0; j < this.vectors[i].length; j++) {
kernel += this.vectors[i][j] * features[j];
}
kernels[i] = kernel;
}
break;
case POLY:
// (y<x,x'>+r)^d
for (int i = 0; i < this.vectors.length; i++) {
kernel = 0.;
for (int j = 0; j < this.vectors[i].length; j++) {
kernel += this.vectors[i][j] * features[j];
}
kernels[i] = Math.pow((this.gamma * kernel) + this.coef0, this.degree);
}
break;
case RBF:
// exp(-y|x-x'|^2)
for (int i = 0; i < this.vectors.length; i++) {
kernel = 0.;
for (int j = 0; j < this.vectors[i].length; j++) {
kernel += Math.pow(this.vectors[i][j] - features[j], 2);
}
kernels[i] = Math.exp(-this.gamma * kernel);
}
break;
case SIGMOID:
// tanh(y<x,x'>+r)
for (int i = 0; i < this.vectors.length; i++) {
kernel = 0.;
for (int j = 0; j < this.vectors[i].length; j++) {
kernel += this.vectors[i][j] * features[j];
}
kernels[i] = Math.tanh((this.gamma * kernel) + this.coef0);
}
break;
}
int[] starts = new int[this.nRows];
for (int i = 0; i < this.nRows; i++) {
if (i != 0) {
int start = 0;
for (int j = 0; j < i; j++) {
start += this.weights[j];
}
starts[i] = start;
} else {
starts[0] = 0;
}
}
int[] ends = new int[this.nRows];
for (int i = 0; i < this.nRows; i++) {
ends[i] = this.weights[i] + starts[i];
}
if (this.nClasses == 2) {
for (int i = 0; i < kernels.length; i++) {
kernels[i] = -kernels[i];
}
double decision = 0.;
for (int k = starts[1]; k < ends[1]; k++) {
decision += kernels[k] * this.coefficients[0][k];
}
for (int k = starts[0]; k < ends[0]; k++) {
decision += kernels[k] * this.coefficients[0][k];
}
decision += this.intercepts[0];
if (decision > 0) {
return 0;
}
return 1;
}
double[] decisions = new double[this.intercepts.length];
for (int i = 0, d = 0, l = this.nRows; i < l; i++) {
for (int j = i + 1; j < l; j++) {
double tmp = 0.;
for (int k = starts[j]; k < ends[j]; k++) {
tmp += this.coefficients[i][k] * kernels[k];
}
for (int k = starts[i]; k < ends[i]; k++) {
tmp += this.coefficients[j - 1][k] * kernels[k];
}
decisions[d] = tmp + this.intercepts[d];
d++;
}
}
int[] votes = new int[this.intercepts.length];
for (int i = 0, d = 0, l = this.nRows; i < l; i++) {
for (int j = i + 1; j < l; j++) {
votes[d] = decisions[d] > 0 ? i : j;
d++;
}
}
int[] amounts = new int[this.nClasses];
for (int i = 0, l = votes.length; i < l; i++) {
amounts[votes[i]] += 1;
}
int classVal = -1, classIdx = -1;
for (int i = 0, l = amounts.length; i < l; i++) {
if (amounts[i] > classVal) {
classVal = amounts[i];
classIdx= i;
}
}
return this.classes[classIdx];
}
public static void main(String[] args) {
if (args.length == 4) {
// Features:
double[] features = new double[args.length];
for (int i = 0, l = args.length; i < l; i++) {
features[i] = Double.parseDouble(args[i]);
}
// Parameters:
double[][] vectors = {{4.9, 3.0, 1.4, 0.2}, {4.6, 3.1, 1.5, 0.2}, {5.4, 3.9, 1.7, 0.4}, {5.0, 3.4, 1.5, 0.2}, {4.9, 3.1, 1.5, 0.1}, {5.4, 3.7, 1.5, 0.2}, {4.8, 3.4, 1.6, 0.2}, {5.7, 4.4, 1.5, 0.4}, {5.7, 3.8, 1.7, 0.3}, {5.1, 3.8, 1.5, 0.3}, {5.4, 3.4, 1.7, 0.2}, {5.1, 3.7, 1.5, 0.4}, {5.1, 3.3, 1.7, 0.5}, {4.8, 3.4, 1.9, 0.2}, {5.0, 3.0, 1.6, 0.2}, {5.0, 3.4, 1.6, 0.4}, {5.2, 3.5, 1.5, 0.2}, {4.7, 3.2, 1.6, 0.2}, {4.8, 3.1, 1.6, 0.2}, {5.4, 3.4, 1.5, 0.4}, {4.9, 3.1, 1.5, 0.2}, {5.1, 3.4, 1.5, 0.2}, {4.5, 2.3, 1.3, 0.3}, {5.0, 3.5, 1.6, 0.6}, {5.1, 3.8, 1.9, 0.4}, {4.8, 3.0, 1.4, 0.3}, {5.1, 3.8, 1.6, 0.2}, {5.3, 3.7, 1.5, 0.2}, {7.0, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.9, 3.1, 4.9, 1.5}, {5.5, 2.3, 4.0, 1.3}, {6.5, 2.8, 4.6, 1.5}, {5.7, 2.8, 4.5, 1.3}, {6.3, 3.3, 4.7, 1.6}, {4.9, 2.4, 3.3, 1.0}, {6.6, 2.9, 4.6, 1.3}, {5.2, 2.7, 3.9, 1.4}, {5.0, 2.0, 3.5, 1.0}, {5.9, 3.0, 4.2, 1.5}, {6.0, 2.2, 4.0, 1.0}, {6.1, 2.9, 4.7, 1.4}, {5.6, 2.9, 3.6, 1.3}, {6.7, 3.1, 4.4, 1.4}, {5.6, 3.0, 4.5, 1.5}, {5.8, 2.7, 4.1, 1.0}, {6.2, 2.2, 4.5, 1.5}, {5.6, 2.5, 3.9, 1.1}, {5.9, 3.2, 4.8, 1.8}, {6.1, 2.8, 4.0, 1.3}, {6.3, 2.5, 4.9, 1.5}, {6.1, 2.8, 4.7, 1.2}, {6.6, 3.0, 4.4, 1.4}, {6.8, 2.8, 4.8, 1.4}, {6.7, 3.0, 5.0, 1.7}, {6.0, 2.9, 4.5, 1.5}, {5.7, 2.6, 3.5, 1.0}, {5.5, 2.4, 3.8, 1.1}, {5.5, 2.4, 3.7, 1.0}, {5.8, 2.7, 3.9, 1.2}, {6.0, 2.7, 5.1, 1.6}, {5.4, 3.0, 4.5, 1.5}, {6.0, 3.4, 4.5, 1.6}, {6.7, 3.1, 4.7, 1.5}, {6.3, 2.3, 4.4, 1.3}, {5.6, 3.0, 4.1, 1.3}, {5.5, 2.5, 4.0, 1.3}, {5.5, 2.6, 4.4, 1.2}, {6.1, 3.0, 4.6, 1.4}, {5.8, 2.6, 4.0, 1.2}, {5.0, 2.3, 3.3, 1.0}, {5.6, 2.7, 4.2, 1.3}, {5.7, 3.0, 4.2, 1.2}, {5.7, 2.9, 4.2, 1.3}, {6.2, 2.9, 4.3, 1.3}, {5.1, 2.5, 3.0, 1.1}, {5.7, 2.8, 4.1, 1.3}, {5.8, 2.7, 5.1, 1.9}, {6.3, 2.9, 5.6, 1.8}, {4.9, 2.5, 4.5, 1.7}, {6.5, 3.2, 5.1, 2.0}, {6.4, 2.7, 5.3, 1.9}, {5.7, 2.5, 5.0, 2.0}, {5.8, 2.8, 5.1, 2.4}, {6.4, 3.2, 5.3, 2.3}, {6.5, 3.0, 5.5, 1.8}, {6.0, 2.2, 5.0, 1.5}, {5.6, 2.8, 4.9, 2.0}, {6.3, 2.7, 4.9, 1.8}, {6.2, 2.8, 4.8, 1.8}, {6.1, 3.0, 4.9, 1.8}, {7.2, 3.0, 5.8, 1.6}, {6.3, 2.8, 5.1, 1.5}, {6.1, 2.6, 5.6, 1.4}, {6.4, 3.1, 5.5, 1.8}, {6.0, 3.0, 4.8, 1.8}, {6.9, 3.1, 5.4, 2.1}, {6.9, 3.1, 5.1, 2.3}, {5.8, 2.7, 5.1, 1.9}, {6.7, 3.0, 5.2, 2.3}, {6.3, 2.5, 5.0, 1.9}, {6.5, 3.0, 5.2, 2.0}, {6.2, 3.4, 5.4, 2.3}, {5.9, 3.0, 5.1, 1.8}};
double[][] coefficients = {{4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 0.0, 4.680538527007988, 0.0, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 4.680538527007988, 0.0, -0.0, -0.0, -0.0, -4.680538527007988, -0.0, -0.0, -0.0, -4.680538527007988, -0.0, -4.680538527007988, -4.680538527007988, -4.680538527007988, -4.680538527007988, -0.0, -4.680538527007988, -0.0, -0.0, -4.680538527007988, -0.0, -4.680538527007988, -0.0, -4.680538527007988, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -4.680538527007988, -4.680538527007988, -4.680538527007988, -4.680538527007988, -0.0, -0.0, -0.0, -0.0, -0.0, -4.680538527007988, -4.680538527007988, -4.680538527007988, -0.0, -4.680538527007988, -4.680538527007988, -4.680538527007988, -4.680538527007988, -4.680538527007988, -4.680538527007988, -4.680538527007988, -4.680538527007988, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -0.0, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -0.0, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366, -2.1228182659346366}, {0.0, 0.0, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 0.0, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 2.1228182659346366, 47.52934177369389, 47.52934177369389, 47.52934177369389, 0.0, 47.52934177369389, 47.52934177369389, 47.52934177369389, 0.0, 47.52934177369389, 0.0, 0.0, 0.0, 0.0, 47.52934177369389, 0.0, 47.52934177369389, 47.52934177369389, 0.0, 47.52934177369389, 0.0, 47.52934177369389, 0.0, 47.52934177369389, 47.52934177369389, 47.52934177369389, 47.52934177369389, 47.52934177369389, 47.52934177369389, 0.0, 0.0, 0.0, 0.0, 47.52934177369389, 47.52934177369389, 47.52934177369389, 47.52934177369389, 47.52934177369389, 0.0, 0.0, 47.52934177369389, 47.52934177369389, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -0.0, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -47.52934177369389, -0.0, -47.52934177369389}};
double[] intercepts = {0.09572808365772528, 0.049757317370245795, -0.08418168966801846};
int[] weights = {28, 49, 27};
// Prediction:
NuSVC clf = new NuSVC(3, 3, vectors, coefficients, intercepts, weights, "rbf", 0.001, 0.0, 3);
int estimation = clf.predict(features);
System.out.println(estimation);
}
}
}
In [5]:
# Save classifier:
# with open('NuSVC.java', 'w') as f:
# f.write(output)
# Compile model:
# $ javac -cp . NuSVC.java
# Run classification:
# $ java NuSVC 1 2 3 4
Content source: nok/sklearn-porter
Similar notebooks: