Here is an example of how a single non-ensembled model can achieve high ranking scores using XGBoost, which is an optimized distributed gradient boosting library designed to be highly efficient, flexible and portable.
Based on 8142 instances and 22 attributes like, odor, habitat, color, etc, we can easily and accurately classify mushrooms as poisonous or edible and in few cases of unknown edibility and hence not recommended.
The Agaricus genus contains the most widely consumed and best-known mushroom today, but there are poisonous ones among them as well. The dataset consists of 8142 observations of Agaricus and Lepiota Family, this is a multivariate dataset with 22 characteristic attributes and classified into 2 classes, edible and poisonous.
In [34]:
using XGBoost, DataFrames, Gadfly, GLM
include("$(Pkg.dir())/MLDemos/src/xgboost/mushroom.jl");
path = "$(Pkg.dir())/MLDemos/";
Each line represent a single instance, and in the first line '1' is the instance label,'101' and '102' are feature indices, '1.2' and '0.03' are feature values.
Ex. :
1 2:1 9:1 10:1 20:1 29:1 33:1 35:1 39:1 40:1 52:1 57:1 64:1 68:1 76:1 85:1 87:1 91:1 94:1 101:1 104:1 116:1 123:1
0 2:1 9:1 19:1 20:1 22:1 33:1 35:1 38:1 40:1 52:1 55:1 64:1 68:1 76:1 85:1 87:1 91:1 94:1 101:1 105:1 115:1 119:1
In [40]:
attribute_dict = Dict(1=>"cap-shape:bell",2=>"cap-shape:conical",3=>"cap-shape:convex",4=>"cap-shape:flat",5=>"cap-shape:knobbed",6=>"cap-shape:sunken",
7=>"cap-surface: fibrous",8=>"cap-surface:grooves",9=>"cap-surface:scaly",10=>"cap-surface:smooth",
11=>"cap-color: brown",12=>"cap-color: buff",13=>"cap-color: cinnamon",14=>"cap-color: gray",15=>"cap-color: green", 16=>"cap-color: pink",17=>"cap-color: purple",18=>"cap-color: red",19=>"cap-color: white",20=>"cap-color: yellow",
21=>"bruises?: yes",22=>"bruises?: no",
23=>"odor: almond",24=>"odor: anise",25=>"odor: acreosote",26=>"odor: fishy",27=>"odor: foul", 28=>"odor: musty",29=>"odor: none",30=>"odor: pungent",31=>"odor: spicy",
32=>"gill-attachment: attached",33=>"gill-attachment: descending",34=>"gill-attachment: free",35=>"gill-attachment: notched",
36=>"gill-spacing: close",37=>"gill-spacing: crowded",38=>"gill-spacing: distant",
39=>"gill-size: broad",40=>"gill-size: narrow",
41=>"gill-color: black",42=>"gill-color: brown",43=>"gill-color: buff",44=>"gill-color: chocolate",45=>"gill-color: gray", 46=>"gill-color: green",47=>"gill-color: orange",48=>"gill-color: pink",49=>"gill-color: purple",50=>"gill-color: red", 51=>"gill-color: white",52=>"gill-color: yellow",
53=>"stalk-shape: enlarging",54=>"stalk-shape: tapering",
55=>"stalk-root: bulbous",56=>"stalk-root: club",57=>"stalk-root: cup",58=>"stalk-root: equal", 59=>"stalk-root: rhizomorphs",60=>"stalk-root: rooted",61=>"stalk-root: missing",
62=>"stalk-surface-above-ring: fibrous",63=>"stalk-surface-above-ring: scaly",64=>"stalk-surface-above-ring: silky",65=>"stalk-surface-above-ring: smooth",
66=>"stalk-color-below-ring: brown",67=>"stalk-color-below-ring: buff",68=>"stalk-color-below-ring: cinnamon",69=>"stalk-color-below-ring: gray",70=>"stalk-color-below-ring: orange", 71=>"stalk-color-below-ring: pink",72=>"stalk-color-below-ring: red",73=>"stalk-color-below-ring: white",74=>"stalk-color-below-ring: yellow",
75=>"veil-type: partial",76=>"veil-type: universal",
77=>"veil-color: brown",78=>"veil-color: orange",79=>"veil-color: white",80=>"veil-color: yellow",
81=>"ring-number: none",82=>"ring-number: one",83=>"ring-number: two",
84=>"ring-type: cobwebby",85=>"ring-type: evanescent",86=>"ring-type: flaring",87=>"ring-type: large", 88=>"ring-type: none",89=>"ring-type: pendant",90=>"ring-type: sheathing",91=>"ring-type: zone",
92=>"spore-print-color: black",93=>"spore-print-color: brown",94=>"spore-print-color: buff",95=>"spore-print-color: chocolate",96=>"spore-print-color: green", 97=>"spore-print-color: orange",98=>"spore-print-color: purple",99=>"spore-print-color: white",100=>"spore-print-color: yellow",
101=>"population: abundant",102=>"population: clustered",103=>"population: numerous", 104=>"population: scattered",105=>"population: several",106=>"population: solitary",
107=>"habitat: grasses",108=>"habitat: leaves",109=>"habitat: meadows",110=>"habitat: paths", 111=>"habitat: urban",112=>"habitat: waste",113=>"habitat: woods")
Out[40]:
In [4]:
train_X, train_Y = readlibsvm("$(path)data/mushroom/agaricus.txt.train", (6513, 126));
test_X, test_Y = readlibsvm("$(path)data/mushroom/agaricus.txt.test", (1611, 126));
In [52]:
num_round = 2;
print("training xgboost with dense matrix\n");
@time bst1 = xgboost(train_X, num_round, label = train_Y, eta=1, max_depth=2, objective="binary:logistic");
In [8]:
print("training xgboost with sparse matrix\n");
sptrain = sparse(train_X);
param = ["max_depth"=>2, "eta"=>1, "objective"=>"binary:logistic"]
@time bst = xgboost(sptrain, num_round, label = train_Y, param=param)
Out[8]:
In [51]:
print("training xgboost with DMatrix\n")
dtrain = DMatrix(train_X, label = train_Y)
println(num_round)
@time bst = xgboost(dtrain, num_round, eta = 1, objective = "binary:logistic")
Out[51]:
In [7]:
preds1 = predict(bst1, test_X)
print("test-error=", sum((preds1 .> 0.5) .!= test_Y) / float(size(preds1)[1]), "\n")
In [75]:
# To find if the mushroon in the test set(1<n<1611) is edible or not: ~= 1 edible, `= 0 poisonous.
n=3
attribs = find(test_X[n,:])
@show int(preds[n])
describe_mushroom(n);
In [67]:
function describe_mushroom(n)
for i =1:length(attribs)
println(attribute_dict[attribs[i]])
end
end
Out[67]:
Solving the same problem not using linear models instead of trees,
In [35]:
param_lm = Dict("booster"=>"gblinear", "eta"=>1, "silent"=>0,
"objective"=>"binary:logistic", "alpha"=>0.0001, "lambda"=>1)
Out[35]:
In [18]:
dtrain = DMatrix("../data/mushroom/agaricus.txt.train")
dtest = DMatrix("../data/mushroom/agaricus.txt.test")
watchlist = [(dtest,"eval"), (dtrain,"train")]
num_round = 4
Out[18]:
In [36]:
bst = xgboost(dtrain, num_round, param=param_lm, watchlist=watchlist)
Out[36]:
In [22]:
preds_glm = predict(bst, dtest)
Out[22]:
In [23]:
labels = get_info(dtest, "label")
Out[23]:
In [24]:
print("test-error=", sum((preds .> 0.5) .!= labels) / float(size(preds)[1]), "\n")
In [ ]: