train_and_test



In [1]:
import tensorflow as tf
from agent.main import Agent
from emulator.main import Account
from params import *

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
%matplotlib inline

完全随机探索填充经验库


In [2]:
env = Account()
agent = Agent()
# fill cache
for episode in range(2):
    state = env.reset()
    while True:
        action = agent.get_stochastic_policy(state, 0)
        next_state, reward, done = env.step(action)
        agent.update_cache(state, action, reward, next_state, done)
        state = next_state
        if done:
            break
print(len(agent.agent.cache))


2048

训练模型


In [3]:
NUM_EPISODES = 10

episodes_train = []
global_step = 0
for episode in range(NUM_EPISODES):
    state = env.reset()
    episode_step = 0
    while True:
        global_step += 1
        episode_step += 1

        action = agent.get_stochastic_policy(state)
        next_state, reward, done = env.step(action)
        agent.update_cache(state, action, reward, next_state, done)
        state = next_state

        if global_step % TARGET_STEP_SIZE == 0:
            agent.update_target()

        if episode_step % TRAIN_STEP_SIZE == 0 or done:
            agent.update_eval()

            if done:
                print(episode, env.A.total_value)
                episodes_train.append(env.plot_data())
                break

tmp = env.plot_data()
tmp.iloc[:, 0].plot(figsize=(16, 6))
agent.save_model()


0 106121.908286
1 141751.696138
2 182852.38054
3 224314.293184
4 226946.288765
5 236927.351461
6 279614.945652
7 277995.882632
8 298934.630267
9 319255.850856

测试算法


In [4]:
episode_value = [i["value"] for i in episodes_train]
episode_value = pd.concat(episode_value, axis=1)
episode_value.columns = list(range(episode_value.shape[1]))
episode_value.plot(figsize=(12, 9))


Out[4]:
<matplotlib.axes._subplots.AxesSubplot at 0x240d5b4ba58>

In [5]:
state = env.reset()
for i in range(1445):
    action = agent.get_stochastic_policy(state)
    next_state, reward, done = env.step(action)
    state = next_state

tmp = env.plot_data()
tmp.iloc[:, 0].plot(figsize=(16, 6))


Out[5]:
<matplotlib.axes._subplots.AxesSubplot at 0x240efce3828>

继续训练


In [6]:
agent.restore_model()

episodes_train = []
global_step = 0

NUM_EPISODES = 1000
global_step = 0
for episode in range(NUM_EPISODES):
    state = env.reset()
    episode_step = 0
    while True:
        global_step += 1
        episode_step += 1

        action = agent.get_stochastic_policy(state)
        next_state, reward, done = env.step(action)
        reward /= 100 
        agent.update_cache(state, action, reward, next_state, done)
        state = next_state

        if global_step % TARGET_STEP_SIZE == 0:
            agent.update_target()

        if episode_step % TRAIN_STEP_SIZE == 0 or done:
            agent.update_eval()

            if done:
                print(episode, env.A.total_value)
                episodes_train.append(env.plot_data())
                break

tmp = env.plot_data()
tmp.iloc[:, 0].plot(figsize=(16, 6))
agent.save_model()


INFO:tensorflow:Restoring parameters from model/ddqn.ckpt
0 320402.347367
1 304900.034478
2 272539.111529
3 250054.418468
4 259823.338869
5 277531.613587
6 287847.1194
7 281054.240175
8 278023.549652
9 277316.871988
10 282134.929649
11 291791.585449
12 269313.904078
13 311963.160286
14 312424.002299
15 313252.527881
16 340209.100206
17 341799.967912
18 344522.009712
19 352221.085145
20 344221.088569
21 344647.779756
22 351333.469251
23 361599.221968
24 364219.915866
25 368448.785999
26 367666.009463
27 361841.012964
28 367226.101384
29 379429.173693
30 361337.394851
31 384089.88138
32 360175.002011
33 392497.111432
34 375594.060018
35 379160.988622
36 394515.618809
37 366515.798306
38 377776.773868
39 378897.698827
40 417979.144669
41 411063.635102
42 399831.663201
43 405718.597394
44 407288.146497
45 412860.185726
46 425286.363438
47 421682.799415
48 413614.012883
49 396936.938854
50 407945.50062
51 428248.408997
52 407772.038783
53 396449.843145
54 426292.590981
55 410134.886957
56 419938.18617
57 437055.866077
58 433814.118858
59 415276.637005
60 432367.386821
61 426397.160196
62 434348.10499
63 433161.139458
64 423493.987876
65 442033.383001
66 443997.162079
67 443218.242607
68 418228.490307
69 424876.964784
70 446070.046874
71 423473.553568
72 439278.693212
73 438803.98055
74 428016.395231
75 408352.628373
76 448364.742044
77 440637.877645
78 446208.892011
79 427396.271033
80 432280.561965
81 425634.428153
82 438089.538367
83 432492.348316
84 457264.401632
85 417136.467315
86 423227.334578
87 440380.661669
88 441701.708169
89 438261.926971
90 432398.398648
91 425663.735627
92 449606.834521
93 442671.319889
94 454618.361562
95 444179.054852
96 432972.24817
97 450654.871241
98 449177.083832
99 436736.708722
100 446922.66632
101 457463.663439
102 442536.447897
103 450074.850947
104 451825.639422
105 435778.596567
106 433421.973336
107 450409.78312
108 439925.671276
109 446510.33382
110 440624.488583
111 442625.533572
112 443037.890704
113 446247.07138
114 458018.56136
115 446259.406006
116 453237.193125
117 463448.263773
118 451934.469615
119 434855.424632
120 453342.437122
121 457495.415601
122 441879.840739
123 433993.621548
124 445784.69551
125 445971.888427
126 438716.644136
127 459055.572103
128 447169.456499
129 446569.662083
130 452707.074295
131 452614.408838
132 444531.426921
133 439234.547475
134 452405.625074
135 444715.558065
136 457455.456508
137 453437.2851
138 454220.054437
139 452134.455466
140 453166.97534
141 461060.063128
142 461189.66859
143 447567.684462
144 455327.744814
145 447435.382426
146 444701.154834
147 467476.707011
148 449103.746397
149 463386.391218
150 452901.530304
151 466299.947316
152 460187.757663
153 445933.434343
154 447323.208728
155 456105.398191
156 455983.895036
157 458878.642906
158 453366.055548
159 452752.701409
160 468670.073552
161 436583.543454
162 451027.628663
163 462874.470303
164 453374.687925
165 447818.085267
166 452631.412117
167 446433.804775
168 447127.936052
169 463497.449107
170 464409.260146
171 463377.789309
172 457802.317803
173 465449.68333
174 454552.594592
175 450272.043624
176 463409.690031
177 445806.535053
178 460858.832074
179 454836.733049
180 443567.093569
181 453886.663538
182 447270.500678
183 455731.146681
184 456792.927964
185 461515.960848
186 447615.534997
187 457379.230363
188 466328.61502
189 459236.727911
190 440698.753266
191 447468.182795
192 456379.812625
193 466460.288924
194 465651.721316
195 453961.010426
196 446753.538505
197 455633.282518
198 462694.641965
199 457200.279949
200 466784.058111
201 462457.427427
202 452294.635627
203 465867.568432
204 448279.842769
205 445566.511824
206 460770.636129
207 443194.292426
208 449918.64407
209 443678.878865
210 467357.590399
211 463397.166078
212 448375.664743
213 469455.378361
214 448748.223741
215 441617.586939
216 461736.52629
217 469014.882552
218 455570.851504
219 443118.712791
220 448403.806349
221 461688.876774
222 450213.29787
223 462953.318812
224 463028.835066
225 459254.907668
226 468805.993866
227 447436.812824
228 466241.81411
229 455931.719947
230 453459.219808
231 451465.075142
232 458008.030835
233 455926.196491
234 456767.655531
235 468812.632481
236 465338.181757
237 469086.095288
238 462851.180188
239 467672.546282
240 466418.347778
241 468465.99535
242 468557.702852
243 468371.897876
244 452155.580149
245 450526.015638
246 464894.391501
247 443973.835051
248 464864.217345
249 466411.700776
250 469587.281678
251 463125.583556
252 451137.845925
253 455644.756604
254 464407.258687
255 467368.969113
256 462859.176693
257 463252.093104
258 465520.468654
259 435735.676593
260 465420.384826
261 465096.736438
262 463775.306214
263 454767.400005
264 460915.59146
265 461462.004896
266 468981.626902
267 454307.423783
268 464498.068717
269 468106.289795
270 451027.958464
271 453909.441786
272 467100.930881
273 463877.478219
274 451571.007457
275 455804.826621
276 455799.677895
277 466189.324728
278 467572.116056
279 468792.246803
280 459446.644255
281 451790.430304
282 469581.907372
283 468771.587168
284 465495.590521
285 466841.420966
286 464456.217652
287 464092.345205
288 472459.58716
289 466571.953127
290 462947.825361
291 475636.282824
292 463757.337626
293 463246.574208
294 449816.794314
295 475183.156117
296 456006.411465
297 461425.793597
298 465385.588287
299 471190.933127
300 466701.447323
301 467471.047212
302 461213.335463
303 446934.140296
304 475989.921113
305 477284.542165
306 449566.423279
307 464047.314701
308 454930.529012
309 455667.549404
310 460056.375474
311 473774.232336
312 472946.539887
313 458273.611462
314 470822.031873
315 462401.39485
316 454541.378225
317 467015.367642
318 450011.541969
319 456349.120508
320 464236.24907
321 477832.412463
322 468484.7737
323 445743.52586
324 462262.149628
325 482838.675563
326 468600.657593
327 468003.366095
328 463575.136824
329 472205.558885
330 465170.902502
331 464677.298294
332 462809.593763
333 457056.303848
334 463438.057632
335 470781.13822
336 465034.060371
337 462416.968213
338 459554.047895
339 467983.590432
340 463522.161867
341 458528.05354
342 460287.019973
343 441101.068902
344 448942.158389
345 442337.216237
346 446154.044189
347 458401.372682
348 454262.566231
349 444034.10425
350 460361.172413
351 457579.372753
352 454477.330522
353 455055.033268
354 468524.085891
355 470535.309386
356 440941.559589
357 464519.345775
358 472141.552142
359 469187.035475
360 460737.21542
361 456356.850576
362 449549.04127
363 454818.064673
364 469860.911909
365 452113.630393
366 459967.067423
367 461446.290115
368 466140.650171
369 459817.555159
370 466065.96866
371 464249.640177
372 468048.067753
373 457936.656877
374 461286.503492
375 458028.10537
376 456973.027631
377 456605.404678
378 469255.615987
379 469832.743605
380 460167.540486
381 472763.185878
382 476902.469009
383 451311.491815
384 463292.711684
385 463024.104828
386 461087.090306
387 471974.272636
388 473798.586287
389 454229.364431
390 443793.676678
391 453677.148747
392 457257.058831
393 471407.670313
394 472902.415271
395 463198.501523
396 476861.754126
397 470305.202614
398 482123.123086
399 471299.665493
400 456933.026757
401 446806.077673
402 471107.506307
403 475303.434511
404 434855.812328
405 464373.638324
406 447474.467553
407 445853.708392
408 452699.423775
409 463776.947536
410 451547.092527
411 439611.383141
412 465281.569039
413 455190.289618
414 465876.391241
415 461265.11876
416 455211.032615
417 462305.719756
418 463107.610898
419 472762.830863
420 483665.10275
421 467475.896499
422 462710.742389
423 462829.961044
424 456657.705946
425 466277.282376
426 450572.692628
427 465519.147601
428 466978.041572
429 467250.93547
430 471833.701946
431 457893.080767
432 454692.750415
433 461373.293662
434 454192.179421
435 467266.1525
436 464850.15145
437 465075.908874
438 466366.68151
439 458299.84896
440 458081.312921
441 448771.082827
442 455583.935766
443 471005.52692
444 468315.705104
445 459902.752673
446 458314.959086
447 452927.121237
448 458610.360665
449 464105.88827
450 460126.005627
451 460797.841158
452 473036.122195
453 464937.446368
454 455282.891872
455 473766.64187
456 454530.335909
457 465696.30375
458 467886.213448
459 463238.138259
460 457221.822054
461 471282.373023
462 454239.862228
463 471110.638799
464 476184.280926
465 453611.190705
466 464564.346083
467 470986.245948
468 459839.82369
469 455505.839944
470 451775.990547
471 475051.493705
472 469633.827985
473 465391.145125
474 471936.346717
475 444628.917696
476 459225.076028
477 464680.820655
478 459935.691754
479 461636.553073
480 461966.551888
481 461982.249437
482 466990.811626
483 471322.236533
484 455971.653333
485 452889.424557
486 455061.433452
487 461354.107263
488 458098.21082
489 468873.898368
490 460565.196552
491 478766.918774
492 448328.714536
493 465740.223385
494 464422.017542
495 457653.663375
496 454856.814319
497 454882.669375
498 481761.993151
499 457145.191697
500 465329.812551
501 465147.048132
502 473471.762572
503 465288.920244
504 453227.482259
505 462311.842245
506 458635.92095
507 468402.16095
508 465968.085162
509 465161.234752
510 468699.068289
511 470045.811538
512 461409.533971
513 474097.64963
514 462442.39902
515 454182.43379
516 471827.212677
517 466038.740784
518 473358.122946
519 467735.132336
520 458232.527677
521 478210.153661
522 459762.342889
523 458654.305699
524 473680.687138
525 470392.3164
526 472443.336549
527 463205.347129
528 474853.641024
529 475104.063941
530 465639.198299
531 473514.177956
532 461024.609622
533 464330.214435
534 456937.079324
535 465633.447593
536 469662.126085
537 447181.096133
538 458468.263956
539 457297.952107
540 468826.747182
541 471991.644445
542 465361.511505
543 466987.098987
544 464725.773136
545 470786.025867
546 460983.533518
547 468815.453564
548 460423.06793
549 474773.848746
550 451772.25795
551 465573.14603
552 477383.36729
553 464142.574863
554 466200.830338
555 454218.658487
556 459339.316757
557 454916.310906
558 453878.252787
559 459978.549008
560 457129.321007
561 460834.015056
562 465716.699428
563 457153.670336
564 449405.914035
565 463869.118368
566 449283.302767
567 464171.18548
568 459197.799076
569 471311.69136
570 454087.308611
571 461298.418219
572 464667.946853
573 457555.979313
574 460850.572652
575 433423.261859
576 456718.582463
577 464153.106926
578 437581.283044
579 448148.929793
580 460374.422215
581 455060.414404
582 451208.611586
583 473470.361642
584 463679.469216
585 453912.361933
586 452756.522146
587 470120.862308
588 473014.515588
589 471552.049267
590 463116.596737
591 469480.120603
592 450332.036977
593 483867.468502
594 469566.024271
595 467145.013081
596 461690.038329
597 460483.107956
598 476541.673716
599 468247.318523
600 473265.892619
601 456453.686556
602 463709.709257
603 467599.19501
604 471496.362372
605 463754.97029
606 469383.345809
607 465564.566585
608 475548.229916
609 463735.941717
610 473300.513548
611 473742.847777
612 484548.137187
613 464809.716409
614 466639.312812
615 472125.227421
616 467339.098413
617 463107.828337
618 456953.041571
619 478368.384187
620 463660.471383
621 463094.397446
622 463589.99292
623 476173.720059
624 466405.878987
625 456465.609043
626 467778.338693
627 459726.089055
628 445549.723428
629 479225.339229
630 470997.301463
631 478417.60648
632 450966.920004
633 479136.071889
634 466448.398028
635 472046.075094
636 471178.718641
637 460036.553403
638 464776.838471
639 458833.09709
640 469614.373049
641 463085.069915
642 458264.499186
643 466350.570531
644 470370.011895
645 457394.891067
646 476751.663105
647 474340.886397
648 455215.92878
649 467463.782251
650 478146.077087
651 468621.764754
652 456973.934385
653 458398.611436
654 466546.362066
655 468536.641421
656 474793.232526
657 483098.871931
658 462166.257144
659 475145.12244
660 461511.485706
661 463278.488689
662 468733.407742
663 469334.050426
664 463044.475959
665 453449.203052
666 464098.035622
667 464529.663933
668 463352.846129
669 446285.310859
670 462229.934644
671 478725.200763
672 470531.660778
673 469149.635699
674 467860.661189
675 459395.997327
676 478820.963799
677 456919.42689
678 469812.249914
679 464304.90966
680 461242.333991
681 475341.752575
682 456699.057393
683 458340.124609
684 471321.770958
685 468934.253835
686 451962.848981
687 471491.904398
688 472663.771471
689 459922.05011
690 451860.993402
691 468668.767526
692 460786.591743
693 459854.746607
694 464469.641158
695 457079.429137
696 473486.635647
697 476573.912347
698 467309.771285
699 470412.082253
700 460684.452888
701 463289.411319
702 475759.998195
703 469558.66874
704 478638.909768
705 459485.567464
706 480538.970242
707 468437.972244
708 468450.032553
709 456420.610078
710 469012.702281
711 476977.104962
712 463973.61332
713 454142.015121
714 452572.988108
715 472821.745112
716 475083.870395
717 479637.02767
718 479722.538163
719 456950.776827
720 452539.820966
721 462072.44208
722 464941.869659
723 476105.888606
724 474591.235285
725 466804.018381
726 465139.247616
727 471492.164886
728 467075.550863
729 474420.529456
730 449924.43615
731 464595.51356
732 470107.510254
733 462793.394505
734 471613.773827
735 468727.408676
736 454365.296571
737 469954.005854
738 477228.467563
739 456942.512737
740 472240.498552
741 468669.033774
742 468480.147564
743 469296.920539
744 473870.841422
745 477198.814281
746 464708.944684
747 467497.980488
748 476633.065102
749 459104.580213
750 458274.54014
751 467719.774249
752 469179.689826
753 454907.252518
754 464406.202521
755 466766.864236
756 457885.589766
757 453217.248832
758 464031.637107
759 447421.806911
760 467934.788736
761 464176.318598
762 450788.009118
763 461224.078594
764 467775.76488
765 470869.524774
766 475006.828179
767 473778.105072
768 468609.373008
769 471453.852082
770 451478.881359
771 471626.136547
772 449289.192616
773 458811.179128
774 466133.31415
775 451501.886701
776 471710.137159
777 460685.909067
778 456155.445783
779 462248.872921
780 448202.822265
781 462385.041719
782 468157.146315
783 458624.496896
784 456143.428991
785 470408.246198
786 465771.973541
787 461855.899076
788 451224.124379
789 475193.456372
790 466882.537553
791 461895.927861
792 462328.226892
793 469317.025853
794 453788.742366
795 470858.235185
796 449826.04644
797 454706.883123
798 455854.796471
799 459496.797565
800 455045.677036
801 463264.924962
802 455829.796391
803 476061.64637
804 462503.605925
805 453297.559266
806 473377.792514
807 454108.099865
808 465490.560657
809 451257.742335
810 442621.371204
811 459448.530808
812 475422.554501
813 462587.408172
814 455656.929625
815 466566.607883
816 476498.431012
817 475815.903672
818 461539.890191
819 462204.974983
820 465254.250661
821 458069.147303
822 454539.524647
823 460652.154353
824 458004.224431
825 459211.051316
826 456564.011399
827 444732.992101
828 463475.902266
829 468761.051571
830 453680.081578
831 477300.769339
832 468289.585598
833 468394.05748
834 461794.300563
835 460034.138784
836 453389.655622
837 455553.330171
838 469687.602536
839 453923.199446
840 466937.729318
841 471262.913817
842 447057.657641
843 467207.130523
844 462092.736225
845 462627.794763
846 461176.284655
847 475582.59889
848 466650.068991
849 453745.423812
850 462634.149512
851 457385.106497
852 471276.392721
853 460934.553182
854 464164.198211
855 448067.814242
856 471651.308871
857 465080.987735
858 464406.17473
859 446338.288472
860 448734.807665
861 458944.126961
862 457704.598569
863 469807.064455
864 467445.654292
865 447539.899197
866 467769.391736
867 460935.525299
868 456404.199164
869 462267.492519
870 453147.005673
871 470105.719308
872 459690.804004
873 448694.26105
874 465365.475206
875 456024.045008
876 456174.457515
877 468490.584762
878 466773.894805
879 455446.100441
880 472646.947875
881 464462.529399
882 462755.612646
883 471936.611241
884 460158.900092
885 469989.578543
886 470838.373891
887 467486.8631
888 472408.602724
889 470812.303903
890 447186.961492
891 473020.543979
892 455579.444146
893 469235.508263
894 463155.873952
895 477724.22665
896 464809.410933
897 454604.007732
898 462683.018607
899 445679.404367
900 467001.426939
901 463289.554271
902 451355.848393
903 458343.062223
904 469121.322138
905 459773.52443
906 456325.253758
907 453414.38919
908 460312.962325
909 458741.43782
910 461642.54254
911 453691.436142
912 468826.300118
913 467420.212308
914 468278.311726
915 462568.103049
916 454469.413137
917 462841.549561
918 470635.899152
919 461273.07183
920 466655.26501
921 468389.486059
922 470114.162761
923 473808.717654
924 461929.946285
925 467662.797064
926 463104.285282
927 474954.944934
928 449927.22153
929 457892.308123
930 446755.310856
931 449369.815008
932 461142.418618
933 462964.165414
934 487480.521018
935 458194.298201
936 453167.020836
937 460831.221871
938 453151.194853
939 476157.028955
940 457305.026828
941 469707.697614
942 478375.460981
943 462670.134824
944 463219.23228
945 457774.308006
946 463138.217711
947 451426.193906
948 474363.562636
949 464322.293526
950 456600.056674
951 473759.128415
952 466623.019406
953 475928.874759
954 470779.174556
955 457838.229127
956 471683.63536
957 472964.54117
958 459248.168594
959 459366.537163
960 454955.724996
961 465285.960778
962 455717.753974
963 460784.597082
964 452196.751479
965 464719.682172
966 471505.50143
967 457827.861139
968 468178.074963
969 463880.521412
970 441393.949381
971 469204.158308
972 478050.401017
973 449756.541994
974 474673.368776
975 470400.066137
976 454455.380016
977 465717.727468
978 463146.224841
979 471929.487202
980 462143.091192
981 465331.089078
982 469471.433754
983 461753.671675
984 463248.861588
985 448542.735201
986 465996.896959
987 460785.096845
988 462788.290028
989 474709.041854
990 478698.24399
991 464035.752248
992 470365.510843
993 458052.141043
994 477615.763539
995 462922.728786
996 454920.67128
997 470217.569216
998 462820.239009
999 474566.987489

In [7]:
episode_value = [i["value"] for i in episodes_train]
episode_value = pd.concat(episode_value, axis=1)
episode_value.columns = list(range(episode_value.shape[1]))
episode_value.plot(figsize=(12, 9), legend=False)
plt.savefig("样本内.png")



In [8]:
state = env.reset()
for i in range(1445):
    action = agent.get_stochastic_policy(state)
    next_state, reward, done = env.step(action)
    state = next_state

tmp = env.plot_data()
tmp.iloc[:, 0].plot(figsize=(16, 6))


Out[8]:
<matplotlib.axes._subplots.AxesSubplot at 0x24095b1e6d8>

In [9]:
tmp.iloc[1203:, 0].plot(figsize=(16, 6))
plt.savefig("样本外.png")



In [10]:
test_data = tmp.iloc[1203:]

ratio = np.sum(np.maximum(np.sign(test_data['reward']), 0))/test_data.shape[0]
print("胜率:", ratio)


胜率: 0.5

In [ ]:


In [ ]:


In [ ]: