In [1]:
import tensorflow as tf
import numpy as np
In [2]:
xy = np.loadtxt('./logistic.dat', unpack=True, dtype='float32')
x_data = xy[0:-1]
y_data = xy[-1]
In [3]:
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
W = tf.Variable(tf.random_uniform([1, len(x_data)], -1.0, 1.0))
h = tf.matmul(W, X)
hypothesis = tf.div(1., 1. + tf.exp(-h))
cost = -tf.reduce_mean(Y * tf.log(hypothesis) + (1 - Y) * tf.log(1 - hypothesis))
In [4]:
a = tf.Variable(0.1) # learning rate, alpha
optimizer = tf.train.GradientDescentOptimizer(a)
train = optimizer.minimize(cost) # goal is minimize cost
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
for step in range(10000):
sess.run(train, feed_dict={X: x_data, Y: y_data})
if step % 100 == 0:
print((step, sess.run(cost, feed_dict={X: x_data, Y: y_data}), sess.run(W)))
(0, 0.69244617, array([[-0.44698417, 0.56535155, -0.24234931]], dtype=float32))
(100, 0.50953078, array([[-1.38380301, 0.13240102, 0.32498088]], dtype=float32))
(200, 0.45830154, array([[-2.08717704, 0.18951525, 0.43338704]], dtype=float32))
(300, 0.42690668, array([[-2.63873649, 0.24241963, 0.51071304]], dtype=float32))
(400, 0.40619192, array([[-3.08693218, 0.28342423, 0.5751515 ]], dtype=float32))
(500, 0.39171207, array([[-3.46173072, 0.31579974, 0.63050359]], dtype=float32))
(600, 0.38112172, array([[-3.78229165, 0.34195262, 0.67896569]], dtype=float32))
(700, 0.37309316, array([[-4.06140661, 0.36349553, 0.72202319]], dtype=float32))
(800, 0.36682844, array([[-4.30796003, 0.3815338 , 0.76073122]], dtype=float32))
(900, 0.36182287, array([[-4.52833557, 0.39684749, 0.79586625]], dtype=float32))
(1000, 0.35774407, array([[-4.72725534, 0.41000184, 0.82801503]], dtype=float32))
(1100, 0.35436466, array([[-4.90830612, 0.42141655, 0.85763252]], dtype=float32))
(1200, 0.35152474, array([[-5.07426119, 0.43140942, 0.88507682]], dtype=float32))
(1300, 0.34910873, array([[-5.22731829, 0.44022602, 0.91063714]], dtype=float32))
(1400, 0.34703133, array([[-5.36923027, 0.44805807, 0.93454748]], dtype=float32))
(1500, 0.34522817, array([[-5.50143433, 0.45505893, 0.95700264]], dtype=float32))
(1600, 0.3436501, array([[-5.62510014, 0.46135098, 0.97816318]], dtype=float32))
(1700, 0.3422586, array([[-5.74121475, 0.46703479, 0.9981665 ]], dtype=float32))
(1800, 0.34102347, array([[-5.85059977, 0.47219223, 1.01712859]], dtype=float32))
(1900, 0.33992052, array([[-5.95395756, 0.4768917 , 1.03514922]], dtype=float32))
(2000, 0.33893022, array([[-6.05188656, 0.48119035, 1.0523144 ]], dtype=float32))
(2100, 0.33803657, array([[-6.14490986, 0.48513663, 1.06870055]], dtype=float32))
(2200, 0.33722654, array([[-6.23346138, 0.48877025, 1.08437133]], dtype=float32))
(2300, 0.33648923, array([[-6.31793976, 0.49212685, 1.09938562]], dtype=float32))
(2400, 0.33581567, array([[-6.3986845 , 0.49523577, 1.11379421]], dtype=float32))
(2500, 0.3351979, array([[-6.47599888, 0.49812335, 1.12764263]], dtype=float32))
(2600, 0.33462963, array([[-6.55015087, 0.50081158, 1.14097166]], dtype=float32))
(2700, 0.33410522, array([[-6.62137747, 0.50332004, 1.15381753]], dtype=float32))
(2800, 0.33361992, array([[-6.68989229, 0.50566566, 1.16621315]], dtype=float32))
(2900, 0.33316949, array([[-6.75588512, 0.50786376, 1.17818773]], dtype=float32))
(3000, 0.33275071, array([[-6.81953096, 0.50992733, 1.18976903]], dtype=float32))
(3100, 0.33236015, array([[-6.88098288, 0.51186824, 1.20098031]], dtype=float32))
(3200, 0.33199522, array([[-6.94038153, 0.51369691, 1.21184433]], dtype=float32))
(3300, 0.33165357, array([[-6.99785328, 0.51542217, 1.22238111]], dtype=float32))
(3400, 0.33133298, array([[-7.05351686, 0.51705277, 1.23260939]], dtype=float32))
(3500, 0.33103171, array([[-7.10747766, 0.51859623, 1.24254584]], dtype=float32))
(3600, 0.33074808, array([[-7.15983343, 0.52005899, 1.25220668]], dtype=float32))
(3700, 0.33048058, array([[-7.21067429, 0.52144706, 1.26160645]], dtype=float32))
(3800, 0.33022794, array([[-7.26008224, 0.52276659, 1.27075779]], dtype=float32))
(3900, 0.32998896, array([[-7.30813313, 0.52402157, 1.27967405]], dtype=float32))
(4000, 0.32976261, array([[-7.3548975 , 0.52521658, 1.28836632]], dtype=float32))
(4100, 0.32954785, array([[-7.40043831, 0.52635616, 1.29684496]], dtype=float32))
(4200, 0.329344, array([[-7.44481897, 0.52744395, 1.30512035]], dtype=float32))
(4300, 0.32915005, array([[-7.48809433, 0.52848315, 1.31320179]], dtype=float32))
(4400, 0.32896546, array([[-7.53031635, 0.5294773 , 1.32109773]], dtype=float32))
(4500, 0.32878956, array([[-7.57153368, 0.53042877, 1.32881641]], dtype=float32))
(4600, 0.32862175, array([[-7.61179066, 0.53134042, 1.3363651 ]], dtype=float32))
(4700, 0.32846141, array([[-7.65113068, 0.53221482, 1.34375119]], dtype=float32))
(4800, 0.32830825, array([[-7.68959332, 0.53305334, 1.35098195]], dtype=float32))
(4900, 0.3281616, array([[-7.72721624, 0.53385907, 1.35806274]], dtype=float32))
(5000, 0.32802114, array([[-7.76403427, 0.53463352, 1.36500013]], dtype=float32))
(5100, 0.32788661, array([[-7.80007935, 0.53537858, 1.37179911]], dtype=float32))
(5200, 0.32775748, array([[-7.83538294, 0.53609544, 1.37846541]], dtype=float32))
(5300, 0.32763347, array([[-7.86997509, 0.53678596, 1.38500416]], dtype=float32))
(5400, 0.32751432, array([[-7.9038825 , 0.53745162, 1.39141965]], dtype=float32))
(5500, 0.32739973, array([[-7.93713045, 0.53809381, 1.39771616]], dtype=float32))
(5600, 0.32728961, array([[-7.96974373, 0.53871316, 1.40389836]], dtype=float32))
(5700, 0.32718346, array([[-8.00174618, 0.53931171, 1.40997016]], dtype=float32))
(5800, 0.3270812, array([[-8.03315735, 0.53988969, 1.41593492]], dtype=float32))
(5900, 0.32698256, array([[-8.06400204, 0.54044861, 1.4217968 ]], dtype=float32))
(6000, 0.32688746, array([[-8.09429455, 0.54098898, 1.42755878]], dtype=float32))
(6100, 0.32679567, array([[-8.12406158, 0.54151243, 1.43322492]], dtype=float32))
(6200, 0.32670701, array([[-8.15331268, 0.54201853, 1.43879759]], dtype=float32))
(6300, 0.32662126, array([[-8.18206978, 0.54250938, 1.44427955]], dtype=float32))
(6400, 0.32653835, array([[-8.21034813, 0.5429849 , 1.44967449]], dtype=float32))
(6500, 0.32645813, array([[-8.23816299, 0.54344571, 1.45498502]], dtype=float32))
(6600, 0.32638043, array([[-8.26552963, 0.54389322, 1.46021259]], dtype=float32))
(6700, 0.3263053, array([[-8.29245853, 0.54432726, 1.46536064]], dtype=float32))
(6800, 0.32623246, array([[-8.31896877, 0.54474878, 1.47043145]], dtype=float32))
(6900, 0.32616177, array([[-8.34506989, 0.54515815, 1.47542739]], dtype=float32))
(7000, 0.32609329, array([[-8.37077332, 0.54555601, 1.48035014]], dtype=float32))
(7100, 0.32602689, array([[-8.39609051, 0.54594284, 1.48520148]], dtype=float32))
(7200, 0.32596231, array([[-8.42103291, 0.54631871, 1.48998427]], dtype=float32))
(7300, 0.32589969, array([[-8.44561195, 0.5466845 , 1.49469995]], dtype=float32))
(7400, 0.32583877, array([[-8.4698391 , 0.54704064, 1.49935043]], dtype=float32))
(7500, 0.32577959, array([[-8.49372387, 0.54738718, 1.50393748]], dtype=float32))
(7600, 0.32572216, array([[-8.5172739 , 0.54772496, 1.50846291]], dtype=float32))
(7700, 0.32566613, array([[-8.54049873, 0.54805344, 1.51292789]], dtype=float32))
(7800, 0.32561168, array([[-8.56340981, 0.54837406, 1.51733458]], dtype=float32))
(7900, 0.32555869, array([[-8.58601093, 0.54868668, 1.52168405]], dtype=float32))
(8000, 0.32550713, array([[-8.60831165, 0.54899096, 1.52597797]], dtype=float32))
(8100, 0.32545692, array([[-8.6303196 , 0.54928809, 1.53021693]], dtype=float32))
(8200, 0.32540795, array([[-8.65204525, 0.54957771, 1.5344038 ]], dtype=float32))
(8300, 0.32536015, array([[-8.67349529, 0.54986113, 1.53853846]], dtype=float32))
(8400, 0.32531372, array([[-8.69467068, 0.55013716, 1.54262269]], dtype=float32))
(8500, 0.32526836, array([[-8.71558189, 0.55040652, 1.5466578 ]], dtype=float32))
(8600, 0.32522407, array([[-8.73623848, 0.55067021, 1.55064499]], dtype=float32))
(8700, 0.32518086, array([[-8.7566433 , 0.55092788, 1.55458498]], dtype=float32))
(8800, 0.32513872, array([[-8.7767992 , 0.55117893, 1.55847919]], dtype=float32))
(8900, 0.32509759, array([[-8.7967205 , 0.5514251 , 1.56232858]], dtype=float32))
(9000, 0.32505739, array([[-8.81640244, 0.55166554, 1.56613398]], dtype=float32))
(9100, 0.32501811, array([[-8.83585739, 0.5519008 , 1.56989622]], dtype=float32))
(9200, 0.32497969, array([[-8.85509396, 0.55213112, 1.5736177 ]], dtype=float32))
(9300, 0.32494223, array([[-8.8741045 , 0.55235606, 1.57729733]], dtype=float32))
(9400, 0.32490548, array([[-8.89290237, 0.55257595, 1.58093667]], dtype=float32))
(9500, 0.32486966, array([[-8.9114933 , 0.55279183, 1.58453691]], dtype=float32))
(9600, 0.32483456, array([[-8.92988014, 0.55300307, 1.588099 ]], dtype=float32))
(9700, 0.32480025, array([[-8.94806671, 0.55321014, 1.59162307]], dtype=float32))
(9800, 0.32476673, array([[-8.96605682, 0.55341274, 1.59511042]], dtype=float32))
(9900, 0.32473382, array([[-8.98385429, 0.5536111 , 1.59856176]], dtype=float32))
In [5]:
print(sess.run(hypothesis, feed_dict={X: [[1], [2], [2]]}) > 0.5)
print(sess.run(hypothesis, feed_dict={X: [[1], [5], [5]]}) > 0.5)
[[False]]
[[ True]]
In [6]:
x_data
Out[6]:
array([[ 1., 1., 1., 1., 1., 1.],
[ 2., 3., 3., 5., 7., 2.],
[ 1., 2., 5., 5., 5., 5.]], dtype=float32)
In [7]:
y_data
Out[7]:
array([ 0., 0., 0., 1., 1., 1.], dtype=float32)
In [8]:
sess.run(tf.matmul(W, x_data))
Out[8]:
array([[-6.29173708, -4.13599014, 0.66983986, 1.7774477 , 2.88505507,
0.11603642]], dtype=float32)
In [ ]:
Content source: dobestan/data-science-school
Similar notebooks: