added episilon for equality check
major layout changes in the network
This commit is contained in:
251
output_graphs.py
251
output_graphs.py
@@ -2,256 +2,7 @@ import re
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
text = r"""INFO:__main__:step= 0 loss=1.3149878 acc=0.9093018
|
text = r"""
|
||||||
INFO:__main__:step= 100 loss=0.0089776 acc=0.9993286
|
|
||||||
INFO:__main__:step= 200 loss=0.0088239 acc=0.9996948
|
|
||||||
INFO:__main__:step= 300 loss=0.0075553 acc=0.9996948
|
|
||||||
INFO:__main__:step= 400 loss=0.0065352 acc=0.9995728
|
|
||||||
INFO:__main__:step= 500 loss=0.0053752 acc=0.9997559
|
|
||||||
INFO:__main__:step= 600 loss=0.0043060 acc=0.9998169
|
|
||||||
INFO:__main__:step= 700 loss=0.0045364 acc=0.9996338
|
|
||||||
INFO:__main__:step= 800 loss=0.0037988 acc=0.9996948
|
|
||||||
INFO:__main__:step= 900 loss=0.0037188 acc=0.9998779
|
|
||||||
INFO:__main__:step= 1000 loss=0.0034959 acc=0.9996338
|
|
||||||
INFO:__main__:step= 1100 loss=0.0032190 acc=0.9998169
|
|
||||||
INFO:__main__:step= 1200 loss=0.0033895 acc=1.0000000
|
|
||||||
INFO:__main__:step= 1300 loss=0.0031267 acc=0.9998779
|
|
||||||
INFO:__main__:step= 1400 loss=0.0028533 acc=0.9999390
|
|
||||||
INFO:__main__:step= 1500 loss=0.0024374 acc=0.9998779
|
|
||||||
INFO:__main__:step= 1600 loss=0.0025314 acc=0.9998779
|
|
||||||
INFO:__main__:step= 1700 loss=0.0024157 acc=1.0000000
|
|
||||||
INFO:__main__:step= 1800 loss=0.0019437 acc=1.0000000
|
|
||||||
INFO:__main__:step= 1900 loss=0.0019343 acc=0.9997559
|
|
||||||
INFO:__main__:step= 2000 loss=0.0020423 acc=0.9996948
|
|
||||||
INFO:__main__:step= 2100 loss=0.0025607 acc=0.9999390
|
|
||||||
INFO:__main__:step= 2200 loss=0.0020113 acc=0.9998779
|
|
||||||
INFO:__main__:step= 2300 loss=0.0017351 acc=0.9999390
|
|
||||||
INFO:__main__:step= 2400 loss=0.0017944 acc=0.9998169
|
|
||||||
INFO:__main__:step= 2500 loss=0.0016655 acc=0.9998169
|
|
||||||
INFO:__main__:step= 2600 loss=0.0016493 acc=0.9997559
|
|
||||||
INFO:__main__:step= 2700 loss=0.0016888 acc=1.0000000
|
|
||||||
INFO:__main__:step= 2800 loss=0.0017318 acc=0.9998779
|
|
||||||
INFO:__main__:step= 2900 loss=0.0016884 acc=0.9999390
|
|
||||||
INFO:__main__:step= 3000 loss=0.0012935 acc=0.9998779
|
|
||||||
INFO:__main__:step= 3100 loss=0.0011316 acc=1.0000000
|
|
||||||
INFO:__main__:step= 3200 loss=0.0015423 acc=0.9998169
|
|
||||||
INFO:__main__:step= 3300 loss=0.0008634 acc=0.9999390
|
|
||||||
INFO:__main__:step= 3400 loss=0.0005173 acc=1.0000000
|
|
||||||
INFO:__main__:step= 3500 loss=0.0005730 acc=0.9999390
|
|
||||||
INFO:__main__:step= 3600 loss=0.0007605 acc=1.0000000
|
|
||||||
INFO:__main__:step= 3700 loss=0.0006299 acc=0.9999390
|
|
||||||
INFO:__main__:step= 3800 loss=0.0006295 acc=0.9999390
|
|
||||||
INFO:__main__:step= 3900 loss=0.0003615 acc=1.0000000
|
|
||||||
INFO:__main__:step= 4000 loss=0.0004475 acc=1.0000000
|
|
||||||
INFO:__main__:step= 4100 loss=0.0005678 acc=0.9998779
|
|
||||||
INFO:__main__:step= 4200 loss=0.0003502 acc=1.0000000
|
|
||||||
INFO:__main__:step= 4300 loss=0.0005562 acc=0.9998779
|
|
||||||
INFO:__main__:step= 4400 loss=0.0005300 acc=0.9998779
|
|
||||||
INFO:__main__:step= 4500 loss=0.0004442 acc=0.9998779
|
|
||||||
INFO:__main__:step= 4600 loss=0.0002595 acc=1.0000000
|
|
||||||
INFO:__main__:step= 4700 loss=0.0003925 acc=1.0000000
|
|
||||||
INFO:__main__:step= 4800 loss=0.0003808 acc=1.0000000
|
|
||||||
INFO:__main__:step= 4900 loss=0.0003306 acc=0.9999390
|
|
||||||
INFO:__main__:step= 5000 loss=0.0003823 acc=0.9998779
|
|
||||||
INFO:__main__:step= 5100 loss=0.0002463 acc=1.0000000
|
|
||||||
INFO:__main__:step= 5200 loss=0.0003248 acc=0.9999390
|
|
||||||
INFO:__main__:step= 5300 loss=0.0002426 acc=1.0000000
|
|
||||||
INFO:__main__:step= 5400 loss=0.0002643 acc=1.0000000
|
|
||||||
INFO:__main__:step= 5500 loss=0.0003434 acc=1.0000000
|
|
||||||
INFO:__main__:step= 5600 loss=0.0003985 acc=0.9999390
|
|
||||||
INFO:__main__:step= 5700 loss=0.0004590 acc=0.9997559
|
|
||||||
INFO:__main__:step= 5800 loss=0.0002166 acc=1.0000000
|
|
||||||
INFO:__main__:step= 5900 loss=0.0002622 acc=0.9999390
|
|
||||||
INFO:__main__:step= 6000 loss=0.0003202 acc=1.0000000
|
|
||||||
INFO:__main__:step= 6100 loss=0.0003421 acc=1.0000000
|
|
||||||
INFO:__main__:step= 6200 loss=0.0004393 acc=0.9997559
|
|
||||||
INFO:__main__:step= 6300 loss=0.0002363 acc=1.0000000
|
|
||||||
INFO:__main__:step= 6400 loss=0.0000994 acc=1.0000000
|
|
||||||
INFO:__main__:step= 6500 loss=0.0001811 acc=1.0000000
|
|
||||||
INFO:__main__:step= 6600 loss=0.0003322 acc=0.9998779
|
|
||||||
INFO:__main__:step= 6700 loss=0.0002741 acc=0.9999390
|
|
||||||
INFO:__main__:step= 6800 loss=0.0002755 acc=0.9999390
|
|
||||||
INFO:__main__:step= 6900 loss=0.0001762 acc=0.9999390
|
|
||||||
INFO:__main__:step= 7000 loss=0.0002272 acc=0.9998779
|
|
||||||
INFO:__main__:step= 7100 loss=0.0001781 acc=1.0000000
|
|
||||||
INFO:__main__:step= 7200 loss=0.0002126 acc=1.0000000
|
|
||||||
INFO:__main__:step= 7300 loss=0.0002117 acc=0.9998779
|
|
||||||
INFO:__main__:step= 7400 loss=0.0001614 acc=1.0000000
|
|
||||||
INFO:__main__:step= 7500 loss=0.0002344 acc=1.0000000
|
|
||||||
INFO:__main__:step= 7600 loss=0.0003103 acc=0.9998169
|
|
||||||
INFO:__main__:step= 7700 loss=0.0001102 acc=1.0000000
|
|
||||||
INFO:__main__:step= 7800 loss=0.0001416 acc=0.9999390
|
|
||||||
INFO:__main__:step= 7900 loss=0.0001438 acc=0.9999390
|
|
||||||
INFO:__main__:step= 8000 loss=0.0002020 acc=0.9999390
|
|
||||||
INFO:__main__:step= 8100 loss=0.0001185 acc=1.0000000
|
|
||||||
INFO:__main__:step= 8200 loss=0.0001286 acc=0.9999390
|
|
||||||
INFO:__main__:step= 8300 loss=0.0001579 acc=0.9999390
|
|
||||||
INFO:__main__:step= 8400 loss=0.0002156 acc=0.9998779
|
|
||||||
INFO:__main__:step= 8500 loss=0.0001326 acc=1.0000000
|
|
||||||
INFO:__main__:step= 8600 loss=0.0001097 acc=0.9999390
|
|
||||||
INFO:__main__:step= 8700 loss=0.0000575 acc=1.0000000
|
|
||||||
INFO:__main__:step= 8800 loss=0.0001199 acc=1.0000000
|
|
||||||
INFO:__main__:step= 8900 loss=0.0001446 acc=0.9999390
|
|
||||||
INFO:__main__:step= 9000 loss=0.0002343 acc=0.9999390
|
|
||||||
INFO:__main__:step= 9100 loss=0.0000858 acc=1.0000000
|
|
||||||
INFO:__main__:step= 9200 loss=0.0001535 acc=1.0000000
|
|
||||||
INFO:__main__:step= 9300 loss=0.0001014 acc=1.0000000
|
|
||||||
INFO:__main__:step= 9400 loss=0.0000798 acc=1.0000000
|
|
||||||
INFO:__main__:step= 9500 loss=0.0001623 acc=1.0000000
|
|
||||||
INFO:__main__:step= 9600 loss=0.0000767 acc=1.0000000
|
|
||||||
INFO:__main__:step= 9700 loss=0.0002726 acc=0.9998779
|
|
||||||
INFO:__main__:step= 9800 loss=0.0001945 acc=0.9999390
|
|
||||||
INFO:__main__:step= 9900 loss=0.0002082 acc=0.9998779
|
|
||||||
INFO:__main__:step=10000 loss=0.0001320 acc=0.9999390
|
|
||||||
INFO:__main__:step=10100 loss=0.0002039 acc=0.9999390
|
|
||||||
INFO:__main__:step=10200 loss=0.0001236 acc=1.0000000
|
|
||||||
INFO:__main__:step=10300 loss=0.0001641 acc=0.9999390
|
|
||||||
INFO:__main__:step=10400 loss=0.0001063 acc=1.0000000
|
|
||||||
INFO:__main__:step=10500 loss=0.0001110 acc=1.0000000
|
|
||||||
INFO:__main__:step=10600 loss=0.0000836 acc=1.0000000
|
|
||||||
INFO:__main__:step=10700 loss=0.0001277 acc=0.9999390
|
|
||||||
INFO:__main__:step=10800 loss=0.0002018 acc=1.0000000
|
|
||||||
INFO:__main__:step=10900 loss=0.0001056 acc=1.0000000
|
|
||||||
INFO:__main__:step=11000 loss=0.0001680 acc=1.0000000
|
|
||||||
INFO:__main__:step=11100 loss=0.0001366 acc=0.9999390
|
|
||||||
INFO:__main__:step=11200 loss=0.0000372 acc=1.0000000
|
|
||||||
INFO:__main__:step=11300 loss=0.0001248 acc=0.9999390
|
|
||||||
INFO:__main__:step=11400 loss=0.0000712 acc=1.0000000
|
|
||||||
INFO:__main__:step=11500 loss=0.0001172 acc=1.0000000
|
|
||||||
INFO:__main__:step=11600 loss=0.0000921 acc=1.0000000
|
|
||||||
INFO:__main__:step=11700 loss=0.0000951 acc=1.0000000
|
|
||||||
INFO:__main__:step=11800 loss=0.0000610 acc=1.0000000
|
|
||||||
INFO:__main__:step=11900 loss=0.0000803 acc=1.0000000
|
|
||||||
INFO:__main__:step=12000 loss=0.0000788 acc=1.0000000
|
|
||||||
INFO:__main__:step=12100 loss=0.0001272 acc=1.0000000
|
|
||||||
INFO:__main__:step=12200 loss=0.0000690 acc=1.0000000
|
|
||||||
INFO:__main__:step=12300 loss=0.0001702 acc=1.0000000
|
|
||||||
INFO:__main__:step=12400 loss=0.0001313 acc=1.0000000
|
|
||||||
INFO:__main__:step=12500 loss=0.0000308 acc=1.0000000
|
|
||||||
INFO:__main__:step=12600 loss=0.0000845 acc=1.0000000
|
|
||||||
INFO:__main__:step=12700 loss=0.0000732 acc=1.0000000
|
|
||||||
INFO:__main__:step=12800 loss=0.0000183 acc=1.0000000
|
|
||||||
INFO:__main__:step=12900 loss=0.0000300 acc=1.0000000
|
|
||||||
INFO:__main__:step=13000 loss=0.0001123 acc=0.9999390
|
|
||||||
INFO:__main__:step=13100 loss=0.0000594 acc=1.0000000
|
|
||||||
INFO:__main__:step=13200 loss=0.0000668 acc=1.0000000
|
|
||||||
INFO:__main__:step=13300 loss=0.0000843 acc=0.9999390
|
|
||||||
INFO:__main__:step=13400 loss=0.0000407 acc=1.0000000
|
|
||||||
INFO:__main__:step=13500 loss=0.0000463 acc=1.0000000
|
|
||||||
INFO:__main__:step=13600 loss=0.0001134 acc=1.0000000
|
|
||||||
INFO:__main__:step=13700 loss=0.0000711 acc=1.0000000
|
|
||||||
INFO:__main__:step=13800 loss=0.0000646 acc=1.0000000
|
|
||||||
INFO:__main__:step=13900 loss=0.0000137 acc=1.0000000
|
|
||||||
INFO:__main__:step=14000 loss=0.0000803 acc=1.0000000
|
|
||||||
INFO:__main__:step=14100 loss=0.0001049 acc=1.0000000
|
|
||||||
INFO:__main__:step=14200 loss=0.0000583 acc=1.0000000
|
|
||||||
INFO:__main__:step=14300 loss=0.0000532 acc=1.0000000
|
|
||||||
INFO:__main__:step=14400 loss=0.0000281 acc=1.0000000
|
|
||||||
INFO:__main__:step=14500 loss=0.0000641 acc=1.0000000
|
|
||||||
INFO:__main__:step=14600 loss=0.0000408 acc=1.0000000
|
|
||||||
INFO:__main__:step=14700 loss=0.0000708 acc=1.0000000
|
|
||||||
INFO:__main__:step=14800 loss=0.0000410 acc=1.0000000
|
|
||||||
INFO:__main__:step=14900 loss=0.0000047 acc=1.0000000
|
|
||||||
INFO:__main__:step=15000 loss=0.0000676 acc=1.0000000
|
|
||||||
INFO:__main__:step=15100 loss=0.0001132 acc=0.9999390
|
|
||||||
INFO:__main__:step=15200 loss=0.0000244 acc=1.0000000
|
|
||||||
INFO:__main__:step=15300 loss=0.0000069 acc=1.0000000
|
|
||||||
INFO:__main__:step=15400 loss=0.0000572 acc=0.9999390
|
|
||||||
INFO:__main__:step=15500 loss=0.0001351 acc=1.0000000
|
|
||||||
INFO:__main__:step=15600 loss=0.0000896 acc=0.9999390
|
|
||||||
INFO:__main__:step=15700 loss=0.0000167 acc=1.0000000
|
|
||||||
INFO:__main__:step=15800 loss=0.0000382 acc=1.0000000
|
|
||||||
INFO:__main__:step=15900 loss=0.0000231 acc=1.0000000
|
|
||||||
INFO:__main__:step=16000 loss=0.0000428 acc=1.0000000
|
|
||||||
INFO:__main__:step=16100 loss=0.0000390 acc=1.0000000
|
|
||||||
INFO:__main__:step=16200 loss=0.0000236 acc=1.0000000
|
|
||||||
INFO:__main__:step=16300 loss=0.0001501 acc=0.9999390
|
|
||||||
INFO:__main__:step=16400 loss=0.0000269 acc=1.0000000
|
|
||||||
INFO:__main__:step=16500 loss=0.0000121 acc=1.0000000
|
|
||||||
INFO:__main__:step=16600 loss=0.0000089 acc=1.0000000
|
|
||||||
INFO:__main__:step=16700 loss=0.0000335 acc=1.0000000
|
|
||||||
INFO:__main__:step=16800 loss=0.0000302 acc=1.0000000
|
|
||||||
INFO:__main__:step=16900 loss=0.0000183 acc=1.0000000
|
|
||||||
INFO:__main__:step=17000 loss=0.0000311 acc=1.0000000
|
|
||||||
INFO:__main__:step=17100 loss=0.0000031 acc=1.0000000
|
|
||||||
INFO:__main__:step=17200 loss=0.0001091 acc=1.0000000
|
|
||||||
INFO:__main__:step=17300 loss=0.0000030 acc=1.0000000
|
|
||||||
INFO:__main__:step=17400 loss=0.0000742 acc=1.0000000
|
|
||||||
INFO:__main__:step=17500 loss=0.0000403 acc=1.0000000
|
|
||||||
INFO:__main__:step=17600 loss=0.0000163 acc=1.0000000
|
|
||||||
INFO:__main__:step=17700 loss=0.0000700 acc=1.0000000
|
|
||||||
INFO:__main__:step=17800 loss=0.0000477 acc=1.0000000
|
|
||||||
INFO:__main__:step=17900 loss=0.0000113 acc=1.0000000
|
|
||||||
INFO:__main__:step=18000 loss=0.0000013 acc=1.0000000
|
|
||||||
INFO:__main__:step=18100 loss=0.0000353 acc=1.0000000
|
|
||||||
INFO:__main__:step=18200 loss=0.0000010 acc=1.0000000
|
|
||||||
INFO:__main__:step=18300 loss=0.0000175 acc=1.0000000
|
|
||||||
INFO:__main__:step=18400 loss=0.0000156 acc=1.0000000
|
|
||||||
INFO:__main__:step=18500 loss=0.0000024 acc=1.0000000
|
|
||||||
INFO:__main__:step=18600 loss=0.0000125 acc=1.0000000
|
|
||||||
INFO:__main__:step=18700 loss=0.0001110 acc=0.9999390
|
|
||||||
INFO:__main__:step=18800 loss=0.0000066 acc=1.0000000
|
|
||||||
INFO:__main__:step=18900 loss=0.0000136 acc=1.0000000
|
|
||||||
INFO:__main__:step=19000 loss=0.0000629 acc=1.0000000
|
|
||||||
INFO:__main__:step=19100 loss=0.0000235 acc=1.0000000
|
|
||||||
INFO:__main__:step=19200 loss=0.0000301 acc=1.0000000
|
|
||||||
INFO:__main__:step=19300 loss=0.0000246 acc=1.0000000
|
|
||||||
INFO:__main__:step=19400 loss=0.0000824 acc=0.9999390
|
|
||||||
INFO:__main__:step=19500 loss=0.0000525 acc=1.0000000
|
|
||||||
INFO:__main__:step=19600 loss=0.0000315 acc=1.0000000
|
|
||||||
INFO:__main__:step=19700 loss=0.0000004 acc=1.0000000
|
|
||||||
INFO:__main__:step=19800 loss=0.0000337 acc=1.0000000
|
|
||||||
INFO:__main__:step=19900 loss=0.0000544 acc=1.0000000
|
|
||||||
INFO:__main__:step=20000 loss=0.0000134 acc=1.0000000
|
|
||||||
INFO:__main__:step=20100 loss=0.0000454 acc=1.0000000
|
|
||||||
INFO:__main__:step=20200 loss=0.0000668 acc=1.0000000
|
|
||||||
INFO:__main__:step=20300 loss=0.0000662 acc=1.0000000
|
|
||||||
INFO:__main__:step=20400 loss=0.0000337 acc=1.0000000
|
|
||||||
INFO:__main__:step=20500 loss=0.0000238 acc=1.0000000
|
|
||||||
INFO:__main__:step=20600 loss=0.0000206 acc=1.0000000
|
|
||||||
INFO:__main__:step=20700 loss=0.0000003 acc=1.0000000
|
|
||||||
INFO:__main__:step=20800 loss=0.0000557 acc=1.0000000
|
|
||||||
INFO:__main__:step=20900 loss=0.0000227 acc=1.0000000
|
|
||||||
INFO:__main__:step=21000 loss=0.0000002 acc=1.0000000
|
|
||||||
INFO:__main__:step=21100 loss=0.0000290 acc=1.0000000
|
|
||||||
INFO:__main__:step=21200 loss=0.0000373 acc=1.0000000
|
|
||||||
INFO:__main__:step=21300 loss=0.0000019 acc=1.0000000
|
|
||||||
INFO:__main__:step=21400 loss=0.0000635 acc=1.0000000
|
|
||||||
INFO:__main__:step=21500 loss=0.0000073 acc=1.0000000
|
|
||||||
INFO:__main__:step=21600 loss=0.0000388 acc=1.0000000
|
|
||||||
INFO:__main__:step=21700 loss=0.0000002 acc=1.0000000
|
|
||||||
INFO:__main__:step=21800 loss=0.0000169 acc=1.0000000
|
|
||||||
INFO:__main__:step=21900 loss=0.0000031 acc=1.0000000
|
|
||||||
INFO:__main__:step=22000 loss=0.0000075 acc=1.0000000
|
|
||||||
INFO:__main__:step=22100 loss=0.0000001 acc=1.0000000
|
|
||||||
INFO:__main__:step=22200 loss=0.0000096 acc=1.0000000
|
|
||||||
INFO:__main__:step=22300 loss=0.0000068 acc=1.0000000
|
|
||||||
INFO:__main__:step=22400 loss=0.0000303 acc=1.0000000
|
|
||||||
INFO:__main__:step=22500 loss=0.0000005 acc=1.0000000
|
|
||||||
INFO:__main__:step=22600 loss=0.0000111 acc=1.0000000
|
|
||||||
INFO:__main__:step=22700 loss=0.0000023 acc=1.0000000
|
|
||||||
INFO:__main__:step=22800 loss=0.0000003 acc=1.0000000
|
|
||||||
INFO:__main__:step=22900 loss=0.0000424 acc=1.0000000
|
|
||||||
INFO:__main__:step=23000 loss=0.0000186 acc=1.0000000
|
|
||||||
INFO:__main__:step=23100 loss=0.0000004 acc=1.0000000
|
|
||||||
INFO:__main__:step=23200 loss=0.0000085 acc=1.0000000
|
|
||||||
INFO:__main__:step=23300 loss=0.0000350 acc=1.0000000
|
|
||||||
INFO:__main__:step=23400 loss=0.0000005 acc=1.0000000
|
|
||||||
INFO:__main__:step=23500 loss=0.0000538 acc=0.9999390
|
|
||||||
INFO:__main__:step=23600 loss=0.0000021 acc=1.0000000
|
|
||||||
INFO:__main__:step=23700 loss=0.0000365 acc=1.0000000
|
|
||||||
INFO:__main__:step=23800 loss=0.0000281 acc=1.0000000
|
|
||||||
INFO:__main__:step=23900 loss=0.0000091 acc=1.0000000
|
|
||||||
INFO:__main__:step=24000 loss=0.0000045 acc=1.0000000
|
|
||||||
INFO:__main__:step=24100 loss=0.0000023 acc=1.0000000
|
|
||||||
INFO:__main__:step=24200 loss=0.0000197 acc=1.0000000
|
|
||||||
INFO:__main__:step=24300 loss=0.0000013 acc=1.0000000
|
|
||||||
INFO:__main__:step=24400 loss=0.0000174 acc=1.0000000
|
|
||||||
INFO:__main__:step=24500 loss=0.0000380 acc=1.0000000
|
|
||||||
INFO:__main__:step=24600 loss=0.0000105 acc=1.0000000
|
|
||||||
INFO:__main__:step=24700 loss=0.0000001 acc=1.0000000
|
|
||||||
INFO:__main__:step=24800 loss=0.0000193 acc=1.0000000
|
|
||||||
INFO:__main__:step=24900 loss=0.0000280 acc=1.0000000
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)")
|
pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)")
|
||||||
|
|||||||
@@ -3,32 +3,35 @@ from torch import nn
|
|||||||
|
|
||||||
# 2) Number "embedding" network: R -> R^d
|
# 2) Number "embedding" network: R -> R^d
|
||||||
class NumberEmbedder(nn.Module):
|
class NumberEmbedder(nn.Module):
|
||||||
def __init__(self, d=8):
|
def __init__(self, d=4, hidden=16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Linear(1, 16),
|
nn.Linear(1, hidden),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(16, d),
|
nn.Linear(hidden, d),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
# 3) Comparator head: takes (ea, eb) -> logit for "a > b"
|
# 3) Comparator head: takes (ea, eb, e) -> logit for "a > b"
|
||||||
class PairwiseComparator(nn.Module):
|
class PairwiseComparator(nn.Module):
|
||||||
def __init__(self, d=8):
|
def __init__(self, d=4, hidden=16, k=1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed = NumberEmbedder(d)
|
self.log_k = nn.Parameter(torch.tensor([k]))
|
||||||
|
self.embed = NumberEmbedder(d, hidden)
|
||||||
self.head = nn.Sequential(
|
self.head = nn.Sequential(
|
||||||
nn.Linear(2 * d + 1, 16),
|
nn.Linear(d, hidden),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(16, 1),
|
nn.Linear(hidden, hidden),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
ea = self.embed(a)
|
# trying to force antisym here: h(a,b)=−h(b,a)
|
||||||
eb = self.embed(b)
|
phi = self.head(self.embed(a-b))
|
||||||
delta_ab = a - b
|
phi_neg = self.head(self.embed(b-a))
|
||||||
x = torch.cat([ea, eb, delta_ab], dim=-1)
|
logit = phi - phi_neg
|
||||||
|
|
||||||
return self.head(x) # logits
|
return (self.log_k ** 2) * logit
|
||||||
@@ -13,12 +13,12 @@ import pairwise_comp_nn as comp_nn
|
|||||||
DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu"
|
DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu"
|
||||||
|
|
||||||
# Valves
|
# Valves
|
||||||
DIMENSIONS = 1
|
DIMENSIONS = 2
|
||||||
TRAIN_STEPS = 10000
|
TRAIN_STEPS = 5000
|
||||||
TRAIN_BATCHSZ = 8192
|
TRAIN_BATCHSZ = 8192
|
||||||
TRAIN_PROGRESS = 100
|
TRAIN_PROGRESS = 10
|
||||||
BATCH_LOWER = -512.0
|
BATCH_LOWER = -100.0
|
||||||
BATCH_UPPER = 512.0
|
BATCH_UPPER = 100.0
|
||||||
DO_VERBOSE_EARLY_TRAIN = False
|
DO_VERBOSE_EARLY_TRAIN = False
|
||||||
MODEL_PATH = "./files/pwcomp.model"
|
MODEL_PATH = "./files/pwcomp.model"
|
||||||
LOGGING_PATH = "./files/output.log"
|
LOGGING_PATH = "./files/output.log"
|
||||||
@@ -46,17 +46,15 @@ def plt_embeddings(model: comp_nn.PairwiseComparator):
|
|||||||
|
|
||||||
for i in range(embeddings.shape[1]):
|
for i in range(embeddings.shape[1]):
|
||||||
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
|
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
|
||||||
plt.legend()
|
|
||||||
plt.savefig(EMBED_CHART_PATH)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig(EMBED_CHART_PATH)
|
||||||
|
#plt.show()
|
||||||
csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist()))
|
csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist()))
|
||||||
with open(file=EMBEDDINGS_DATA, mode="w", newline='') as f:
|
with open(file=EMBEDDINGS_DATA, mode="w", newline='') as f:
|
||||||
csv_file = csv.writer(f)
|
csv_file = csv.writer(f)
|
||||||
csv_file.writerows(csv_data)
|
csv_file.writerows(csv_data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_torch_info():
|
def get_torch_info():
|
||||||
log.info("PyTorch Version: %s", torch.__version__)
|
log.info("PyTorch Version: %s", torch.__version__)
|
||||||
log.info("HIP Version: %s", torch.version.hip)
|
log.info("HIP Version: %s", torch.version.hip)
|
||||||
@@ -78,13 +76,10 @@ def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER):
|
|||||||
a = (high - low) * torch.rand(batch_size, 1) + low
|
a = (high - low) * torch.rand(batch_size, 1) + low
|
||||||
b = (high - low) * torch.rand(batch_size, 1) + low
|
b = (high - low) * torch.rand(batch_size, 1) + low
|
||||||
|
|
||||||
# train for if a > b
|
epsi = 1e-4
|
||||||
y = (a > b).float()
|
y = torch.where(a > b + epsi, 1.0,
|
||||||
|
torch.where(a < b - epsi, 0.0, 0.5))
|
||||||
|
|
||||||
# removed but left for my notes; it seems training for equality hurts classifing results that are ~eq
|
|
||||||
# when trained only on "if a > b => y", the model produces more accurate results when classifing if things are equal (~.5 prob).
|
|
||||||
# eq = (a == b).float()
|
|
||||||
# y = gt + 0.5 * eq
|
|
||||||
return a, b, y
|
return a, b, y
|
||||||
|
|
||||||
def training_entry():
|
def training_entry():
|
||||||
@@ -150,7 +145,7 @@ def infer_entry():
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
probs = torch.sigmoid(model(a, b))
|
probs = torch.sigmoid(model(a, b))
|
||||||
|
|
||||||
log.info(f"Output probabilities for {pairs.count} pairs")
|
log.info(f"Output probabilities for {pairs.__len__()} pairs")
|
||||||
for (x, y), p in zip(pairs, probs):
|
for (x, y), p in zip(pairs, probs):
|
||||||
log.info(f"P({x} > {y}) = {p.item():.3f}")
|
log.info(f"P({x} > {y}) = {p.item():.3f}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user