sklearn-porter

Repository: https://github.com/nok/sklearn-porter

SVC

Documentation: sklearn.svm.SVC


In [1]:
import sys
sys.path.append('../../../../..')

Load data


In [2]:
from sklearn.datasets import load_iris

iris_data = load_iris()

X = iris_data.data
y = iris_data.target

print(X.shape, y.shape)


((150, 4), (150,))

Train classifier


In [3]:
from sklearn import svm

clf = svm.SVC(C=1., gamma=0.001, kernel='rbf', random_state=0)
clf.fit(X, y)


Out[3]:
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape='ovr', degree=3, gamma=0.001, kernel='rbf',
  max_iter=-1, probability=False, random_state=0, shrinking=True,
  tol=0.001, verbose=False)

Transpile classifier


In [4]:
from sklearn_porter import Porter

porter = Porter(clf, language='php')
output = porter.export()

print(output)


<?php

class SVC {

    public function __construct($nClasses, $nRows, $vectors, $coefficients, $intercepts, $weights, $kernel, $gamma, $coef0, $degree) {
        $this->nClasses = $nClasses;
        $this->classes = array_fill(0, $nClasses, 0);
        for ($i = 0; $i < $nClasses; $i++) {
            $this->classes[$i] = $i;
        }
        $this->nRows = $nRows;

        $this->vectors = $vectors;
        $this->coefficients = $coefficients;
        $this->intercepts = $intercepts;
        $this->weights = $weights;

        $this->kernel = strtoupper($kernel);
        $this->gamma = $gamma;
        $this->coef0 = $coef0;
        $this->degree = $degree;
    }

    public function predict($features) {
    
        $kernels = array_fill(0, count($this->vectors), 0);
        $kernel;
        switch ($this->kernel) {
            case 'LINEAR':
                // <x,x'>
                for ($i = 0; $i < count($this->vectors); $i++) {
                    $kernel = 0.;
                    for ($j = 0; $j < count($this->vectors[$i]); $j++) {
                        $kernel += $this->vectors[$i][$j] * $features[$j];
                    }
                    $kernels[$i] = $kernel;
                }
                break;
            case 'POLY':
                // (y<x,x'>+r)^d
                for ($i = 0; $i < count($this->vectors); $i++) {
                    $kernel = 0.;
                    for ($j = 0; $j < count($this->vectors[$i]); $j++) {
                        $kernel += $this->vectors[$i][$j] * $features[$j];
                    }
                    $kernels[$i] = pow(($this->gamma * $kernel) + $this->coef0, $this->degree);
                }
                break;
            case 'RBF':
                // exp(-y|x-x'|^2)
                for ($i = 0; $i < count($this->vectors); $i++) {
                    $kernel = 0.;
                    for ($j = 0; $j < count($this->vectors[$i]); $j++) {
                        $kernel += pow($this->vectors[$i][$j] - $features[$j], 2);
                    }
                    $kernels[$i] = exp(-$this->gamma * $kernel);
                }
                break;
            case 'SIGMOID':
                // tanh(y<x,x'>+r)
                for ($i = 0; $i < count($this->vectors); $i++) {
                    $kernel = 0.;
                    for ($j = 0; $j < count($this->vectors[$i]); $j++) {
                        $kernel += $this->vectors[$i][$j] * $features[$j];
                    }
                    $kernels[$i] = tanh(($this->gamma * $kernel) + $this->coef0);
                }
                break;
        }
    
        $starts = array_fill(0, $this->nRows, 0);
        for ($i = 0; $i < $this->nRows; $i++) {
            if ($i != 0) {
                $start = 0;
                for ($j = 0; $j < $i; $j++) {
                    $start += $this->weights[$j];
                }
                $starts[$i] = $start;
            } else {
                $starts[0] = 0;
            }
        }
    
        $ends = array_fill(0, $this->nRows, 0);
        for ($i = 0; $i < $this->nRows; $i++) {
            $ends[$i] = $this->weights[$i] + $starts[$i];
        }
    
        if ($this->nClasses == 2) {
    
            for ($i = 0; $i < count($kernels); $i++) {
                $kernels[$i] = -$kernels[$i];
            }
    
            $decision = 0.;
            for ($k = $starts[1]; $k < $ends[1]; $k++) {
                $decision += $kernels[$k] * $this->coefficients[0][$k];
            }
            for ($k = $starts[0]; $k < $ends[0]; $k++) {
                $decision += $kernels[$k] * $this->coefficients[0][$k];
            }
            $decision += $this->intercepts[0];
    
            if ($decision > 0) {
                return 0;
            }
            return 1;
    
        }
    
        $decisions = array_fill(0, count($this->intercepts), 0);
        for ($i = 0, $d = 0, $l = $this->nRows; $i < $l; $i++) {
            for ($j = $i + 1; $j < $l; $j++) {
                $tmp = 0.;
                for ($k = $starts[$j]; $k < $ends[$j]; $k++) {
                    $tmp += $this->coefficients[$i][$k] * $kernels[$k];
                }
                for ($k = $starts[$i]; $k < $ends[$i]; $k++) {
                    $tmp += $this->coefficients[$j - 1][$k] * $kernels[$k];
                }
                $decisions[$d] = $tmp + $this->intercepts[$d];
                $d++;
            }
        }
    
        $votes = array_fill(0, count($this->intercepts), 0);
        for ($i = 0, $d = 0, $l = $this->nRows; $i < $l; $i++) {
            for ($j = $i + 1; $j < $l; $j++) {
                $votes[$d] = $decisions[$d] > 0 ? $i : $j;
                $d++;
            }
        }
    
        $amounts = array_fill(0, $this->nClasses, 0);
        for ($i = 0, $l = count($votes); $i < $l; $i++) {
            $amounts[$votes[$i]] += 1;
        }
    
        $classVal = -1;
        $classIdx = -1;
        for ($i = 0, $l = count($amounts); $i < $l; $i++) {
            if ($amounts[$i] > $classVal) {
                $classVal = $amounts[$i];
                $classIdx = $i;
            }
        }
        return $this->classes[$classIdx];
    
    }

}

if ($argc > 1) {

    // Features:
    array_shift($argv);
    $features = $argv;

    // Parameters:
    $vectors = [[5.1, 3.5, 1.4, 0.2], [4.9, 3.0, 1.4, 0.2], [4.7, 3.2, 1.3, 0.2], [4.6, 3.1, 1.5, 0.2], [5.0, 3.6, 1.4, 0.2], [5.4, 3.9, 1.7, 0.4], [4.6, 3.4, 1.4, 0.3], [5.0, 3.4, 1.5, 0.2], [4.4, 2.9, 1.4, 0.2], [4.9, 3.1, 1.5, 0.1], [5.4, 3.7, 1.5, 0.2], [4.8, 3.4, 1.6, 0.2], [4.8, 3.0, 1.4, 0.1], [4.3, 3.0, 1.1, 0.1], [5.8, 4.0, 1.2, 0.2], [5.7, 4.4, 1.5, 0.4], [5.4, 3.9, 1.3, 0.4], [5.1, 3.5, 1.4, 0.3], [5.7, 3.8, 1.7, 0.3], [5.1, 3.8, 1.5, 0.3], [5.4, 3.4, 1.7, 0.2], [5.1, 3.7, 1.5, 0.4], [4.6, 3.6, 1.0, 0.2], [5.1, 3.3, 1.7, 0.5], [4.8, 3.4, 1.9, 0.2], [5.0, 3.0, 1.6, 0.2], [5.0, 3.4, 1.6, 0.4], [5.2, 3.5, 1.5, 0.2], [5.2, 3.4, 1.4, 0.2], [4.7, 3.2, 1.6, 0.2], [4.8, 3.1, 1.6, 0.2], [5.4, 3.4, 1.5, 0.4], [5.2, 4.1, 1.5, 0.1], [5.5, 4.2, 1.4, 0.2], [4.9, 3.1, 1.5, 0.2], [5.0, 3.2, 1.2, 0.2], [5.5, 3.5, 1.3, 0.2], [4.9, 3.6, 1.4, 0.1], [4.4, 3.0, 1.3, 0.2], [5.1, 3.4, 1.5, 0.2], [5.0, 3.5, 1.3, 0.3], [4.5, 2.3, 1.3, 0.3], [4.4, 3.2, 1.3, 0.2], [5.0, 3.5, 1.6, 0.6], [5.1, 3.8, 1.9, 0.4], [4.8, 3.0, 1.4, 0.3], [5.1, 3.8, 1.6, 0.2], [4.6, 3.2, 1.4, 0.2], [5.3, 3.7, 1.5, 0.2], [5.0, 3.3, 1.4, 0.2], [7.0, 3.2, 4.7, 1.4], [6.4, 3.2, 4.5, 1.5], [6.9, 3.1, 4.9, 1.5], [5.5, 2.3, 4.0, 1.3], [6.5, 2.8, 4.6, 1.5], [5.7, 2.8, 4.5, 1.3], [6.3, 3.3, 4.7, 1.6], [4.9, 2.4, 3.3, 1.0], [6.6, 2.9, 4.6, 1.3], [5.2, 2.7, 3.9, 1.4], [5.0, 2.0, 3.5, 1.0], [5.9, 3.0, 4.2, 1.5], [6.0, 2.2, 4.0, 1.0], [6.1, 2.9, 4.7, 1.4], [5.6, 2.9, 3.6, 1.3], [6.7, 3.1, 4.4, 1.4], [5.6, 3.0, 4.5, 1.5], [5.8, 2.7, 4.1, 1.0], [6.2, 2.2, 4.5, 1.5], [5.6, 2.5, 3.9, 1.1], [5.9, 3.2, 4.8, 1.8], [6.1, 2.8, 4.0, 1.3], [6.3, 2.5, 4.9, 1.5], [6.1, 2.8, 4.7, 1.2], [6.4, 2.9, 4.3, 1.3], [6.6, 3.0, 4.4, 1.4], [6.8, 2.8, 4.8, 1.4], [6.7, 3.0, 5.0, 1.7], [6.0, 2.9, 4.5, 1.5], [5.7, 2.6, 3.5, 1.0], [5.5, 2.4, 3.8, 1.1], [5.5, 2.4, 3.7, 1.0], [5.8, 2.7, 3.9, 1.2], [6.0, 2.7, 5.1, 1.6], [5.4, 3.0, 4.5, 1.5], [6.0, 3.4, 4.5, 1.6], [6.7, 3.1, 4.7, 1.5], [6.3, 2.3, 4.4, 1.3], [5.6, 3.0, 4.1, 1.3], [5.5, 2.5, 4.0, 1.3], [5.5, 2.6, 4.4, 1.2], [6.1, 3.0, 4.6, 1.4], [5.8, 2.6, 4.0, 1.2], [5.0, 2.3, 3.3, 1.0], [5.6, 2.7, 4.2, 1.3], [5.7, 3.0, 4.2, 1.2], [5.7, 2.9, 4.2, 1.3], [6.2, 2.9, 4.3, 1.3], [5.1, 2.5, 3.0, 1.1], [5.7, 2.8, 4.1, 1.3], [6.3, 3.3, 6.0, 2.5], [5.8, 2.7, 5.1, 1.9], [7.1, 3.0, 5.9, 2.1], [6.3, 2.9, 5.6, 1.8], [6.5, 3.0, 5.8, 2.2], [7.6, 3.0, 6.6, 2.1], [4.9, 2.5, 4.5, 1.7], [7.3, 2.9, 6.3, 1.8], [6.7, 2.5, 5.8, 1.8], [7.2, 3.6, 6.1, 2.5], [6.5, 3.2, 5.1, 2.0], [6.4, 2.7, 5.3, 1.9], [6.8, 3.0, 5.5, 2.1], [5.7, 2.5, 5.0, 2.0], [5.8, 2.8, 5.1, 2.4], [6.4, 3.2, 5.3, 2.3], [6.5, 3.0, 5.5, 1.8], [7.7, 3.8, 6.7, 2.2], [7.7, 2.6, 6.9, 2.3], [6.0, 2.2, 5.0, 1.5], [6.9, 3.2, 5.7, 2.3], [5.6, 2.8, 4.9, 2.0], [7.7, 2.8, 6.7, 2.0], [6.3, 2.7, 4.9, 1.8], [6.7, 3.3, 5.7, 2.1], [7.2, 3.2, 6.0, 1.8], [6.2, 2.8, 4.8, 1.8], [6.1, 3.0, 4.9, 1.8], [6.4, 2.8, 5.6, 2.1], [7.2, 3.0, 5.8, 1.6], [7.4, 2.8, 6.1, 1.9], [7.9, 3.8, 6.4, 2.0], [6.4, 2.8, 5.6, 2.2], [6.3, 2.8, 5.1, 1.5], [6.1, 2.6, 5.6, 1.4], [7.7, 3.0, 6.1, 2.3], [6.3, 3.4, 5.6, 2.4], [6.4, 3.1, 5.5, 1.8], [6.0, 3.0, 4.8, 1.8], [6.9, 3.1, 5.4, 2.1], [6.7, 3.1, 5.6, 2.4], [6.9, 3.1, 5.1, 2.3], [5.8, 2.7, 5.1, 1.9], [6.8, 3.2, 5.9, 2.3], [6.7, 3.3, 5.7, 2.5], [6.7, 3.0, 5.2, 2.3], [6.3, 2.5, 5.0, 1.9], [6.5, 3.0, 5.2, 2.0], [6.2, 3.4, 5.4, 2.3], [5.9, 3.0, 5.1, 1.8]];
    $coefficients = [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -0.0, -1.0, -0.0, -1.0, -0.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -0.0, -0.0, -1.0, -1.0, -1.0, -0.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -0.0, -0.0, -1.0, -1.0, -1.0, -0.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]];
    $intercepts = [0.04357001185417175, 0.11042118072509766, -0.0031709671020507812];
    $weights = [50, 50, 50];

    // Prediction:
    $clf = new SVC(3, 3, $vectors, $coefficients, $intercepts, $weights, "rbf", 0.001, 0.0, 3);
    $prediction = $clf->predict($features);
    fwrite(STDOUT, $prediction);

}