In [1]:
import tensorflow as tf
import numpy as np

In [2]:
xy = np.loadtxt('./logistic.dat', unpack=True, dtype='float32')
x_data = xy[0:-1]
y_data = xy[-1]

In [3]:
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)

W = tf.Variable(tf.random_uniform([1, len(x_data)], -1.0, 1.0))

h = tf.matmul(W, X)
hypothesis = tf.div(1., 1. + tf.exp(-h))

cost = -tf.reduce_mean(Y * tf.log(hypothesis) + (1 - Y) * tf.log(1 - hypothesis))

In [4]:
a = tf.Variable(0.1)  # learning rate, alpha
optimizer = tf.train.GradientDescentOptimizer(a)
train = optimizer.minimize(cost)  # goal is minimize cost

init = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init)

for step in range(10000):
    sess.run(train, feed_dict={X: x_data, Y: y_data})
    if step % 100 == 0:
        print((step, sess.run(cost, feed_dict={X: x_data, Y: y_data}), sess.run(W)))


(0, 0.69244617, array([[-0.44698417,  0.56535155, -0.24234931]], dtype=float32))
(100, 0.50953078, array([[-1.38380301,  0.13240102,  0.32498088]], dtype=float32))
(200, 0.45830154, array([[-2.08717704,  0.18951525,  0.43338704]], dtype=float32))
(300, 0.42690668, array([[-2.63873649,  0.24241963,  0.51071304]], dtype=float32))
(400, 0.40619192, array([[-3.08693218,  0.28342423,  0.5751515 ]], dtype=float32))
(500, 0.39171207, array([[-3.46173072,  0.31579974,  0.63050359]], dtype=float32))
(600, 0.38112172, array([[-3.78229165,  0.34195262,  0.67896569]], dtype=float32))
(700, 0.37309316, array([[-4.06140661,  0.36349553,  0.72202319]], dtype=float32))
(800, 0.36682844, array([[-4.30796003,  0.3815338 ,  0.76073122]], dtype=float32))
(900, 0.36182287, array([[-4.52833557,  0.39684749,  0.79586625]], dtype=float32))
(1000, 0.35774407, array([[-4.72725534,  0.41000184,  0.82801503]], dtype=float32))
(1100, 0.35436466, array([[-4.90830612,  0.42141655,  0.85763252]], dtype=float32))
(1200, 0.35152474, array([[-5.07426119,  0.43140942,  0.88507682]], dtype=float32))
(1300, 0.34910873, array([[-5.22731829,  0.44022602,  0.91063714]], dtype=float32))
(1400, 0.34703133, array([[-5.36923027,  0.44805807,  0.93454748]], dtype=float32))
(1500, 0.34522817, array([[-5.50143433,  0.45505893,  0.95700264]], dtype=float32))
(1600, 0.3436501, array([[-5.62510014,  0.46135098,  0.97816318]], dtype=float32))
(1700, 0.3422586, array([[-5.74121475,  0.46703479,  0.9981665 ]], dtype=float32))
(1800, 0.34102347, array([[-5.85059977,  0.47219223,  1.01712859]], dtype=float32))
(1900, 0.33992052, array([[-5.95395756,  0.4768917 ,  1.03514922]], dtype=float32))
(2000, 0.33893022, array([[-6.05188656,  0.48119035,  1.0523144 ]], dtype=float32))
(2100, 0.33803657, array([[-6.14490986,  0.48513663,  1.06870055]], dtype=float32))
(2200, 0.33722654, array([[-6.23346138,  0.48877025,  1.08437133]], dtype=float32))
(2300, 0.33648923, array([[-6.31793976,  0.49212685,  1.09938562]], dtype=float32))
(2400, 0.33581567, array([[-6.3986845 ,  0.49523577,  1.11379421]], dtype=float32))
(2500, 0.3351979, array([[-6.47599888,  0.49812335,  1.12764263]], dtype=float32))
(2600, 0.33462963, array([[-6.55015087,  0.50081158,  1.14097166]], dtype=float32))
(2700, 0.33410522, array([[-6.62137747,  0.50332004,  1.15381753]], dtype=float32))
(2800, 0.33361992, array([[-6.68989229,  0.50566566,  1.16621315]], dtype=float32))
(2900, 0.33316949, array([[-6.75588512,  0.50786376,  1.17818773]], dtype=float32))
(3000, 0.33275071, array([[-6.81953096,  0.50992733,  1.18976903]], dtype=float32))
(3100, 0.33236015, array([[-6.88098288,  0.51186824,  1.20098031]], dtype=float32))
(3200, 0.33199522, array([[-6.94038153,  0.51369691,  1.21184433]], dtype=float32))
(3300, 0.33165357, array([[-6.99785328,  0.51542217,  1.22238111]], dtype=float32))
(3400, 0.33133298, array([[-7.05351686,  0.51705277,  1.23260939]], dtype=float32))
(3500, 0.33103171, array([[-7.10747766,  0.51859623,  1.24254584]], dtype=float32))
(3600, 0.33074808, array([[-7.15983343,  0.52005899,  1.25220668]], dtype=float32))
(3700, 0.33048058, array([[-7.21067429,  0.52144706,  1.26160645]], dtype=float32))
(3800, 0.33022794, array([[-7.26008224,  0.52276659,  1.27075779]], dtype=float32))
(3900, 0.32998896, array([[-7.30813313,  0.52402157,  1.27967405]], dtype=float32))
(4000, 0.32976261, array([[-7.3548975 ,  0.52521658,  1.28836632]], dtype=float32))
(4100, 0.32954785, array([[-7.40043831,  0.52635616,  1.29684496]], dtype=float32))
(4200, 0.329344, array([[-7.44481897,  0.52744395,  1.30512035]], dtype=float32))
(4300, 0.32915005, array([[-7.48809433,  0.52848315,  1.31320179]], dtype=float32))
(4400, 0.32896546, array([[-7.53031635,  0.5294773 ,  1.32109773]], dtype=float32))
(4500, 0.32878956, array([[-7.57153368,  0.53042877,  1.32881641]], dtype=float32))
(4600, 0.32862175, array([[-7.61179066,  0.53134042,  1.3363651 ]], dtype=float32))
(4700, 0.32846141, array([[-7.65113068,  0.53221482,  1.34375119]], dtype=float32))
(4800, 0.32830825, array([[-7.68959332,  0.53305334,  1.35098195]], dtype=float32))
(4900, 0.3281616, array([[-7.72721624,  0.53385907,  1.35806274]], dtype=float32))
(5000, 0.32802114, array([[-7.76403427,  0.53463352,  1.36500013]], dtype=float32))
(5100, 0.32788661, array([[-7.80007935,  0.53537858,  1.37179911]], dtype=float32))
(5200, 0.32775748, array([[-7.83538294,  0.53609544,  1.37846541]], dtype=float32))
(5300, 0.32763347, array([[-7.86997509,  0.53678596,  1.38500416]], dtype=float32))
(5400, 0.32751432, array([[-7.9038825 ,  0.53745162,  1.39141965]], dtype=float32))
(5500, 0.32739973, array([[-7.93713045,  0.53809381,  1.39771616]], dtype=float32))
(5600, 0.32728961, array([[-7.96974373,  0.53871316,  1.40389836]], dtype=float32))
(5700, 0.32718346, array([[-8.00174618,  0.53931171,  1.40997016]], dtype=float32))
(5800, 0.3270812, array([[-8.03315735,  0.53988969,  1.41593492]], dtype=float32))
(5900, 0.32698256, array([[-8.06400204,  0.54044861,  1.4217968 ]], dtype=float32))
(6000, 0.32688746, array([[-8.09429455,  0.54098898,  1.42755878]], dtype=float32))
(6100, 0.32679567, array([[-8.12406158,  0.54151243,  1.43322492]], dtype=float32))
(6200, 0.32670701, array([[-8.15331268,  0.54201853,  1.43879759]], dtype=float32))
(6300, 0.32662126, array([[-8.18206978,  0.54250938,  1.44427955]], dtype=float32))
(6400, 0.32653835, array([[-8.21034813,  0.5429849 ,  1.44967449]], dtype=float32))
(6500, 0.32645813, array([[-8.23816299,  0.54344571,  1.45498502]], dtype=float32))
(6600, 0.32638043, array([[-8.26552963,  0.54389322,  1.46021259]], dtype=float32))
(6700, 0.3263053, array([[-8.29245853,  0.54432726,  1.46536064]], dtype=float32))
(6800, 0.32623246, array([[-8.31896877,  0.54474878,  1.47043145]], dtype=float32))
(6900, 0.32616177, array([[-8.34506989,  0.54515815,  1.47542739]], dtype=float32))
(7000, 0.32609329, array([[-8.37077332,  0.54555601,  1.48035014]], dtype=float32))
(7100, 0.32602689, array([[-8.39609051,  0.54594284,  1.48520148]], dtype=float32))
(7200, 0.32596231, array([[-8.42103291,  0.54631871,  1.48998427]], dtype=float32))
(7300, 0.32589969, array([[-8.44561195,  0.5466845 ,  1.49469995]], dtype=float32))
(7400, 0.32583877, array([[-8.4698391 ,  0.54704064,  1.49935043]], dtype=float32))
(7500, 0.32577959, array([[-8.49372387,  0.54738718,  1.50393748]], dtype=float32))
(7600, 0.32572216, array([[-8.5172739 ,  0.54772496,  1.50846291]], dtype=float32))
(7700, 0.32566613, array([[-8.54049873,  0.54805344,  1.51292789]], dtype=float32))
(7800, 0.32561168, array([[-8.56340981,  0.54837406,  1.51733458]], dtype=float32))
(7900, 0.32555869, array([[-8.58601093,  0.54868668,  1.52168405]], dtype=float32))
(8000, 0.32550713, array([[-8.60831165,  0.54899096,  1.52597797]], dtype=float32))
(8100, 0.32545692, array([[-8.6303196 ,  0.54928809,  1.53021693]], dtype=float32))
(8200, 0.32540795, array([[-8.65204525,  0.54957771,  1.5344038 ]], dtype=float32))
(8300, 0.32536015, array([[-8.67349529,  0.54986113,  1.53853846]], dtype=float32))
(8400, 0.32531372, array([[-8.69467068,  0.55013716,  1.54262269]], dtype=float32))
(8500, 0.32526836, array([[-8.71558189,  0.55040652,  1.5466578 ]], dtype=float32))
(8600, 0.32522407, array([[-8.73623848,  0.55067021,  1.55064499]], dtype=float32))
(8700, 0.32518086, array([[-8.7566433 ,  0.55092788,  1.55458498]], dtype=float32))
(8800, 0.32513872, array([[-8.7767992 ,  0.55117893,  1.55847919]], dtype=float32))
(8900, 0.32509759, array([[-8.7967205 ,  0.5514251 ,  1.56232858]], dtype=float32))
(9000, 0.32505739, array([[-8.81640244,  0.55166554,  1.56613398]], dtype=float32))
(9100, 0.32501811, array([[-8.83585739,  0.5519008 ,  1.56989622]], dtype=float32))
(9200, 0.32497969, array([[-8.85509396,  0.55213112,  1.5736177 ]], dtype=float32))
(9300, 0.32494223, array([[-8.8741045 ,  0.55235606,  1.57729733]], dtype=float32))
(9400, 0.32490548, array([[-8.89290237,  0.55257595,  1.58093667]], dtype=float32))
(9500, 0.32486966, array([[-8.9114933 ,  0.55279183,  1.58453691]], dtype=float32))
(9600, 0.32483456, array([[-8.92988014,  0.55300307,  1.588099  ]], dtype=float32))
(9700, 0.32480025, array([[-8.94806671,  0.55321014,  1.59162307]], dtype=float32))
(9800, 0.32476673, array([[-8.96605682,  0.55341274,  1.59511042]], dtype=float32))
(9900, 0.32473382, array([[-8.98385429,  0.5536111 ,  1.59856176]], dtype=float32))

In [5]:
print(sess.run(hypothesis, feed_dict={X: [[1], [2], [2]]}) > 0.5)
print(sess.run(hypothesis, feed_dict={X: [[1], [5], [5]]}) > 0.5)


[[False]]
[[ True]]

In [6]:
x_data


Out[6]:
array([[ 1.,  1.,  1.,  1.,  1.,  1.],
       [ 2.,  3.,  3.,  5.,  7.,  2.],
       [ 1.,  2.,  5.,  5.,  5.,  5.]], dtype=float32)

In [7]:
y_data


Out[7]:
array([ 0.,  0.,  0.,  1.,  1.,  1.], dtype=float32)

In [8]:
sess.run(tf.matmul(W, x_data))


Out[8]:
array([[-6.29173708, -4.13599014,  0.66983986,  1.7774477 ,  2.88505507,
         0.11603642]], dtype=float32)

In [ ]: