In [105]:
import gbdt
import numpy
import matplotlib.pyplot as plt

In [106]:
x = numpy.arange(0.1, 10, 0.1)
y = numpy.log(x)

In [107]:
data = gbdt.DataLoader.from_dict(bucketized_float_cols={'x': list(x)}, raw_float_cols={'y': list(y)})

In [104]:
config = {'loss_func': 'mse',
          'num_trees': 1,
          'num_leaves': 2,
          'example_sampling_rate': 0.8,
          'feature_sampling_rate': 0.8,
          'shrinkage' : 0.1}

In [66]:
forest = gbdt.train(data,
                    y=list(data['y']),
                    features=['x'],
                    config= 
                    {'loss_func': 'mse',
                     'num_trees': 20,
                     'num_leaves': 2,
                     'example_sampling_rate': 0.8,
                     'feature_sampling_rate': 0.8,
                     'shrinkage' : 0.1})

In [67]:
plt.plot(x, y, x, forest.predict(data), 'g-')


Out[67]:
[<matplotlib.lines.Line2D at 0x107de9518>,
 <matplotlib.lines.Line2D at 0x107de96d8>]

In [68]:
plt.show()



In [97]:
x1 = list(x) + [float('nan')]* 20
y1 = list(y) + [1.3] * 20

In [101]:
data1 = gbdt.DataLoader.from_dict(bucketized_float_cols={'x': list(x1)}, raw_float_cols={'y': list(y1)})
forest1 = gbdt.train(data1,
                    y=list(data1['y']),
                    features=['x'],
                    config= 
                    {'loss_func': 'mse',
                     'num_trees': 100,
                     'num_leaves': 2,
                     'example_sampling_rate': 0.8,
                     'feature_sampling_rate': 0.8,
                     'shrinkage' : 0.1})

In [102]:
plt.plot(x1, y1, x1, forest1.predict(data1), 'g-')
plt.show()



In [103]:
forest1.predict(data)


Out[103]:
[-2.1622843207223923,
 -1.5598169812947162,
 -1.1581779667249066,
 -0.8770570966953528,
 -0.6615576621625223,
 -0.5217339660521247,
 -0.3288728933330276,
 -0.1963320913273492,
 -0.16821528779837536,
 0.02421998144563986,
 0.16804252290603472,
 0.1839968212225358,
 0.2812133984916727,
 0.2812133984916727,
 0.4365549156718771,
 0.521354538788728,
 0.54544018132583,
 0.5680432777080568,
 0.5680432777080568,
 0.7478791224930319,
 0.7985112342212233,
 0.8023997616910492,
 0.8023997616910492,
 0.9166337546848808,
 0.9166337546848808,
 0.9166337546848808,
 1.0632810068273102,
 1.0632810068273102,
 1.0685323273719405,
 1.0685323273719405,
 1.188003471303091,
 1.188003471303091,
 1.192467789114744,
 1.2329856552605634,
 1.2519683245991473,
 1.3061442304824595,
 1.3100741105808993,
 1.333042480815493,
 1.333042480815493,
 1.3886119879971375,
 1.4044044039401342,
 1.4044044039401342,
 1.5036864786598017,
 1.5036864786598017,
 1.5036864786598017,
 1.5364076898622443,
 1.5364076898622443,
 1.5364076898622443,
 1.6125054371732404,
 1.6125054371732404,
 1.638700872026675,
 1.650193864312314,
 1.650193864312314,
 1.650193864312314,
 1.719175260346674,
 1.719175260346674,
 1.719175260346674,
 1.780328033324622,
 1.780328033324622,
 1.78669169551722,
 1.78669169551722,
 1.8406030461846967,
 1.8458137583875214,
 1.8458137583875214,
 1.8458137583875214,
 1.8890299486884032,
 1.8890299486884032,
 1.9311497143426095,
 1.9311497143426095,
 1.9413450573993032,
 1.9596074006476556,
 1.9596074006476556,
 1.9927581382871722,
 1.9927581382871722,
 1.9927581382871722,
 2.022393818195269,
 2.022393818195269,
 2.048996358826116,
 2.048996358826116,
 2.0735515302003478,
 2.0735515302003478,
 2.082894477127411,
 2.102722992487543,
 2.102722992487543,
 2.122319732392498,
 2.122319732392498,
 2.137961284672201,
 2.1510707532361266,
 2.164899107076053,
 2.1740947418220458,
 2.1817031655591563,
 2.193747473276744,
 2.201437209332653,
 2.201437209332653,
 2.209340184352186,
 2.209340184352186,
 2.209340184352186,
 2.209340184352186,
 2.209340184352186,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595,
 1.3061442304824595]

In [ ]: