Compare commits
19 Commits
bf15e7b6c8
...
trunk
| Author | SHA1 | Date | |
|---|---|---|---|
| 9af36e7145 | |||
| 23fddbe5b9 | |||
| acbccebb2c | |||
| edf8d46123 | |||
| 4f500e8b4c | |||
| 921e24b451 | |||
| d46712ff53 | |||
| cd72cd7052 | |||
| 755161c152 | |||
| 1d70935b64 | |||
| 0e2098ceec | |||
| cfef24921d | |||
| 9ea8ef3458 | |||
| 6e31865a84 | |||
| 997303028e | |||
| c3fbc44a34 | |||
| 5e5ad1bc20 | |||
| 0d6a92823a | |||
| cfcec07b9c |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,4 +1,6 @@
|
||||
__pycache__/
|
||||
files/
|
||||
.venv/
|
||||
*.png
|
||||
*.pth
|
||||
*.model
|
||||
*.log
|
||||
|
||||
@@ -25,11 +25,11 @@ source .venv/bin/activate
|
||||
For instructions installing pytorch refer to the [PyTorch Home page]
|
||||
|
||||
```bash
|
||||
pip3 install numpy
|
||||
# use the nvidia CUDA or CPU only packages if required.
|
||||
# I use the ROCm packages, so the repo uses the ROCm packages.
|
||||
pip3 install numpy pandas matplotlib
|
||||
# I use the ROCm packages
|
||||
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4
|
||||
pip3 install pandas matplotlib
|
||||
# if you need the CPU only package
|
||||
# pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
## Running the code
|
||||
|
||||
286
output_graphs.py
286
output_graphs.py
@@ -1,286 +0,0 @@
|
||||
import re
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
text = r"""INFO:__main__:step= 0 loss=1.3149878 acc=0.9093018
|
||||
INFO:__main__:step= 100 loss=0.0061767 acc=0.9990845
|
||||
INFO:__main__:step= 200 loss=0.0035383 acc=0.9995117
|
||||
INFO:__main__:step= 300 loss=0.0024436 acc=0.9998169
|
||||
INFO:__main__:step= 400 loss=0.0018407 acc=0.9997559
|
||||
INFO:__main__:step= 500 loss=0.0015498 acc=0.9999390
|
||||
INFO:__main__:step= 600 loss=0.0011456 acc=0.9998779
|
||||
INFO:__main__:step= 700 loss=0.0012764 acc=0.9998169
|
||||
INFO:__main__:step= 800 loss=0.0009753 acc=0.9999390
|
||||
INFO:__main__:step= 900 loss=0.0008533 acc=0.9998169
|
||||
INFO:__main__:step= 1000 loss=0.0009977 acc=0.9998779
|
||||
INFO:__main__:step= 1100 loss=0.0006450 acc=1.0000000
|
||||
INFO:__main__:step= 1200 loss=0.0009882 acc=0.9999390
|
||||
INFO:__main__:step= 1300 loss=0.0007516 acc=0.9998169
|
||||
INFO:__main__:step= 1400 loss=0.0006381 acc=0.9999390
|
||||
INFO:__main__:step= 1500 loss=0.0005117 acc=1.0000000
|
||||
INFO:__main__:step= 1600 loss=0.0006124 acc=0.9999390
|
||||
INFO:__main__:step= 1700 loss=0.0005856 acc=0.9999390
|
||||
INFO:__main__:step= 1800 loss=0.0004643 acc=1.0000000
|
||||
INFO:__main__:step= 1900 loss=0.0005289 acc=1.0000000
|
||||
INFO:__main__:step= 2000 loss=0.0006825 acc=0.9998779
|
||||
INFO:__main__:step= 2100 loss=0.0006999 acc=0.9999390
|
||||
INFO:__main__:step= 2200 loss=0.0004827 acc=0.9999390
|
||||
INFO:__main__:step= 2300 loss=0.0004767 acc=1.0000000
|
||||
INFO:__main__:step= 2400 loss=0.0005529 acc=0.9998779
|
||||
INFO:__main__:step= 2500 loss=0.0005013 acc=0.9999390
|
||||
INFO:__main__:step= 2600 loss=0.0004871 acc=0.9998779
|
||||
INFO:__main__:step= 2700 loss=0.0006320 acc=0.9999390
|
||||
INFO:__main__:step= 2800 loss=0.0005142 acc=1.0000000
|
||||
INFO:__main__:step= 2900 loss=0.0004747 acc=0.9999390
|
||||
INFO:__main__:step= 3000 loss=0.0003393 acc=1.0000000
|
||||
INFO:__main__:step= 3100 loss=0.0002169 acc=0.9999390
|
||||
INFO:__main__:step= 3200 loss=0.0004685 acc=1.0000000
|
||||
INFO:__main__:step= 3300 loss=0.0006188 acc=1.0000000
|
||||
INFO:__main__:step= 3400 loss=0.0002341 acc=1.0000000
|
||||
INFO:__main__:step= 3500 loss=0.0002824 acc=1.0000000
|
||||
INFO:__main__:step= 3600 loss=0.0004709 acc=1.0000000
|
||||
INFO:__main__:step= 3700 loss=0.0004435 acc=0.9999390
|
||||
INFO:__main__:step= 3800 loss=0.0004570 acc=0.9999390
|
||||
INFO:__main__:step= 3900 loss=0.0002688 acc=1.0000000
|
||||
INFO:__main__:step= 4000 loss=0.0003271 acc=1.0000000
|
||||
INFO:__main__:step= 4100 loss=0.0003988 acc=0.9999390
|
||||
INFO:__main__:step= 4200 loss=0.0002737 acc=1.0000000
|
||||
INFO:__main__:step= 4300 loss=0.0004687 acc=1.0000000
|
||||
INFO:__main__:step= 4400 loss=0.0004002 acc=0.9999390
|
||||
INFO:__main__:step= 4500 loss=0.0003822 acc=0.9998779
|
||||
INFO:__main__:step= 4600 loss=0.0002028 acc=1.0000000
|
||||
INFO:__main__:step= 4700 loss=0.0003800 acc=1.0000000
|
||||
INFO:__main__:step= 4800 loss=0.0003447 acc=1.0000000
|
||||
INFO:__main__:step= 4900 loss=0.0003252 acc=0.9999390
|
||||
INFO:__main__:step= 5000 loss=0.0003673 acc=0.9999390
|
||||
INFO:__main__:step= 5100 loss=0.0002339 acc=1.0000000
|
||||
INFO:__main__:step= 5200 loss=0.0003250 acc=0.9999390
|
||||
INFO:__main__:step= 5300 loss=0.0002328 acc=1.0000000
|
||||
INFO:__main__:step= 5400 loss=0.0003054 acc=1.0000000
|
||||
INFO:__main__:step= 5500 loss=0.0003867 acc=1.0000000
|
||||
INFO:__main__:step= 5600 loss=0.0004573 acc=0.9999390
|
||||
INFO:__main__:step= 5700 loss=0.0005030 acc=0.9999390
|
||||
INFO:__main__:step= 5800 loss=0.0002435 acc=1.0000000
|
||||
INFO:__main__:step= 5900 loss=0.0003278 acc=1.0000000
|
||||
INFO:__main__:step= 6000 loss=0.0003972 acc=0.9999390
|
||||
INFO:__main__:step= 6100 loss=0.0004111 acc=0.9999390
|
||||
INFO:__main__:step= 6200 loss=0.0004616 acc=0.9998779
|
||||
INFO:__main__:step= 6300 loss=0.0002833 acc=1.0000000
|
||||
INFO:__main__:step= 6400 loss=0.0001403 acc=1.0000000
|
||||
INFO:__main__:step= 6500 loss=0.0002128 acc=1.0000000
|
||||
INFO:__main__:step= 6600 loss=0.0003678 acc=1.0000000
|
||||
INFO:__main__:step= 6700 loss=0.0003675 acc=0.9999390
|
||||
INFO:__main__:step= 6800 loss=0.0003362 acc=0.9999390
|
||||
INFO:__main__:step= 6900 loss=0.0002456 acc=1.0000000
|
||||
INFO:__main__:step= 7000 loss=0.0002907 acc=1.0000000
|
||||
INFO:__main__:step= 7100 loss=0.0002552 acc=1.0000000
|
||||
INFO:__main__:step= 7200 loss=0.0003215 acc=1.0000000
|
||||
INFO:__main__:step= 7300 loss=0.0002414 acc=0.9999390
|
||||
INFO:__main__:step= 7400 loss=0.0002210 acc=0.9999390
|
||||
INFO:__main__:step= 7500 loss=0.0003406 acc=1.0000000
|
||||
INFO:__main__:step= 7600 loss=0.0003976 acc=0.9999390
|
||||
INFO:__main__:step= 7700 loss=0.0001889 acc=1.0000000
|
||||
INFO:__main__:step= 7800 loss=0.0001913 acc=1.0000000
|
||||
INFO:__main__:step= 7900 loss=0.0002028 acc=0.9999390
|
||||
INFO:__main__:step= 8000 loss=0.0002912 acc=1.0000000
|
||||
INFO:__main__:step= 8100 loss=0.0001934 acc=1.0000000
|
||||
INFO:__main__:step= 8200 loss=0.0001729 acc=0.9999390
|
||||
INFO:__main__:step= 8300 loss=0.0002534 acc=1.0000000
|
||||
INFO:__main__:step= 8400 loss=0.0002508 acc=1.0000000
|
||||
INFO:__main__:step= 8500 loss=0.0002167 acc=1.0000000
|
||||
INFO:__main__:step= 8600 loss=0.0001678 acc=0.9999390
|
||||
INFO:__main__:step= 8700 loss=0.0001330 acc=1.0000000
|
||||
INFO:__main__:step= 8800 loss=0.0002283 acc=1.0000000
|
||||
INFO:__main__:step= 8900 loss=0.0001854 acc=1.0000000
|
||||
INFO:__main__:step= 9000 loss=0.0003707 acc=1.0000000
|
||||
INFO:__main__:step= 9100 loss=0.0001784 acc=0.9999390
|
||||
INFO:__main__:step= 9200 loss=0.0002114 acc=1.0000000
|
||||
INFO:__main__:step= 9300 loss=0.0002016 acc=1.0000000
|
||||
INFO:__main__:step= 9400 loss=0.0001510 acc=1.0000000
|
||||
INFO:__main__:step= 9500 loss=0.0002751 acc=1.0000000
|
||||
INFO:__main__:step= 9600 loss=0.0001933 acc=1.0000000
|
||||
INFO:__main__:step= 9700 loss=0.0002801 acc=0.9999390
|
||||
INFO:__main__:step= 9800 loss=0.0002744 acc=1.0000000
|
||||
INFO:__main__:step= 9900 loss=0.0002888 acc=0.9998779
|
||||
INFO:__main__:step=10000 loss=0.0002251 acc=0.9999390
|
||||
INFO:__main__:step=10100 loss=0.0002925 acc=0.9999390
|
||||
INFO:__main__:step=10200 loss=0.0002304 acc=1.0000000
|
||||
INFO:__main__:step=10300 loss=0.0002787 acc=0.9999390
|
||||
INFO:__main__:step=10400 loss=0.0002299 acc=0.9999390
|
||||
INFO:__main__:step=10500 loss=0.0002260 acc=1.0000000
|
||||
INFO:__main__:step=10600 loss=0.0002000 acc=1.0000000
|
||||
INFO:__main__:step=10700 loss=0.0002608 acc=0.9999390
|
||||
INFO:__main__:step=10800 loss=0.0002861 acc=1.0000000
|
||||
INFO:__main__:step=10900 loss=0.0001996 acc=1.0000000
|
||||
INFO:__main__:step=11000 loss=0.0002830 acc=0.9999390
|
||||
INFO:__main__:step=11100 loss=0.0002845 acc=0.9999390
|
||||
INFO:__main__:step=11200 loss=0.0001409 acc=1.0000000
|
||||
INFO:__main__:step=11300 loss=0.0001962 acc=0.9999390
|
||||
INFO:__main__:step=11400 loss=0.0002022 acc=1.0000000
|
||||
INFO:__main__:step=11500 loss=0.0003032 acc=1.0000000
|
||||
INFO:__main__:step=11600 loss=0.0002062 acc=1.0000000
|
||||
INFO:__main__:step=11700 loss=0.0002120 acc=1.0000000
|
||||
INFO:__main__:step=11800 loss=0.0001484 acc=1.0000000
|
||||
INFO:__main__:step=11900 loss=0.0001639 acc=1.0000000
|
||||
INFO:__main__:step=12000 loss=0.0001864 acc=1.0000000
|
||||
INFO:__main__:step=12100 loss=0.0002334 acc=1.0000000
|
||||
INFO:__main__:step=12200 loss=0.0001641 acc=1.0000000
|
||||
INFO:__main__:step=12300 loss=0.0003251 acc=0.9998779
|
||||
INFO:__main__:step=12400 loss=0.0002605 acc=1.0000000
|
||||
INFO:__main__:step=12500 loss=0.0001344 acc=1.0000000
|
||||
INFO:__main__:step=12600 loss=0.0002226 acc=0.9999390
|
||||
INFO:__main__:step=12700 loss=0.0002189 acc=0.9999390
|
||||
INFO:__main__:step=12800 loss=0.0001012 acc=1.0000000
|
||||
INFO:__main__:step=12900 loss=0.0001505 acc=1.0000000
|
||||
INFO:__main__:step=13000 loss=0.0002257 acc=0.9999390
|
||||
INFO:__main__:step=13100 loss=0.0001643 acc=1.0000000
|
||||
INFO:__main__:step=13200 loss=0.0001547 acc=0.9999390
|
||||
INFO:__main__:step=13300 loss=0.0002164 acc=1.0000000
|
||||
INFO:__main__:step=13400 loss=0.0001538 acc=1.0000000
|
||||
INFO:__main__:step=13500 loss=0.0001582 acc=1.0000000
|
||||
INFO:__main__:step=13600 loss=0.0002629 acc=0.9999390
|
||||
INFO:__main__:step=13700 loss=0.0002293 acc=1.0000000
|
||||
INFO:__main__:step=13800 loss=0.0001947 acc=1.0000000
|
||||
INFO:__main__:step=13900 loss=0.0001451 acc=1.0000000
|
||||
INFO:__main__:step=14000 loss=0.0002371 acc=1.0000000
|
||||
INFO:__main__:step=14100 loss=0.0003281 acc=1.0000000
|
||||
INFO:__main__:step=14200 loss=0.0002205 acc=1.0000000
|
||||
INFO:__main__:step=14300 loss=0.0001904 acc=1.0000000
|
||||
INFO:__main__:step=14400 loss=0.0001126 acc=1.0000000
|
||||
INFO:__main__:step=14500 loss=0.0002144 acc=1.0000000
|
||||
INFO:__main__:step=14600 loss=0.0001922 acc=1.0000000
|
||||
INFO:__main__:step=14700 loss=0.0002118 acc=1.0000000
|
||||
INFO:__main__:step=14800 loss=0.0001527 acc=1.0000000
|
||||
INFO:__main__:step=14900 loss=0.0000752 acc=1.0000000
|
||||
INFO:__main__:step=15000 loss=0.0002345 acc=1.0000000
|
||||
INFO:__main__:step=15100 loss=0.0002119 acc=1.0000000
|
||||
INFO:__main__:step=15200 loss=0.0001223 acc=1.0000000
|
||||
INFO:__main__:step=15300 loss=0.0000772 acc=1.0000000
|
||||
INFO:__main__:step=15400 loss=0.0001805 acc=0.9999390
|
||||
INFO:__main__:step=15500 loss=0.0003057 acc=1.0000000
|
||||
INFO:__main__:step=15600 loss=0.0002293 acc=1.0000000
|
||||
INFO:__main__:step=15700 loss=0.0000739 acc=1.0000000
|
||||
INFO:__main__:step=15800 loss=0.0001586 acc=1.0000000
|
||||
INFO:__main__:step=15900 loss=0.0001513 acc=1.0000000
|
||||
INFO:__main__:step=16000 loss=0.0001348 acc=1.0000000
|
||||
INFO:__main__:step=16100 loss=0.0002099 acc=1.0000000
|
||||
INFO:__main__:step=16200 loss=0.0001405 acc=1.0000000
|
||||
INFO:__main__:step=16300 loss=0.0003015 acc=0.9998779
|
||||
INFO:__main__:step=16400 loss=0.0000603 acc=1.0000000
|
||||
INFO:__main__:step=16500 loss=0.0001273 acc=1.0000000
|
||||
INFO:__main__:step=16600 loss=0.0001151 acc=1.0000000
|
||||
INFO:__main__:step=16700 loss=0.0001440 acc=1.0000000
|
||||
INFO:__main__:step=16800 loss=0.0002359 acc=1.0000000
|
||||
INFO:__main__:step=16900 loss=0.0002146 acc=1.0000000
|
||||
INFO:__main__:step=17000 loss=0.0002382 acc=1.0000000
|
||||
INFO:__main__:step=17100 loss=0.0000885 acc=1.0000000
|
||||
INFO:__main__:step=17200 loss=0.0002271 acc=0.9999390
|
||||
INFO:__main__:step=17300 loss=0.0000785 acc=1.0000000
|
||||
INFO:__main__:step=17400 loss=0.0002242 acc=1.0000000
|
||||
INFO:__main__:step=17500 loss=0.0001646 acc=1.0000000
|
||||
INFO:__main__:step=17600 loss=0.0001174 acc=1.0000000
|
||||
INFO:__main__:step=17700 loss=0.0001843 acc=0.9999390
|
||||
INFO:__main__:step=17800 loss=0.0001872 acc=1.0000000
|
||||
INFO:__main__:step=17900 loss=0.0001122 acc=1.0000000
|
||||
INFO:__main__:step=18000 loss=0.0000516 acc=1.0000000
|
||||
INFO:__main__:step=18100 loss=0.0001427 acc=1.0000000
|
||||
INFO:__main__:step=18200 loss=0.0000453 acc=1.0000000
|
||||
INFO:__main__:step=18300 loss=0.0001730 acc=1.0000000
|
||||
INFO:__main__:step=18400 loss=0.0001801 acc=1.0000000
|
||||
INFO:__main__:step=18500 loss=0.0001219 acc=1.0000000
|
||||
INFO:__main__:step=18600 loss=0.0001443 acc=1.0000000
|
||||
INFO:__main__:step=18700 loss=0.0003240 acc=1.0000000
|
||||
INFO:__main__:step=18800 loss=0.0001341 acc=1.0000000
|
||||
INFO:__main__:step=18900 loss=0.0000698 acc=1.0000000
|
||||
INFO:__main__:step=19000 loss=0.0002490 acc=1.0000000
|
||||
INFO:__main__:step=19100 loss=0.0002027 acc=1.0000000
|
||||
INFO:__main__:step=19200 loss=0.0001338 acc=0.9999390
|
||||
INFO:__main__:step=19300 loss=0.0001596 acc=1.0000000
|
||||
INFO:__main__:step=19400 loss=0.0001416 acc=0.9999390
|
||||
INFO:__main__:step=19500 loss=0.0001592 acc=1.0000000
|
||||
INFO:__main__:step=19600 loss=0.0002262 acc=1.0000000
|
||||
INFO:__main__:step=19700 loss=0.0000626 acc=1.0000000
|
||||
INFO:__main__:step=19800 loss=0.0001256 acc=1.0000000
|
||||
INFO:__main__:step=19900 loss=0.0002005 acc=0.9999390
|
||||
INFO:__main__:step=20000 loss=0.0000979 acc=1.0000000
|
||||
INFO:__main__:step=20100 loss=0.0002766 acc=1.0000000
|
||||
INFO:__main__:step=20200 loss=0.0003364 acc=0.9999390
|
||||
INFO:__main__:step=20300 loss=0.0001628 acc=1.0000000
|
||||
INFO:__main__:step=20400 loss=0.0002390 acc=1.0000000
|
||||
INFO:__main__:step=20500 loss=0.0001474 acc=1.0000000
|
||||
INFO:__main__:step=20600 loss=0.0001439 acc=1.0000000
|
||||
INFO:__main__:step=20700 loss=0.0000553 acc=1.0000000
|
||||
INFO:__main__:step=20800 loss=0.0001755 acc=0.9999390
|
||||
INFO:__main__:step=20900 loss=0.0000641 acc=1.0000000
|
||||
INFO:__main__:step=21000 loss=0.0000668 acc=1.0000000
|
||||
INFO:__main__:step=21100 loss=0.0002183 acc=1.0000000
|
||||
INFO:__main__:step=21200 loss=0.0001400 acc=1.0000000
|
||||
INFO:__main__:step=21300 loss=0.0001134 acc=1.0000000
|
||||
INFO:__main__:step=21400 loss=0.0002051 acc=0.9999390
|
||||
INFO:__main__:step=21500 loss=0.0001587 acc=1.0000000
|
||||
INFO:__main__:step=21600 loss=0.0002183 acc=1.0000000
|
||||
INFO:__main__:step=21700 loss=0.0000929 acc=1.0000000
|
||||
INFO:__main__:step=21800 loss=0.0001406 acc=1.0000000
|
||||
INFO:__main__:step=21900 loss=0.0001177 acc=1.0000000
|
||||
INFO:__main__:step=22000 loss=0.0000872 acc=1.0000000
|
||||
INFO:__main__:step=22100 loss=0.0000580 acc=1.0000000
|
||||
INFO:__main__:step=22200 loss=0.0000653 acc=1.0000000
|
||||
INFO:__main__:step=22300 loss=0.0001202 acc=1.0000000
|
||||
INFO:__main__:step=22400 loss=0.0002056 acc=1.0000000
|
||||
INFO:__main__:step=22500 loss=0.0001006 acc=1.0000000
|
||||
INFO:__main__:step=22600 loss=0.0001436 acc=1.0000000
|
||||
INFO:__main__:step=22700 loss=0.0001289 acc=1.0000000
|
||||
INFO:__main__:step=22800 loss=0.0000839 acc=1.0000000
|
||||
INFO:__main__:step=22900 loss=0.0001841 acc=1.0000000
|
||||
INFO:__main__:step=23000 loss=0.0000884 acc=1.0000000
|
||||
INFO:__main__:step=23100 loss=0.0000641 acc=1.0000000
|
||||
INFO:__main__:step=23200 loss=0.0001370 acc=1.0000000
|
||||
INFO:__main__:step=23300 loss=0.0002339 acc=0.9998779
|
||||
INFO:__main__:step=23400 loss=0.0001042 acc=1.0000000
|
||||
INFO:__main__:step=23500 loss=0.0001897 acc=1.0000000
|
||||
INFO:__main__:step=23600 loss=0.0001677 acc=1.0000000
|
||||
INFO:__main__:step=23700 loss=0.0001252 acc=1.0000000
|
||||
INFO:__main__:step=23800 loss=0.0001060 acc=1.0000000
|
||||
INFO:__main__:step=23900 loss=0.0001251 acc=1.0000000
|
||||
INFO:__main__:step=24000 loss=0.0001638 acc=1.0000000
|
||||
INFO:__main__:step=24100 loss=0.0001202 acc=1.0000000
|
||||
INFO:__main__:step=24200 loss=0.0001683 acc=1.0000000
|
||||
INFO:__main__:step=24300 loss=0.0000737 acc=1.0000000
|
||||
INFO:__main__:step=24400 loss=0.0001864 acc=1.0000000
|
||||
INFO:__main__:step=24500 loss=0.0001836 acc=1.0000000
|
||||
INFO:__main__:step=24600 loss=0.0001927 acc=1.0000000
|
||||
INFO:__main__:step=24700 loss=0.0000401 acc=1.0000000
|
||||
INFO:__main__:step=24800 loss=0.0002011 acc=1.0000000
|
||||
INFO:__main__:step=24900 loss=0.0002700 acc=0.9998779
|
||||
"""
|
||||
|
||||
pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)")
|
||||
rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)]
|
||||
df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True)
|
||||
|
||||
# Avoid log(0) issues for loss plot by clamping at a tiny positive value
|
||||
eps = 1e-10
|
||||
df["loss_clamped"] = df["loss"].clip(lower=eps)
|
||||
|
||||
# Plot 1: Loss
|
||||
plt.figure(figsize=(9, 4.8))
|
||||
plt.plot(df["step"], df["loss_clamped"])
|
||||
plt.yscale("log")
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Loss (log scale)")
|
||||
plt.title("Training Loss vs Step")
|
||||
plt.tight_layout()
|
||||
plt.savefig('training_loss_v_step.png')
|
||||
plt.show()
|
||||
|
||||
# Plot 2: Accuracy
|
||||
df["err"] = (1.0 - df["acc"]).clip(lower=eps)
|
||||
plt.figure(figsize=(9, 4.8))
|
||||
plt.plot(df["step"], df["err"])
|
||||
plt.yscale("log")
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Error rate (1 - accuracy) (log scale)")
|
||||
plt.title("Training Error Rate vs Step")
|
||||
plt.tight_layout()
|
||||
plt.savefig('training_error_v_step.png')
|
||||
plt.show()
|
||||
37
pairwise_comp_nn.py
Normal file
37
pairwise_comp_nn.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# 2) Number "embedding" network: R -> R^d
|
||||
class NumberEmbedder(nn.Module):
|
||||
def __init__(self, d=2, hidden=4):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(1, hidden),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden, d),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# MLP Comparator head: takes (ea, eb, e) -> logit for "a > b"
|
||||
class PairwiseComparator(nn.Module):
|
||||
def __init__(self, d=2, hidden=4, k=0.5):
|
||||
super().__init__()
|
||||
self.log_k = nn.Parameter(torch.tensor([k]))
|
||||
self.embed = NumberEmbedder(d, hidden)
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(d, hidden),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden, hidden),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden, 1),
|
||||
)
|
||||
|
||||
def forward(self, a, b):
|
||||
# trying to force antisym here: h(a,b)=-h(b,a)
|
||||
phi = self.head(self.embed(a-b))
|
||||
phi_neg = self.head(self.embed(b-a))
|
||||
logit = phi - phi_neg
|
||||
|
||||
return (self.log_k ** 2) * logit
|
||||
@@ -2,23 +2,126 @@
|
||||
# pairwise_compare.py
|
||||
import logging
|
||||
import random
|
||||
import lzma
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import re
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import pairwise_comp_nn as comp_nn
|
||||
|
||||
# early pytorch device setup
|
||||
DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu"
|
||||
|
||||
# Valves
|
||||
DIMENSIONS = 1
|
||||
TRAIN_STEPS = 25000
|
||||
TRAIN_BATCHSZ = 16384
|
||||
TRAIN_PROGRESS = 100
|
||||
BATCH_LOWER = -512.0
|
||||
BATCH_UPPER = 512.0
|
||||
DO_VERBOSE_EARLY_TRAIN = False
|
||||
DIMENSIONS = 2
|
||||
HIDDEN_NEURONS = 4
|
||||
ADAMW_LR = 5e-3
|
||||
ADAMW_DECAY = 5e-4
|
||||
TRAIN_STEPS = 2000
|
||||
TRAIN_BATCHSZ = 8192
|
||||
TRAIN_PROGRESS = 10
|
||||
BATCH_LOWER = -100.0
|
||||
BATCH_UPPER = 100.0
|
||||
|
||||
def get_torch_info():
|
||||
# Files
|
||||
MODEL_PATH = "./files/pwcomp.model"
|
||||
LOGGING_PATH = "./files/output.log"
|
||||
EMBED_CHART_PATH = "./files/embedding_chart.png"
|
||||
EMBEDDINGS_DATA_PATH = "./files/embedding_data.csv"
|
||||
TRAINING_LOG_PATH = "./files/training.log.xz"
|
||||
LOSS_CHART_PATH = "./files/training_loss_v_step.png"
|
||||
ACC_CHART_PATH = "./files/training_error_v_step.png"
|
||||
|
||||
# TODO: Move plotting into its own file
|
||||
def parse_training_log(file_path: str) -> pd.DataFrame:
|
||||
text: str = ""
|
||||
with lzma.open(file_path, mode='rt') as f:
|
||||
text = f.read()
|
||||
|
||||
pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)")
|
||||
rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)]
|
||||
df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True)
|
||||
|
||||
# Avoid log(0) issues for loss plot by clamping at a tiny positive value
|
||||
eps = 1e-10
|
||||
df["loss_clamped"] = df["loss"].clip(lower=eps)
|
||||
|
||||
return df
|
||||
|
||||
# TODO: Move plotting into its own file
|
||||
def plt_loss_tstep(df: pd.DataFrame) -> None:
|
||||
# Plot 1: Loss
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(df["step"], df["loss_clamped"])
|
||||
plt.yscale("log")
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Loss (log scale)")
|
||||
plt.title("Training Loss vs Step")
|
||||
plt.tight_layout()
|
||||
plt.savefig(LOSS_CHART_PATH)
|
||||
plt.close()
|
||||
|
||||
return None
|
||||
|
||||
# TODO: Move plotting into its own file
|
||||
def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None:
|
||||
# Plot 2: Accuracy
|
||||
df["err"] = (1.0 - df["acc"]).clip(lower=eps)
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(df["step"], df["err"])
|
||||
plt.yscale("log")
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Error rate (1 - accuracy) (log scale)")
|
||||
plt.title("Training Error Rate vs Step")
|
||||
plt.tight_layout()
|
||||
plt.savefig(ACC_CHART_PATH)
|
||||
plt.close()
|
||||
|
||||
return None
|
||||
|
||||
# TODO: Move plotting into its own file
|
||||
def plt_embeddings(model: comp_nn.PairwiseComparator) -> None:
|
||||
import csv
|
||||
|
||||
log.info("Starting embeddings sweep...")
|
||||
# samples for embedding mapping
|
||||
with torch.no_grad():
|
||||
xs = torch.arange(
|
||||
BATCH_LOWER,
|
||||
BATCH_UPPER + 1.0,
|
||||
0.1,
|
||||
).unsqueeze(1).to(DEVICE) # shape: (N, 1)
|
||||
|
||||
embeddings = model.embed(xs) # shape: (N, d)
|
||||
|
||||
# move data back to CPU for plotting
|
||||
embeddings = embeddings.cpu()
|
||||
xs = xs.cpu()
|
||||
|
||||
# Plot 3: x vs h(x)
|
||||
plt.figure(figsize=(10, 6))
|
||||
for i in range(embeddings.shape[1]):
|
||||
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
|
||||
plt.title("x vs h(x)")
|
||||
plt.xlabel("x [input]")
|
||||
plt.ylabel("h(x) [embedding]")
|
||||
plt.legend()
|
||||
plt.savefig(EMBED_CHART_PATH)
|
||||
plt.close()
|
||||
|
||||
# save all our embeddings data to csv
|
||||
csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist()))
|
||||
with open(file=EMBEDDINGS_DATA_PATH, mode="w", newline='') as f:
|
||||
csv_file = csv.writer(f)
|
||||
csv_file.writerows(csv_data)
|
||||
|
||||
return None
|
||||
|
||||
def get_torch_info() -> None:
|
||||
log.info("PyTorch Version: %s", torch.__version__)
|
||||
log.info("HIP Version: %s", torch.version.hip)
|
||||
log.info("CUDA support: %s", torch.cuda.is_available())
|
||||
@@ -28,93 +131,69 @@ def get_torch_info():
|
||||
|
||||
log.info("Using %s compute mode", DEVICE)
|
||||
|
||||
def set_seed(seed: int):
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# 1) Data: pairs (a, b) with label y = 1 if a > b else 0
|
||||
def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER):
|
||||
# pairs (a, b) with label y = 1 if a > b else 0 -> (a,b,y)
|
||||
# uses epsi to select the window in which a == b for equality training
|
||||
def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER, epsi=1e-4) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
a = (high - low) * torch.rand(batch_size, 1) + low
|
||||
b = (high - low) * torch.rand(batch_size, 1) + low
|
||||
|
||||
# train for if a > b
|
||||
y = (a > b).float()
|
||||
epsi = 1e-4
|
||||
y = torch.where(a > b + epsi, 1.0,
|
||||
torch.where(a < b - epsi, 0.0, 0.5))
|
||||
|
||||
# removed but left for my notes; it seems training for equality hurts classifing results that are ~eq
|
||||
# when trained only on "if a > b => y", the model produces more accurate results when classifing if things are equal (~.5 prob).
|
||||
# eq = (a == b).float()
|
||||
# y = gt + 0.5 * eq
|
||||
return a, b, y
|
||||
|
||||
# 2) Number "embedding" network: R -> R^d
|
||||
class NumberEmbedder(nn.Module):
|
||||
def __init__(self, d=8):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(1, 16),
|
||||
nn.ReLU(),
|
||||
nn.Linear(16, d),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
# 3) Comparator head: takes (ea, eb) -> logit for "a > b"
|
||||
class PairwiseComparator(nn.Module):
|
||||
def __init__(self, d=8):
|
||||
super().__init__()
|
||||
self.embed = NumberEmbedder(d)
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(2 * d + 1, 16),
|
||||
nn.ReLU(),
|
||||
nn.Linear(16, 1),
|
||||
)
|
||||
|
||||
def forward(self, a, b):
|
||||
ea = self.embed(a)
|
||||
eb = self.embed(b)
|
||||
delta_ab = a - b
|
||||
x = torch.cat([ea, eb, delta_ab], dim=-1)
|
||||
|
||||
return self.head(x) # logits
|
||||
|
||||
def training_entry():
|
||||
get_torch_info()
|
||||
|
||||
# all prng seeds to 0 for deterministic outputs durring testing
|
||||
# the seed should initialized normally otherwise
|
||||
set_seed(0)
|
||||
|
||||
model = PairwiseComparator(d=DIMENSIONS).to(DEVICE)
|
||||
# opt = torch.optim.AdamW(model.parameters(), lr=2e-3)
|
||||
opt = torch.optim.Adadelta(model.parameters(), lr=1.0)
|
||||
model = comp_nn.PairwiseComparator(d=DIMENSIONS, hidden=HIDDEN_NEURONS).to(DEVICE)
|
||||
opt = torch.optim.AdamW(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_DECAY)
|
||||
|
||||
# 4) Train
|
||||
for step in range(TRAIN_STEPS):
|
||||
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
||||
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
|
||||
log.info(f"Using {TRAINING_LOG_PATH} as the logging destination for training...")
|
||||
with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog:
|
||||
# training loop
|
||||
training_start_time = datetime.datetime.now()
|
||||
last_ack = datetime.datetime.now()
|
||||
|
||||
logits = model(a, b)
|
||||
loss_fn = F.binary_cross_entropy_with_logits(logits, y)
|
||||
for step in range(TRAIN_STEPS):
|
||||
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
||||
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
|
||||
|
||||
opt.zero_grad()
|
||||
loss_fn.backward()
|
||||
opt.step()
|
||||
logits = model(a, b)
|
||||
loss_fn = F.binary_cross_entropy_with_logits(logits, y)
|
||||
|
||||
opt.zero_grad()
|
||||
loss_fn.backward()
|
||||
opt.step()
|
||||
|
||||
if step <= TRAIN_PROGRESS and DO_VERBOSE_EARLY_TRAIN is True:
|
||||
with torch.no_grad():
|
||||
pred = (torch.sigmoid(logits) > 0.5).float()
|
||||
acc = (pred == y).float().mean().item()
|
||||
log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}")
|
||||
elif step % TRAIN_PROGRESS == 0:
|
||||
with torch.no_grad():
|
||||
pred = (torch.sigmoid(logits) > 0.5).float()
|
||||
acc = (pred == y).float().mean().item()
|
||||
log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}")
|
||||
tlog.write(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}\n")
|
||||
|
||||
# 5) Quick test: evaluate accuracy on fresh pairs
|
||||
# also print to normal text log occasionally to show some activity.
|
||||
# every 10 steps check if its been longer than 5 seconds since we've updated the user
|
||||
if step % 10 == 0:
|
||||
if (datetime.datetime.now() - last_ack).total_seconds() > 5:
|
||||
log.info(f"still training... step={step} of {TRAIN_STEPS}")
|
||||
last_ack = datetime.datetime.now()
|
||||
|
||||
training_end_time = datetime.datetime.now()
|
||||
log.info(f"Training steps complete. Start time: {training_start_time} End time: {training_end_time}")
|
||||
|
||||
# evaluate final model accuracy on fresh pairs
|
||||
with torch.no_grad():
|
||||
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
||||
a, b, y = sample_batch(TRAIN_BATCHSZ*4)
|
||||
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
|
||||
logits = model(a, b)
|
||||
pred = (torch.sigmoid(logits) > 0.5).float()
|
||||
@@ -122,13 +201,15 @@ def training_entry():
|
||||
acc = (pred == y).float().mean().item()
|
||||
log.info(f"Final test acc: {acc} errors: {errors}")
|
||||
|
||||
# embed model depth into the model serialization
|
||||
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, "model.pth")
|
||||
log.info("Saved PyTorch Model State to model.pth")
|
||||
# embed model dimensions into the model serialization
|
||||
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS, "h": HIDDEN_NEURONS}, MODEL_PATH)
|
||||
log.info(f"Saved PyTorch Model State to {MODEL_PATH}")
|
||||
|
||||
def infer_entry():
|
||||
model_ckpt = torch.load("model.pth", map_location=DEVICE)
|
||||
model = PairwiseComparator(d=model_ckpt["d"]).to(DEVICE)
|
||||
get_torch_info()
|
||||
|
||||
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
|
||||
model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE)
|
||||
model.load_state_dict(model_ckpt["state_dict"])
|
||||
model.eval()
|
||||
|
||||
@@ -139,44 +220,104 @@ def infer_entry():
|
||||
b = torch.tensor([[p[1]] for p in pairs], dtype=torch.float32, device=DEVICE)
|
||||
|
||||
# sanity check before inference
|
||||
log.info(f"a.device: {a.device} model.device: {next(model.parameters()).device}")
|
||||
log.debug(f"a.device: {a.device} model.device: {next(model.parameters()).device}")
|
||||
|
||||
with torch.no_grad():
|
||||
probs = torch.sigmoid(model(a, b))
|
||||
|
||||
log.info(f"Output probabilities for {pairs.__len__()} pairs")
|
||||
for (x, y), p in zip(pairs, probs):
|
||||
log.info(f"P({x} > {y}) = {p.item():.3f}")
|
||||
|
||||
|
||||
|
||||
def graphs_entry():
|
||||
get_torch_info()
|
||||
|
||||
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
|
||||
model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE)
|
||||
model.load_state_dict(model_ckpt["state_dict"])
|
||||
model.eval()
|
||||
|
||||
plt_embeddings(model)
|
||||
|
||||
data = parse_training_log(TRAINING_LOG_PATH)
|
||||
plt_loss_tstep(data)
|
||||
plt_acc_tstep(data)
|
||||
|
||||
help_text = r"""
|
||||
pairwise_compare.py — tiny pairwise "a > b?" neural comparator
|
||||
|
||||
USAGE
|
||||
python3 pairwise_compare.py train
|
||||
Train a PairwiseComparator on synthetic (a,b) pairs sampled uniformly from
|
||||
[BATCH_LOWER, BATCH_UPPER]. Labels are:
|
||||
1.0 if a > b + epsi
|
||||
0.0 if a < b - epsi
|
||||
0.5 otherwise (near-equality window)
|
||||
Writes training metrics to:
|
||||
./files/training.log.xz
|
||||
Saves the trained model checkpoint to:
|
||||
./files/pwcomp.model
|
||||
|
||||
python3 pairwise_compare.py infer
|
||||
Load ./files/pwcomp.model and run inference on a built-in list of test pairs.
|
||||
Prints probabilities as:
|
||||
P(a > b) = sigmoid(model(a,b))
|
||||
|
||||
python3 pairwise_compare.py graphs
|
||||
Load ./files/pwcomp.model and generate plots + exports:
|
||||
./files/embedding_chart.png (embed(x) vs x for each embedding dimension)
|
||||
./files/embedding_data.csv (x and embedding vectors)
|
||||
./files/training_loss_v_step.png
|
||||
./files/training_error_v_step.png (1 - acc, log scale)
|
||||
Requires that ./files/training.log.xz exists (i.e., you ran "train" first).
|
||||
|
||||
FILES
|
||||
./files/output.log General runtime log (info/errors)
|
||||
./files/pwcomp.model Torch checkpoint: {"state_dict": ..., "d": DIMENSIONS}
|
||||
./files/training.log.xz step/loss/acc trace used for plots
|
||||
|
||||
NOTES
|
||||
- DEVICE is chosen via torch.accelerator if available, else CPU.
|
||||
- Hyperparameters are controlled by the "Valves" constants near the top.
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
import os
|
||||
import datetime
|
||||
|
||||
# TODO: tidy up the paths to files and checking if the directory exists
|
||||
if not os.path.exists("./files/"):
|
||||
os.mkdir("./files")
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(LOGGING_PATH),
|
||||
logging.StreamHandler(stream=sys.stdout)
|
||||
])
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logging.basicConfig(filename='pairwise_compare.log', level=logging.INFO)
|
||||
log.info(f"Log opened {datetime.datetime.now()}")
|
||||
log.info(f"Log file {LOGGING_PATH} opened {datetime.datetime.now()}")
|
||||
|
||||
get_torch_info()
|
||||
|
||||
name = os.path.basename(sys.argv[0])
|
||||
if name == 'train.py':
|
||||
training_entry()
|
||||
elif name == 'infer.py':
|
||||
infer_entry()
|
||||
else:
|
||||
# alt call patern
|
||||
# python3 pairwise_compare.py train
|
||||
# python3 pairwise_compare.py infer
|
||||
if len(sys.argv) > 1:
|
||||
mode = sys.argv[1].strip().lower()
|
||||
if mode == "train":
|
||||
# python3 pairwise_compare.py train
|
||||
# python3 pairwise_compare.py infer
|
||||
# python3 pairwise_compare.py graphs
|
||||
if len(sys.argv) > 1:
|
||||
match sys.argv[1].strip().lower():
|
||||
case "train":
|
||||
training_entry()
|
||||
elif mode == "infer":
|
||||
case "infer":
|
||||
infer_entry()
|
||||
else:
|
||||
case "graphs":
|
||||
graphs_entry()
|
||||
case "help":
|
||||
log.info(help_text)
|
||||
case mode:
|
||||
log.error(f"Unknown operation: {mode}")
|
||||
log.error("Invalid call syntax, call script as \"train.py\" or \"infer.py\" or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
|
||||
else:
|
||||
log.error("Not enough arguments passed to script; call as train.py or infer.py or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
|
||||
log.error("valid options are one of [\"train\", \"infer\", \"graphs\", \"help\"]")
|
||||
log.info(help_text)
|
||||
|
||||
log.info(f"Log closed {datetime.datetime.now()}")
|
||||
Reference in New Issue
Block a user