train_and_test-checkpoint



In [1]:
import tensorflow as tf
from agent.main import Agent
from emulator.main import Account
from params import *

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
%matplotlib inline

完全随机探索填充经验库


In [2]:
env = Account()
agent = Agent()
# fill cache
for episode in range(2):
    state = env.reset()
    while True:
        action = agent.get_stochastic_policy(state, 0)
        next_state, reward, done = env.step(action)
        agent.update_cache(state, action, reward, next_state, done)
        state = next_state
        if done:
            break
print(len(agent.agent.cache))


2048

训练模型


In [3]:
NUM_EPISODES = 10

episodes_train = []
global_step = 0
for episode in range(NUM_EPISODES):
    state = env.reset()
    episode_step = 0
    while True:
        global_step += 1
        episode_step += 1

        action = agent.get_stochastic_policy(state)
        next_state, reward, done = env.step(action)
        agent.update_cache(state, action, reward, next_state, done)
        state = next_state

        if global_step % TARGET_STEP_SIZE == 0:
            agent.update_target()

        if episode_step % TRAIN_STEP_SIZE == 0 or done:
            agent.update_eval()

            if done:
                print(episode, env.A.total_value)
                episodes_train.append(env.plot_data())
                break

tmp = env.plot_data()
tmp.iloc[:, 0].plot(figsize=(16, 6))
agent.save_model()


0 129285.409067
1 155325.939548
2 200299.421097
3 193038.696306
4 239591.793298
5 256011.023411
6 268853.004745
7 280514.332691
8 303613.812299
9 293471.265307

测试算法


In [4]:
episode_value = [i["value"] for i in episodes_train]
episode_value = pd.concat(episode_value, axis=1)
episode_value.columns = list(range(episode_value.shape[1]))
episode_value.plot(figsize=(12, 9))


Out[4]:
<matplotlib.axes._subplots.AxesSubplot at 0x29c372b0>

In [5]:
state = env.reset()
for i in range(1445):
    action = agent.get_stochastic_policy(state)
    next_state, reward, done = env.step(action)
    state = next_state

tmp = env.plot_data()
tmp.iloc[:, 0].plot(figsize=(16, 6))


Out[5]:
<matplotlib.axes._subplots.AxesSubplot at 0x1c483438>

In [6]:
tmp.iloc[1203:, 0].plot(figsize=(16, 6))


Out[6]:
<matplotlib.axes._subplots.AxesSubplot at 0x29c5de10>

继续训练


In [8]:
agent.restore_model()

NUM_EPISODES = 1000
global_step = 0
for episode in range(NUM_EPISODES):
    state = env.reset()
    episode_step = 0
    while True:
        global_step += 1
        episode_step += 1

        action = agent.get_stochastic_policy(state)
        next_state, reward, done = env.step(action)
        agent.update_cache(state, action, reward, next_state, done)
        state = next_state

        if global_step % TARGET_STEP_SIZE == 0:
            agent.update_target()

        if episode_step % TRAIN_STEP_SIZE == 0 or done:
            agent.update_eval()

            if done:
                print(episode, env.A.total_value)
                episodes_train.append(env.plot_data())
                break

tmp = env.plot_data()
tmp.iloc[:, 0].plot(figsize=(16, 6))
agent.save_model()


INFO:tensorflow:Restoring parameters from model/ddqn.ckpt
0 305969.019185
1 296885.090924
2 314495.755748
3 330705.099699
4 332554.716914
5 343580.953714
6 334767.373049
7 327727.206608
8 338477.252834
9 337379.98773
10 350099.2577
11 357298.914069
12 358732.902911
13 361747.443897
14 360049.668097
15 338844.886457
16 344843.116577
17 362760.397236
18 361822.958613
19 364521.842898
20 371185.367259
21 367965.464033
22 374923.232031
23 366483.865857
24 350404.503759
25 380067.713115
26 366588.391349
27 365717.288956
28 372432.269562
29 375173.692649
30 377192.045029
31 365639.161042
32 381244.18921
33 377702.21812
34 376915.32017
35 400433.023658
36 389607.865325
37 398013.023334
38 391225.882637
39 383281.706963
40 396791.555476
41 387641.950133
42 394821.473207
43 404463.675754
44 401372.359312
45 394467.783514
46 404079.806721
47 384219.657554
48 405922.023195
49 402139.881513
50 400465.174227
51 419912.301379
52 419657.729645
53 409316.354628
54 394882.44859
55 394036.050544
56 404042.133161
57 399020.853927
58 414650.120946
59 396378.594304
60 417921.17552
61 415157.490641
62 419866.082507
63 405532.592906
64 411711.666119
65 424957.853797
66 411718.370832
67 415283.364872
68 430478.420048
69 429313.971626
70 423889.110221
71 436811.262268
72 421149.733362
73 423355.300437
74 429259.530344
75 422964.459555
76 437642.906787
77 424922.444097
78 425435.42815
79 424942.752144
80 416568.934969
81 426946.277966
82 431574.631786
83 430834.920396
84 432323.045853
85 432137.026592
86 428510.50381
87 428528.80433
88 433594.86602
89 436111.007028
90 435797.854852
91 426700.831415
92 442062.815185
93 432712.185196
94 427518.256263
95 436669.478503
96 445653.430412
97 430711.19989
98 433094.99711
99 434821.454844
100 437343.711474
101 433720.047689
102 430366.490571
103 441860.44635
104 439247.024679
105 425863.201165
106 433280.042127
107 425997.354223
108 431169.384523
109 446897.872526
110 422125.14392
111 439490.480779
112 409620.937763
113 431264.647114
114 441993.986161
115 427023.022987
116 434984.816498
117 450080.209397
118 442696.219617
119 431905.842326
120 437605.088744
121 434520.951849
122 444525.001128
123 439947.594089
124 425374.263088
125 444018.85295
126 439670.877545
127 431870.594104
128 428474.264717
129 433487.402007
130 437949.192141
131 440339.165664
132 436560.340774
133 446874.167231
134 434911.511252
135 436809.518014
136 446920.393254
137 429208.743351
138 444693.246299
139 448375.689543
140 450535.81643
141 450618.807678
142 447686.213987
143 438517.702834
144 440758.514034
145 443682.65171
146 429483.969552
147 434433.153208
148 448438.331317
149 451514.214909
150 425865.493185
151 454480.021465
152 443878.860794
153 436714.871344
154 438810.704825
155 451173.306777
156 449382.478617
157 453333.74887
158 431577.264568
159 453235.093093
160 458646.126508
161 434142.049959
162 443724.028197
163 431851.28543
164 443272.93657
165 444937.305933
166 443495.648055
167 446832.577351
168 445902.473208
169 434594.229469
170 438180.254239
171 436619.999599
172 443237.849297
173 440458.73094
174 435093.365663
175 459964.752298
176 446610.46402
177 434394.341934
178 461715.136362
179 457349.675732
180 443949.704628
181 447153.925486
182 435110.736807
183 446657.989277
184 447997.753745
185 455878.138711
186 431697.028637
187 454177.024574
188 447284.084389
189 456459.395863
190 458537.137829
191 450535.479661
192 444974.969024
193 433777.53212
194 457669.467273
195 446343.945556
196 444802.168683
197 450051.383047
198 448890.306222
199 440200.125127
200 451832.547785
201 443573.403553
202 448082.581352
203 447734.540818
204 448977.608893
205 449671.598582
206 447031.946503
207 446666.72673
208 454852.452498
209 460983.084536
210 446747.939559
211 452481.561527
212 450290.922357
213 450395.909799
214 459906.133456
215 451959.779749
216 437705.216749
217 428194.647341
218 451497.161845
219 459133.757847
220 451662.58368
221 458328.826623
222 465794.568554
223 466389.988484
224 456803.125189
225 457950.659293
226 462148.118765
227 452847.531851
228 455933.556596
229 450081.41928
230 449572.598773
231 451200.769895
232 454039.519925
233 447377.305003
234 462991.440715
235 451262.796552
236 456112.879026
237 456131.646842
238 462023.633917
239 460990.592205
240 459715.559768
241 467680.700601
242 470312.491525
243 451542.833122
244 462615.651522
245 459717.071998
246 452236.895716
247 452287.340629
248 436968.690887
249 469986.012386
250 451940.945508
251 451080.263768
252 459409.057611
253 449746.238482
254 456116.251074
255 459981.986793
256 471858.469194
257 471906.796407
258 470941.580456
259 458479.44067
260 464100.51472
261 464090.463058
262 465886.348143
263 456211.496562
264 464038.839103
265 459804.447027
266 473599.938576
267 446413.266096
268 447707.679594
269 472548.554591
270 460275.27296
271 455888.857141
272 459194.041227
273 451452.404326
274 456021.145349
275 455611.904945
276 455311.230518
277 465329.09766
278 455182.293576
279 460067.348697
280 447487.875962
281 436575.149349
282 449851.154528
283 465510.178181
284 463055.587849
285 454795.427939
286 456179.089118
287 455447.503686
288 457460.44461
289 443099.879831
290 454036.836718
291 463794.623417
292 459549.463382
293 452264.350448
294 457454.395411
295 461206.267856
296 444866.336383
297 472710.861591
298 461296.221756
299 468588.4616
300 452814.959241
301 462266.150922
302 456467.744217
303 471796.650559
304 450891.569536
305 461825.033851
306 472943.95065
307 459128.951552
308 462648.990498
309 458672.688253
310 459396.446975
311 455500.205259
312 445471.330099
313 460652.365358
314 466248.920055
315 456877.618297
316 457989.232516
317 473607.695763
318 457568.895559
319 451989.564987
320 471827.345799
321 453828.640665
322 456644.111419
323 445868.752423
324 451750.590874
325 471652.818508
326 463757.377507
327 451894.226044
328 460236.421747
329 472874.944553
330 456236.689362
331 468522.49446
332 465063.012442
333 460867.204313
334 463274.300475
335 468385.658488
336 458072.7939
337 463480.186225
338 453695.924467
339 472658.534101
340 455295.970386
341 452135.137475
342 465928.691471
343 464812.455925
344 453833.939505
345 462148.695146
346 465091.066119
347 466958.962501
348 465739.571465
349 470269.160408
350 461922.34926
351 449649.511313
352 446890.442931
353 460355.553287
354 462470.498917
355 454263.32851
356 460756.963115
357 454243.226392
358 457617.967719
359 467013.545654
360 444962.80106
361 461769.130256
362 465558.726983
363 454638.593091
364 453000.716755
365 464863.880052
366 475287.057761
367 448318.208113
368 458855.638577
369 455524.578129
370 470391.39911
371 460386.453453
372 469749.897504
373 462923.718901
374 471395.795286
375 444596.380898
376 454230.96255
377 469860.86305
378 470836.924114
379 456472.831062
380 462720.345612
381 456560.582425
382 448341.095472
383 459531.289064
384 454321.828492
385 450489.551413
386 451536.570943
387 465140.617692
388 464669.746387
389 460862.047172
390 456676.339853
391 463096.661672
392 468407.45731
393 463989.529155
394 456302.291131
395 465032.597009
396 464651.968886
397 464408.528554
398 458302.47988
399 459922.987943
400 458161.751873
401 456473.721092
402 459202.857774
403 460090.368408
404 466282.507516
405 454560.997306
406 454459.032587
407 460741.402658
408 467457.874705
409 468142.713039
410 475953.704951
411 469048.710173
412 457955.221596
413 465549.864339
414 467043.509229
415 460749.779909
416 460881.682496
417 457082.090417
418 464925.219727
419 457345.817466
420 476816.150351
421 471389.582416
422 452357.342663
423 446385.983042
424 464421.919361
425 462892.182934
426 463529.736594
427 456561.161714
428 464564.808756
429 465481.387083
430 465022.686602
431 465720.554434
432 465082.797414
433 464915.40203
434 464823.368542
435 460647.826914
436 462385.795366
437 465548.37765
438 453587.085551
439 466041.738486
440 463475.14962
441 464150.753252
442 449593.110952
443 458244.283876
444 480422.361924
445 477714.870748
446 460247.491547
447 462830.689624
448 451475.270041
449 471316.303257
450 446932.777059
451 448353.591365
452 463680.475116
453 460670.76095
454 464899.780333
455 457926.182599
456 452629.770129
457 453708.981043
458 454244.871349
459 462445.491764
460 460921.002315
461 465313.506647
462 461741.931969
463 464352.288828
464 455842.418502
465 452420.311392
466 458686.441419
467 461502.176473
468 461709.672364
469 453031.635574
470 461541.674905
471 461536.586414
472 466448.006916
473 461545.029908
474 458833.592983
475 451760.068653
476 456523.718096
477 441693.305587
478 459384.202394
479 465440.9933
480 459503.926505
481 456338.259349
482 459357.536579
483 465319.568457
484 460424.947921
485 466601.890136
486 432381.367624
487 466274.17858
488 449208.976385
489 459195.194109
490 449422.018095
491 443932.265156
492 465923.391791
493 455198.30441
494 453528.734073
495 457353.399441
496 449229.348832
497 463850.210192
498 467786.42046
499 440610.480735
500 454907.788227
501 452877.652249
502 455073.231323
503 449837.443408
504 453314.91354
505 461759.720981
506 459453.382476
507 448617.212173
508 462330.651749
509 447064.919939
510 446329.565708
511 455730.828403
512 452245.132304
513 458222.664041
514 453746.926927
515 453254.315871
516 446503.054475
517 456586.512567
518 452798.91707
519 453611.391792
520 454431.243825
521 466405.486708
522 469967.323784
523 467140.63512
524 458998.196182
525 461826.101757
526 472105.549549
527 456292.468075
528 468214.079732
529 448270.427289
530 462453.806799
531 460936.621175
532 449115.284661
533 472331.331303
534 464527.121175
535 468145.843906
536 450372.912703
537 454419.494911
538 445169.469101
539 462999.849316
540 465758.902974
541 466609.955255
542 452219.288055
543 460953.896168
544 453815.46756
545 469454.494787
546 456822.462823
547 446304.357466
548 454443.827453
549 457341.519525
550 450530.698348
551 468273.093324
552 467013.656268
553 458249.050742
554 466914.314943
555 454172.217178
556 455915.184514
557 461886.313669
558 454024.953405
559 446669.273122
560 449966.931245
561 478212.624755
562 450426.951686
563 463908.16636
564 443690.480694
565 454350.015846
566 457599.894019
567 460594.836116
568 462662.740745
569 476626.957614
570 466487.521986
571 469449.28376
572 442361.775628
573 475978.160529
574 453333.030091
575 449476.399627
576 471704.435747
577 469737.734801
578 450416.799838
579 457290.751231
580 445421.500483
581 459498.983651
582 447352.476296
583 468394.421021
584 471578.836676
585 435985.486288
586 467576.963032
587 443060.456829
588 466897.113069
589 454228.370918
590 478330.698959
591 467581.018026
592 453877.509683
593 459496.524146
594 454959.528172
595 454811.488651
596 454314.773878
597 456901.618686
598 454733.496117
599 464372.20267
600 465246.153902
601 458836.8296
602 456768.949347
603 452966.062159
604 471072.20521
605 460024.168222
606 466166.431025
607 461166.414343
608 469658.4336
609 456479.207518
610 453423.870001
611 452968.448449
612 454470.250054
613 470477.495215
614 470492.9239
615 474744.917802
616 468483.624303
617 471222.458767
618 472720.541294
619 452847.5814
620 469291.978677
621 439832.295066
622 452322.854076
623 462174.604354
624 467298.867291
625 472917.240659
626 461745.496804
627 453238.542515
628 467673.387473
629 470405.703252
630 451519.989872
631 467835.229346
632 477657.198942
633 448819.15887
634 461888.3898
635 465771.246423
636 449820.35898
637 452031.001554
638 455768.485908
639 461421.846121
640 464193.33567
641 469333.575723
642 478279.538765
643 451321.278687
644 461263.121483
645 466623.299586
646 479283.054434
647 466143.215472
648 468932.891158
649 466577.358763
650 463103.406652
651 441144.653194
652 471427.567358
653 458994.704414
654 465613.230801
655 456761.478414
656 458531.749062
657 460864.691743
658 434037.109397
659 460741.51814
660 461271.099337
661 461157.614067
662 446981.201854
663 467479.080001
664 467193.044641
665 474024.012012
666 463697.301789
667 454336.014486
668 445448.296566
669 470158.553355
670 469054.610374
671 469113.401954
672 454172.329033
673 458548.33736
674 456140.25604
675 456389.972761
676 461065.986395
677 454463.757598
678 459574.160461
679 465989.471161
680 460928.002598
681 459245.576111
682 470923.060395
683 470214.920177
684 481691.615573
685 455911.553369
686 476576.047183
687 473888.551005
688 456261.445542
689 472793.900491
690 466074.839209
691 458029.506306
692 465668.367083
693 459512.688272
694 461419.201093
695 475998.598696
696 448839.805034
697 468497.180495
698 479020.321824
699 468026.443383
700 470369.719701
701 452980.828152
702 456379.993018
703 473816.289962
704 471366.217919
705 477346.940376
706 472568.854898
707 479525.894752
708 455603.366809
709 468235.737751
710 439610.825088
711 465203.820142
712 470745.715622
713 471897.487377
714 463453.453452
715 459763.97956
716 467214.057396
717 457087.97013
718 467303.173374
719 468150.6575
720 475937.83231
721 466900.38917
722 469398.870812
723 459187.712112
724 466506.684266
725 475438.244942
726 457421.41001
727 471703.18656
728 469125.435574
729 458105.47598
730 461538.699138
731 471808.771182
732 442469.882093
733 455809.166829
734 461959.378722
735 461331.321136
736 456085.910972
737 462638.323722
738 452098.150039
739 462413.859168
740 463808.476596
741 456741.101304
742 464424.515391
743 462866.613485
744 466479.336993
745 466186.058436
746 464300.088087
747 451898.537697
748 464004.112935
749 465118.320405
750 460529.394747
751 452491.736894
752 480870.048458
753 473703.086644
754 452135.948903
755 461846.267116
756 458941.248726
757 461302.392594
758 469588.53045
759 469564.439087
760 465315.097871
761 457375.088088
762 470295.300088
763 448746.114797
764 452816.727943
765 443945.04885
766 474490.006713
767 449869.869024
768 461594.111008
769 465489.733322
770 480257.825249
771 459617.041334
772 464882.029259
773 469139.748524
774 461840.06876
775 471222.727302
776 457636.330029
777 471555.425956
778 465878.765931
779 456639.945425
780 462666.319719
781 467264.413225
782 470257.806134
783 458147.878869
784 466082.655864
785 459445.70685
786 471782.224293
787 473857.665402
788 474506.891646
789 446567.042515
790 464621.440833
791 457873.413136
792 475431.279025
793 472994.263452
794 455404.631874
795 446723.973431
796 459299.943151
797 453882.758929
798 465453.960944
799 474737.108894
800 449931.698294
801 469876.365422
802 459645.362904
803 471808.221112
804 458523.715643
805 457828.656995
806 467217.643058
807 450035.982554
808 450923.283289
809 465407.56085
810 456417.655157
811 466258.831348
812 471258.52425
813 464772.902387
814 454764.393255
815 462819.648179
816 473028.67356
817 466288.922554
818 461208.90008
819 460108.786999
820 475749.070708
821 464941.18464
822 464525.871474
823 459480.05725
824 465899.697316
825 442585.603332
826 470717.746314
827 446593.314147
828 460813.148387
829 463779.766468
830 464796.437293
831 456235.419413
832 469174.889791
833 462192.658056
834 452536.133914
835 471946.152386
836 461136.952491
837 460674.5885
838 461623.124609
839 455486.612504
840 463472.753595
841 459857.48455
842 470230.063753
843 450238.986727
844 455955.00529
845 467365.886574
846 463742.923287
847 468777.265812
848 474223.360505
849 456195.036636
850 463304.323598
851 468562.695615
852 459034.557943
853 465347.085968
854 472036.227965
855 456490.265346
856 466401.882276
857 451377.392198
858 461538.5998
859 471399.374777
860 462746.554585
861 465314.85818
862 462120.275914
863 462713.376147
864 466365.146812
865 466185.839568
866 464831.954906
867 472123.64733
868 475324.387479
869 469502.978613
870 482549.982527
871 471105.962503
872 473026.461334
873 469684.34657
874 464524.546399
875 481446.487332
876 440686.797618
877 449017.890924
878 465985.745028
879 473603.636418
880 467378.527812
881 450754.025925
882 462650.397525
883 451664.844251
884 461401.770593
885 459141.317746
886 461827.004038
887 462899.114719
888 464910.95324
889 469345.727799
890 465270.808882
891 464805.127877
892 460052.626494
893 461660.196697
894 467951.278282
895 472420.304554
896 465781.320582
897 462749.37307
898 459336.648673
899 465225.546981
900 464042.348109
901 458793.192175
902 481503.08252
903 444940.820287
904 480521.995267
905 454366.534662
906 468466.769842
907 456896.579502
908 464268.488055
909 471536.836611
910 469290.153827
911 470679.633919
912 475064.287145
913 469105.73824
914 459103.902005
915 469603.906377
916 461764.777811
917 469787.338947
918 454725.840306
919 462657.95222
920 482888.965237
921 464059.949852
922 468592.428025
923 470099.510776
924 469236.882337
925 464309.585567
926 462587.804664
927 476099.128352
928 444318.044678
929 452511.448325
930 468095.55201
931 458802.865123
932 463960.75901
933 475150.335814
934 465492.669514
935 473444.144985
936 447318.818759
937 459075.828247
938 463739.983051
939 451518.115594
940 471713.763888
941 451808.200112
942 443575.261856
943 469528.739792
944 460573.817136
945 451658.370347
946 457639.056983
947 469931.41386
948 479724.970377
949 460636.777035
950 469874.573307
951 460913.104208
952 467181.401912
953 466149.66598
954 453216.825007
955 463391.822064
956 451697.783757
957 453245.892403
958 470268.466109
959 455609.640315
960 451701.154269
961 472574.744083
962 461804.061868
963 467446.96385
964 469438.832502
965 463149.788906
966 449154.111988
967 459544.287152
968 461652.810407
969 441619.811438
970 470233.88891
971 461997.551723
972 468958.333499
973 448154.490024
974 454609.849946
975 457621.205317
976 462241.776523
977 453148.277274
978 468314.988292
979 464901.620809
980 462216.668549
981 456527.211309
982 476560.293381
983 468170.399552
984 462856.456981
985 482559.548957
986 456634.838073
987 462253.711543
988 465066.80939
989 471896.961652
990 469983.965756
991 459935.129643
992 464247.599662
993 463017.220567
994 462710.737727
995 466962.490509
996 461029.759248
997 468450.51417
998 468140.728236
999 460322.689034

In [29]:
episode_value = [i["value"] for i in episodes_train]
episode_value = pd.concat(episode_value, axis=1)
episode_value.columns = list(range(episode_value.shape[1]))
episode_value.plot(figsize=(12, 9), legend=False)
plt.savefig("样本内.png")



In [10]:
state = env.reset()
for i in range(1445):
    action = agent.get_stochastic_policy(state)
    next_state, reward, done = env.step(action)
    state = next_state

tmp = env.plot_data()
tmp.iloc[:, 0].plot(figsize=(16, 6))


Out[10]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f00e5843c88>

In [28]:
tmp.iloc[1203:, 0].plot(figsize=(16, 6))
plt.savefig("样本外.png")



In [27]:
test_data = tmp.iloc[1203:]

ratio = np.sum(np.maximum(np.sign(test_data['reward']), 0))/test_data.shape[0]
print("胜率:", ratio)


胜率: 0.5413223140495868

In [ ]:


In [ ]:


In [ ]: