package main
import (
"os"
"fmt"
"strconv"
)
type DecisionTreeClassifier struct {
lChilds []int
rChilds []int
thresholds []float64
indices []int
classes [][]int
}
func (dtc DecisionTreeClassifier) predict_(features []float64, node int) int {
if dtc.thresholds[node] != -2 {
if features[dtc.indices[node]] <= dtc.thresholds[node] {
return dtc.predict_(features, dtc.lChilds[node])
} else {
return dtc.predict_(features, dtc.rChilds[node])
}
}
var index int = 0
for i := 0; i < len(dtc.classes[node]); i++ {
if dtc.classes[node][i] > dtc.classes[node][index] {
index = i
}
}
return index
}
func (dtc DecisionTreeClassifier) predict(features []float64) int {
return dtc.predict_(features, 0)
}
func main() {
// Features:
var features []float64
for _, arg := range os.Args[1:] {
if n, err := strconv.ParseFloat(arg, 64); err == nil {
features = append(features, n)
}
}
// Parameters:
lChilds := []int {1, -1, 3, 4, 5, -1, -1, 8, -1, 10, -1, -1, 13, 14, -1, -1, -1}
rChilds := []int {2, -1, 12, 7, 6, -1, -1, 9, -1, 11, -1, -1, 16, 15, -1, -1, -1}
thresholds := []float64 {0.800000011921, -2.0, 1.75, 4.95000004768, 1.65000003576, -2.0, -2.0, 1.55000001192, -2.0, 6.94999980927, -2.0, -2.0, 4.85000014305, 5.95000004768, -2.0, -2.0, -2.0}
indices := []int {3, -2, 3, 2, 3, -2, -2, 3, -2, 0, -2, -2, 2, 0, -2, -2, -2}
classes := [][]int {{50, 50, 50}, {50, 0, 0}, {0, 50, 50}, {0, 49, 5}, {0, 47, 1}, {0, 47, 0}, {0, 0, 1}, {0, 2, 4}, {0, 0, 3}, {0, 2, 1}, {0, 2, 0}, {0, 0, 1}, {0, 1, 45}, {0, 1, 2}, {0, 1, 0}, {0, 0, 2}, {0, 0, 43}}
// Prediction:
clf := DecisionTreeClassifier{lChilds, rChilds, thresholds, indices, classes}
estimation := clf.predict(features)
fmt.Printf("%d\n", estimation)
}