In [14]:
import matplotlib.pyplot as plt, numpy, seaborn
%matplotlib inline

In [10]:
batches = []
losses = []
with open('/Users/alger/data/Crowdastro/cnn_validation_output_0_12_05_17.csv') as f:
    for row in f:
        batch, loss = row.strip().split()
        batch, loss = int(batch), float(loss)
        batches.append(batch)
        losses.append(loss)

In [11]:
plt.scatter(batches, losses, alpha=0.5)
plt.show()



In [37]:
epochs = []
std_losses = []
mean_losses = []
ebvs = []
for i, batch in enumerate(range(0, len(batches), 40)):
    epochs.append(i)
    epoch = batches[batch:batch + 40]
    mean_loss = numpy.mean([losses[b] for b in epoch])
    std_loss = numpy.std([losses[b] for b in epoch])
    mean_losses.append(mean_loss)
    std_losses.append(std_loss)
    for j, b in enumerate(epoch):
        ebvs.append((i, j, losses[b]))

In [66]:
import pandas
data = pandas.DataFrame(data=ebvs, columns=['Epoch', 'Batch', 'Value'])
seaborn.tsplot(data, time='Epoch', unit='Batch', value='Value')


Out[66]:
<matplotlib.axes._subplots.AxesSubplot at 0x11c00b8d0>

In [63]:
data[data['Epoch'] > 400]


Out[63]:
Epoch Batch Value
16040 401 0 0.053508
16041 401 1 0.006960
16042 401 2 0.000262
16043 401 3 0.001344
16044 401 4 0.004116
16045 401 5 0.002104
16046 401 6 0.034309
16047 401 7 0.004492
16048 401 8 0.000298
16049 401 9 0.000767
16050 401 10 0.000352
16051 401 11 0.211647
16052 401 12 0.033747
16053 401 13 0.000996
16054 401 14 0.010047
16055 401 15 0.001878
16056 401 16 0.052937
16057 401 17 0.000392
16058 401 18 0.012180
16059 401 19 0.012755
16060 401 20 0.000268
16061 401 21 0.042466
16062 401 22 0.000409
16063 401 23 0.014569
16064 401 24 0.173037
16065 401 25 0.003160
16066 401 26 0.014222
16067 401 27 0.007383
16068 401 28 0.009715
16069 401 29 0.011253
... ... ... ...
39970 999 10 0.000836
39971 999 11 0.000043
39972 999 12 0.000006
39973 999 13 0.000031
39974 999 14 0.000022
39975 999 15 0.000725
39976 999 16 0.000546
39977 999 17 0.000229
39978 999 18 0.001127
39979 999 19 0.000697
39980 999 20 0.000140
39981 999 21 0.000125
39982 999 22 0.000027
39983 999 23 0.016327
39984 999 24 0.043670
39985 999 25 0.000016
39986 999 26 0.000006
39987 999 27 0.000028
39988 999 28 0.005499
39989 999 29 0.000165
39990 999 30 0.000471
39991 999 31 0.002593
39992 999 32 0.000363
39993 999 33 0.000918
39994 999 34 0.000005
39995 999 35 0.000244
39996 999 36 0.000523
39997 999 37 0.000014
39998 999 38 0.000002
39999 999 39 0.000018

23960 rows × 3 columns


In [ ]: