Compare commits
21 Commits
5c89149388
...
trunk
| Author | SHA1 | Date | |
|---|---|---|---|
| 9af36e7145 | |||
| 23fddbe5b9 | |||
| acbccebb2c | |||
| edf8d46123 | |||
| 4f500e8b4c | |||
| 921e24b451 | |||
| d46712ff53 | |||
| cd72cd7052 | |||
| 755161c152 | |||
| 1d70935b64 | |||
| 0e2098ceec | |||
| cfef24921d | |||
| 9ea8ef3458 | |||
| 6e31865a84 | |||
| 997303028e | |||
| c3fbc44a34 | |||
| 5e5ad1bc20 | |||
| 0d6a92823a | |||
| cfcec07b9c | |||
| bf15e7b6c8 | |||
| 0385b0acc8 |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,3 +1,6 @@
|
|||||||
|
__pycache__/
|
||||||
|
files/
|
||||||
.venv/
|
.venv/
|
||||||
*.pth
|
*.png
|
||||||
|
*.model
|
||||||
*.log
|
*.log
|
||||||
|
|||||||
@@ -25,11 +25,11 @@ source .venv/bin/activate
|
|||||||
For instructions installing pytorch refer to the [PyTorch Home page]
|
For instructions installing pytorch refer to the [PyTorch Home page]
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install numpy
|
pip3 install numpy pandas matplotlib
|
||||||
# use the nvidia CUDA or CPU only packages if required.
|
# I use the ROCm packages
|
||||||
# I use the ROCm packages, so the repo uses the ROCm packages.
|
|
||||||
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4
|
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4
|
||||||
pip3 install pandas matplotlib
|
# if you need the CPU only package
|
||||||
|
# pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||||
```
|
```
|
||||||
|
|
||||||
## Running the code
|
## Running the code
|
||||||
|
|||||||
573
output_graphs.py
573
output_graphs.py
@@ -1,573 +0,0 @@
|
|||||||
import re
|
|
||||||
import pandas as pd
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
text = r"""INFO:__main__:step= 0 loss=1.3149878 acc=0.9093018
|
|
||||||
INFO:__main__:step= 1 loss=1.0067503 acc=0.9183350
|
|
||||||
INFO:__main__:step= 2 loss=0.7588759 acc=0.9278564
|
|
||||||
INFO:__main__:step= 3 loss=0.5432727 acc=0.9384766
|
|
||||||
INFO:__main__:step= 4 loss=0.3854972 acc=0.9427490
|
|
||||||
INFO:__main__:step= 5 loss=0.2379047 acc=0.9512939
|
|
||||||
INFO:__main__:step= 6 loss=0.1357967 acc=0.9575195
|
|
||||||
INFO:__main__:step= 7 loss=0.0659471 acc=0.9808960
|
|
||||||
INFO:__main__:step= 8 loss=0.0325629 acc=0.9878540
|
|
||||||
INFO:__main__:step= 9 loss=0.0159342 acc=0.9951172
|
|
||||||
INFO:__main__:step= 10 loss=0.0123130 acc=0.9995117
|
|
||||||
INFO:__main__:step= 11 loss=0.0129503 acc=0.9963989
|
|
||||||
INFO:__main__:step= 12 loss=0.0198095 acc=0.9927979
|
|
||||||
INFO:__main__:step= 13 loss=0.0261302 acc=0.9901733
|
|
||||||
INFO:__main__:step= 14 loss=0.0338637 acc=0.9891968
|
|
||||||
INFO:__main__:step= 15 loss=0.0425015 acc=0.9861450
|
|
||||||
INFO:__main__:step= 16 loss=0.0458022 acc=0.9865112
|
|
||||||
INFO:__main__:step= 17 loss=0.0492692 acc=0.9859619
|
|
||||||
INFO:__main__:step= 18 loss=0.0543331 acc=0.9850464
|
|
||||||
INFO:__main__:step= 19 loss=0.0575214 acc=0.9846802
|
|
||||||
INFO:__main__:step= 20 loss=0.0635471 acc=0.9841919
|
|
||||||
INFO:__main__:step= 21 loss=0.0503580 acc=0.9871216
|
|
||||||
INFO:__main__:step= 22 loss=0.0503206 acc=0.9863892
|
|
||||||
INFO:__main__:step= 23 loss=0.0549617 acc=0.9852905
|
|
||||||
INFO:__main__:step= 24 loss=0.0438731 acc=0.9872437
|
|
||||||
INFO:__main__:step= 25 loss=0.0471392 acc=0.9867554
|
|
||||||
INFO:__main__:step= 26 loss=0.0431895 acc=0.9868774
|
|
||||||
INFO:__main__:step= 27 loss=0.0354491 acc=0.9892578
|
|
||||||
INFO:__main__:step= 28 loss=0.0311145 acc=0.9905396
|
|
||||||
INFO:__main__:step= 29 loss=0.0305991 acc=0.9897461
|
|
||||||
INFO:__main__:step= 30 loss=0.0229458 acc=0.9920044
|
|
||||||
INFO:__main__:step= 31 loss=0.0203844 acc=0.9925537
|
|
||||||
INFO:__main__:step= 32 loss=0.0190371 acc=0.9929810
|
|
||||||
INFO:__main__:step= 33 loss=0.0167215 acc=0.9935303
|
|
||||||
INFO:__main__:step= 34 loss=0.0153083 acc=0.9937744
|
|
||||||
INFO:__main__:step= 35 loss=0.0120912 acc=0.9947510
|
|
||||||
INFO:__main__:step= 36 loss=0.0101961 acc=0.9961548
|
|
||||||
INFO:__main__:step= 37 loss=0.0100932 acc=0.9968262
|
|
||||||
INFO:__main__:step= 38 loss=0.0101871 acc=0.9968262
|
|
||||||
INFO:__main__:step= 39 loss=0.0088590 acc=0.9978027
|
|
||||||
INFO:__main__:step= 40 loss=0.0090834 acc=0.9982300
|
|
||||||
INFO:__main__:step= 41 loss=0.0084685 acc=0.9993896
|
|
||||||
INFO:__main__:step= 42 loss=0.0075352 acc=0.9989014
|
|
||||||
INFO:__main__:step= 43 loss=0.0080332 acc=0.9987793
|
|
||||||
INFO:__main__:step= 44 loss=0.0087203 acc=0.9979858
|
|
||||||
INFO:__main__:step= 45 loss=0.0086979 acc=0.9977417
|
|
||||||
INFO:__main__:step= 46 loss=0.0086313 acc=0.9972534
|
|
||||||
INFO:__main__:step= 47 loss=0.0092596 acc=0.9971924
|
|
||||||
INFO:__main__:step= 48 loss=0.0086056 acc=0.9978027
|
|
||||||
INFO:__main__:step= 49 loss=0.0101687 acc=0.9960327
|
|
||||||
INFO:__main__:step= 50 loss=0.0083168 acc=0.9978638
|
|
||||||
INFO:__main__:step= 51 loss=0.0095370 acc=0.9965820
|
|
||||||
INFO:__main__:step= 52 loss=0.0093318 acc=0.9968872
|
|
||||||
INFO:__main__:step= 53 loss=0.0093573 acc=0.9966431
|
|
||||||
INFO:__main__:step= 54 loss=0.0091653 acc=0.9971313
|
|
||||||
INFO:__main__:step= 55 loss=0.0097087 acc=0.9962769
|
|
||||||
INFO:__main__:step= 56 loss=0.0094240 acc=0.9972534
|
|
||||||
INFO:__main__:step= 57 loss=0.0088949 acc=0.9970093
|
|
||||||
INFO:__main__:step= 58 loss=0.0085519 acc=0.9975586
|
|
||||||
INFO:__main__:step= 59 loss=0.0098900 acc=0.9967651
|
|
||||||
INFO:__main__:step= 60 loss=0.0085984 acc=0.9977417
|
|
||||||
INFO:__main__:step= 61 loss=0.0082639 acc=0.9984131
|
|
||||||
INFO:__main__:step= 62 loss=0.0083590 acc=0.9979858
|
|
||||||
INFO:__main__:step= 63 loss=0.0072852 acc=0.9990234
|
|
||||||
INFO:__main__:step= 64 loss=0.0080845 acc=0.9983521
|
|
||||||
INFO:__main__:step= 65 loss=0.0074006 acc=0.9988403
|
|
||||||
INFO:__main__:step= 66 loss=0.0067656 acc=0.9990845
|
|
||||||
INFO:__main__:step= 67 loss=0.0072893 acc=0.9992065
|
|
||||||
INFO:__main__:step= 68 loss=0.0077519 acc=0.9991455
|
|
||||||
INFO:__main__:step= 69 loss=0.0072619 acc=0.9992065
|
|
||||||
INFO:__main__:step= 70 loss=0.0069199 acc=0.9993896
|
|
||||||
INFO:__main__:step= 71 loss=0.0070979 acc=0.9992676
|
|
||||||
INFO:__main__:step= 72 loss=0.0073295 acc=0.9992676
|
|
||||||
INFO:__main__:step= 73 loss=0.0074595 acc=0.9994507
|
|
||||||
INFO:__main__:step= 74 loss=0.0069510 acc=0.9993896
|
|
||||||
INFO:__main__:step= 75 loss=0.0072545 acc=0.9993896
|
|
||||||
INFO:__main__:step= 76 loss=0.0070704 acc=0.9994507
|
|
||||||
INFO:__main__:step= 77 loss=0.0080527 acc=0.9992065
|
|
||||||
INFO:__main__:step= 78 loss=0.0078222 acc=0.9990234
|
|
||||||
INFO:__main__:step= 79 loss=0.0074403 acc=0.9993896
|
|
||||||
INFO:__main__:step= 80 loss=0.0070136 acc=0.9992065
|
|
||||||
INFO:__main__:step= 81 loss=0.0068662 acc=0.9992676
|
|
||||||
INFO:__main__:step= 82 loss=0.0071025 acc=0.9992676
|
|
||||||
INFO:__main__:step= 83 loss=0.0068718 acc=0.9993896
|
|
||||||
INFO:__main__:step= 84 loss=0.0072839 acc=0.9991455
|
|
||||||
INFO:__main__:step= 85 loss=0.0071138 acc=0.9992676
|
|
||||||
INFO:__main__:step= 86 loss=0.0070522 acc=0.9992065
|
|
||||||
INFO:__main__:step= 87 loss=0.0073670 acc=0.9992065
|
|
||||||
INFO:__main__:step= 88 loss=0.0069347 acc=0.9995117
|
|
||||||
INFO:__main__:step= 89 loss=0.0066190 acc=0.9993286
|
|
||||||
INFO:__main__:step= 90 loss=0.0069868 acc=0.9997559
|
|
||||||
INFO:__main__:step= 91 loss=0.0071491 acc=0.9995728
|
|
||||||
INFO:__main__:step= 92 loss=0.0071107 acc=0.9991455
|
|
||||||
INFO:__main__:step= 93 loss=0.0065876 acc=0.9996338
|
|
||||||
INFO:__main__:step= 94 loss=0.0067994 acc=0.9995117
|
|
||||||
INFO:__main__:step= 95 loss=0.0072922 acc=0.9995728
|
|
||||||
INFO:__main__:step= 96 loss=0.0070875 acc=0.9995117
|
|
||||||
INFO:__main__:step= 97 loss=0.0067269 acc=0.9996338
|
|
||||||
INFO:__main__:step= 98 loss=0.0069193 acc=0.9995728
|
|
||||||
INFO:__main__:step= 99 loss=0.0068322 acc=0.9995728
|
|
||||||
INFO:__main__:step= 100 loss=0.0065636 acc=0.9995117
|
|
||||||
INFO:__main__:step= 101 loss=0.0072776 acc=0.9995117
|
|
||||||
INFO:__main__:step= 102 loss=0.0074568 acc=0.9996338
|
|
||||||
INFO:__main__:step= 103 loss=0.0067574 acc=0.9996338
|
|
||||||
INFO:__main__:step= 104 loss=0.0067869 acc=0.9995728
|
|
||||||
INFO:__main__:step= 105 loss=0.0064811 acc=0.9996948
|
|
||||||
INFO:__main__:step= 106 loss=0.0077020 acc=0.9990234
|
|
||||||
INFO:__main__:step= 107 loss=0.0069068 acc=0.9996338
|
|
||||||
INFO:__main__:step= 108 loss=0.0068666 acc=0.9994507
|
|
||||||
INFO:__main__:step= 109 loss=0.0066869 acc=0.9992065
|
|
||||||
INFO:__main__:step= 110 loss=0.0070414 acc=0.9993286
|
|
||||||
INFO:__main__:step= 111 loss=0.0072129 acc=0.9994507
|
|
||||||
INFO:__main__:step= 112 loss=0.0067772 acc=0.9993896
|
|
||||||
INFO:__main__:step= 113 loss=0.0069291 acc=0.9995728
|
|
||||||
INFO:__main__:step= 114 loss=0.0065003 acc=0.9996338
|
|
||||||
INFO:__main__:step= 115 loss=0.0071769 acc=0.9993286
|
|
||||||
INFO:__main__:step= 116 loss=0.0068795 acc=0.9993896
|
|
||||||
INFO:__main__:step= 117 loss=0.0068955 acc=0.9995117
|
|
||||||
INFO:__main__:step= 118 loss=0.0071484 acc=0.9993896
|
|
||||||
INFO:__main__:step= 119 loss=0.0072527 acc=0.9992676
|
|
||||||
INFO:__main__:step= 120 loss=0.0070072 acc=0.9996948
|
|
||||||
INFO:__main__:step= 121 loss=0.0069677 acc=0.9995117
|
|
||||||
INFO:__main__:step= 122 loss=0.0070129 acc=0.9996338
|
|
||||||
INFO:__main__:step= 123 loss=0.0068947 acc=0.9990845
|
|
||||||
INFO:__main__:step= 124 loss=0.0066504 acc=0.9996338
|
|
||||||
INFO:__main__:step= 125 loss=0.0060845 acc=0.9994507
|
|
||||||
INFO:__main__:step= 126 loss=0.0063553 acc=0.9995117
|
|
||||||
INFO:__main__:step= 127 loss=0.0063697 acc=0.9996948
|
|
||||||
INFO:__main__:step= 128 loss=0.0064274 acc=0.9995728
|
|
||||||
INFO:__main__:step= 129 loss=0.0057200 acc=0.9998169
|
|
||||||
INFO:__main__:step= 130 loss=0.0064411 acc=0.9996948
|
|
||||||
INFO:__main__:step= 131 loss=0.0060729 acc=0.9998779
|
|
||||||
INFO:__main__:step= 132 loss=0.0062524 acc=0.9996948
|
|
||||||
INFO:__main__:step= 133 loss=0.0060491 acc=0.9991455
|
|
||||||
INFO:__main__:step= 134 loss=0.0065517 acc=0.9996338
|
|
||||||
INFO:__main__:step= 135 loss=0.0058316 acc=0.9998169
|
|
||||||
INFO:__main__:step= 136 loss=0.0059018 acc=0.9997559
|
|
||||||
INFO:__main__:step= 137 loss=0.0065307 acc=0.9996948
|
|
||||||
INFO:__main__:step= 138 loss=0.0064329 acc=0.9995117
|
|
||||||
INFO:__main__:step= 139 loss=0.0066867 acc=0.9995728
|
|
||||||
INFO:__main__:step= 140 loss=0.0070452 acc=0.9992676
|
|
||||||
INFO:__main__:step= 141 loss=0.0064459 acc=0.9993896
|
|
||||||
INFO:__main__:step= 142 loss=0.0061031 acc=0.9997559
|
|
||||||
INFO:__main__:step= 143 loss=0.0067490 acc=0.9992676
|
|
||||||
INFO:__main__:step= 144 loss=0.0065364 acc=0.9998779
|
|
||||||
INFO:__main__:step= 145 loss=0.0058783 acc=0.9993286
|
|
||||||
INFO:__main__:step= 146 loss=0.0060647 acc=0.9998169
|
|
||||||
INFO:__main__:step= 147 loss=0.0064641 acc=0.9996338
|
|
||||||
INFO:__main__:step= 148 loss=0.0065979 acc=0.9996338
|
|
||||||
INFO:__main__:step= 149 loss=0.0058760 acc=0.9997559
|
|
||||||
INFO:__main__:step= 150 loss=0.0058440 acc=0.9998779
|
|
||||||
INFO:__main__:step= 151 loss=0.0060483 acc=0.9998169
|
|
||||||
INFO:__main__:step= 152 loss=0.0062729 acc=0.9998169
|
|
||||||
INFO:__main__:step= 153 loss=0.0061169 acc=0.9996338
|
|
||||||
INFO:__main__:step= 154 loss=0.0064641 acc=0.9995728
|
|
||||||
INFO:__main__:step= 155 loss=0.0062200 acc=0.9995117
|
|
||||||
INFO:__main__:step= 156 loss=0.0062547 acc=0.9997559
|
|
||||||
INFO:__main__:step= 157 loss=0.0064813 acc=0.9995728
|
|
||||||
INFO:__main__:step= 158 loss=0.0061465 acc=0.9996338
|
|
||||||
INFO:__main__:step= 159 loss=0.0056798 acc=0.9995728
|
|
||||||
INFO:__main__:step= 160 loss=0.0064472 acc=0.9996338
|
|
||||||
INFO:__main__:step= 161 loss=0.0055585 acc=0.9997559
|
|
||||||
INFO:__main__:step= 162 loss=0.0057363 acc=0.9995117
|
|
||||||
INFO:__main__:step= 163 loss=0.0061281 acc=0.9995728
|
|
||||||
INFO:__main__:step= 164 loss=0.0062811 acc=0.9995117
|
|
||||||
INFO:__main__:step= 165 loss=0.0058447 acc=0.9997559
|
|
||||||
INFO:__main__:step= 166 loss=0.0054476 acc=0.9998169
|
|
||||||
INFO:__main__:step= 167 loss=0.0067439 acc=0.9998779
|
|
||||||
INFO:__main__:step= 168 loss=0.0063363 acc=0.9992676
|
|
||||||
INFO:__main__:step= 169 loss=0.0060227 acc=0.9995117
|
|
||||||
INFO:__main__:step= 170 loss=0.0060772 acc=0.9998169
|
|
||||||
INFO:__main__:step= 171 loss=0.0060449 acc=0.9994507
|
|
||||||
INFO:__main__:step= 172 loss=0.0061583 acc=0.9997559
|
|
||||||
INFO:__main__:step= 173 loss=0.0061034 acc=0.9995728
|
|
||||||
INFO:__main__:step= 174 loss=0.0060934 acc=0.9996948
|
|
||||||
INFO:__main__:step= 175 loss=0.0062173 acc=0.9995117
|
|
||||||
INFO:__main__:step= 176 loss=0.0065296 acc=0.9995728
|
|
||||||
INFO:__main__:step= 177 loss=0.0067725 acc=0.9994507
|
|
||||||
INFO:__main__:step= 178 loss=0.0064399 acc=0.9997559
|
|
||||||
INFO:__main__:step= 179 loss=0.0060964 acc=0.9997559
|
|
||||||
INFO:__main__:step= 180 loss=0.0053458 acc=0.9995728
|
|
||||||
INFO:__main__:step= 181 loss=0.0056739 acc=0.9996338
|
|
||||||
INFO:__main__:step= 182 loss=0.0054851 acc=0.9996338
|
|
||||||
INFO:__main__:step= 183 loss=0.0064169 acc=0.9997559
|
|
||||||
INFO:__main__:step= 184 loss=0.0058324 acc=0.9996948
|
|
||||||
INFO:__main__:step= 185 loss=0.0058293 acc=0.9994507
|
|
||||||
INFO:__main__:step= 186 loss=0.0067926 acc=0.9996338
|
|
||||||
INFO:__main__:step= 187 loss=0.0053189 acc=0.9994507
|
|
||||||
INFO:__main__:step= 188 loss=0.0059640 acc=0.9998169
|
|
||||||
INFO:__main__:step= 189 loss=0.0055875 acc=0.9996338
|
|
||||||
INFO:__main__:step= 190 loss=0.0057980 acc=0.9995117
|
|
||||||
INFO:__main__:step= 191 loss=0.0057864 acc=0.9996338
|
|
||||||
INFO:__main__:step= 192 loss=0.0052436 acc=0.9999390
|
|
||||||
INFO:__main__:step= 193 loss=0.0063448 acc=0.9994507
|
|
||||||
INFO:__main__:step= 194 loss=0.0056429 acc=0.9996948
|
|
||||||
INFO:__main__:step= 195 loss=0.0057287 acc=0.9993286
|
|
||||||
INFO:__main__:step= 196 loss=0.0058553 acc=0.9998169
|
|
||||||
INFO:__main__:step= 197 loss=0.0058331 acc=0.9992676
|
|
||||||
INFO:__main__:step= 198 loss=0.0052676 acc=0.9995117
|
|
||||||
INFO:__main__:step= 199 loss=0.0058648 acc=0.9996948
|
|
||||||
INFO:__main__:step= 200 loss=0.0061336 acc=0.9998169
|
|
||||||
INFO:__main__:step= 201 loss=0.0053004 acc=0.9998169
|
|
||||||
INFO:__main__:step= 202 loss=0.0061638 acc=0.9997559
|
|
||||||
INFO:__main__:step= 203 loss=0.0065025 acc=0.9995117
|
|
||||||
INFO:__main__:step= 204 loss=0.0059157 acc=0.9998169
|
|
||||||
INFO:__main__:step= 205 loss=0.0054036 acc=0.9998779
|
|
||||||
INFO:__main__:step= 206 loss=0.0056138 acc=0.9996338
|
|
||||||
INFO:__main__:step= 207 loss=0.0054218 acc=0.9996338
|
|
||||||
INFO:__main__:step= 208 loss=0.0066803 acc=0.9992676
|
|
||||||
INFO:__main__:step= 209 loss=0.0058715 acc=0.9996948
|
|
||||||
INFO:__main__:step= 210 loss=0.0055717 acc=0.9998169
|
|
||||||
INFO:__main__:step= 211 loss=0.0060626 acc=0.9994507
|
|
||||||
INFO:__main__:step= 212 loss=0.0051552 acc=0.9998779
|
|
||||||
INFO:__main__:step= 213 loss=0.0059606 acc=0.9995117
|
|
||||||
INFO:__main__:step= 214 loss=0.0049876 acc=0.9998779
|
|
||||||
INFO:__main__:step= 215 loss=0.0057766 acc=0.9996338
|
|
||||||
INFO:__main__:step= 216 loss=0.0060298 acc=0.9996338
|
|
||||||
INFO:__main__:step= 217 loss=0.0055154 acc=0.9996338
|
|
||||||
INFO:__main__:step= 218 loss=0.0055578 acc=0.9998169
|
|
||||||
INFO:__main__:step= 219 loss=0.0052298 acc=0.9995728
|
|
||||||
INFO:__main__:step= 220 loss=0.0064167 acc=0.9995117
|
|
||||||
INFO:__main__:step= 221 loss=0.0056415 acc=0.9996338
|
|
||||||
INFO:__main__:step= 222 loss=0.0052709 acc=0.9995728
|
|
||||||
INFO:__main__:step= 223 loss=0.0060531 acc=0.9995728
|
|
||||||
INFO:__main__:step= 224 loss=0.0049705 acc=0.9995117
|
|
||||||
INFO:__main__:step= 225 loss=0.0053588 acc=0.9997559
|
|
||||||
INFO:__main__:step= 226 loss=0.0052681 acc=0.9998779
|
|
||||||
INFO:__main__:step= 227 loss=0.0053794 acc=0.9996948
|
|
||||||
INFO:__main__:step= 228 loss=0.0056610 acc=0.9997559
|
|
||||||
INFO:__main__:step= 229 loss=0.0057926 acc=0.9995728
|
|
||||||
INFO:__main__:step= 230 loss=0.0053835 acc=0.9998169
|
|
||||||
INFO:__main__:step= 231 loss=0.0045384 acc=0.9996948
|
|
||||||
INFO:__main__:step= 232 loss=0.0048465 acc=0.9997559
|
|
||||||
INFO:__main__:step= 233 loss=0.0057317 acc=0.9995117
|
|
||||||
INFO:__main__:step= 234 loss=0.0058601 acc=0.9995117
|
|
||||||
INFO:__main__:step= 235 loss=0.0050546 acc=0.9997559
|
|
||||||
INFO:__main__:step= 236 loss=0.0049725 acc=1.0000000
|
|
||||||
INFO:__main__:step= 237 loss=0.0048411 acc=0.9995728
|
|
||||||
INFO:__main__:step= 238 loss=0.0054285 acc=0.9998169
|
|
||||||
INFO:__main__:step= 239 loss=0.0051857 acc=0.9997559
|
|
||||||
INFO:__main__:step= 240 loss=0.0054422 acc=0.9996948
|
|
||||||
INFO:__main__:step= 241 loss=0.0054665 acc=0.9999390
|
|
||||||
INFO:__main__:step= 242 loss=0.0049092 acc=0.9996948
|
|
||||||
INFO:__main__:step= 243 loss=0.0050945 acc=0.9995728
|
|
||||||
INFO:__main__:step= 244 loss=0.0057586 acc=0.9992676
|
|
||||||
INFO:__main__:step= 245 loss=0.0055652 acc=0.9997559
|
|
||||||
INFO:__main__:step= 246 loss=0.0052052 acc=0.9996948
|
|
||||||
INFO:__main__:step= 247 loss=0.0050273 acc=0.9996948
|
|
||||||
INFO:__main__:step= 248 loss=0.0055413 acc=0.9994507
|
|
||||||
INFO:__main__:step= 249 loss=0.0050175 acc=0.9998169
|
|
||||||
INFO:__main__:step= 250 loss=0.0049932 acc=0.9998169
|
|
||||||
INFO:__main__:step= 251 loss=0.0050813 acc=0.9997559
|
|
||||||
INFO:__main__:step= 252 loss=0.0051735 acc=0.9998169
|
|
||||||
INFO:__main__:step= 253 loss=0.0049972 acc=0.9998169
|
|
||||||
INFO:__main__:step= 254 loss=0.0051353 acc=0.9996338
|
|
||||||
INFO:__main__:step= 255 loss=0.0052580 acc=0.9996338
|
|
||||||
INFO:__main__:step= 256 loss=0.0047390 acc=0.9996948
|
|
||||||
INFO:__main__:step= 257 loss=0.0051188 acc=0.9997559
|
|
||||||
INFO:__main__:step= 258 loss=0.0046280 acc=0.9997559
|
|
||||||
INFO:__main__:step= 259 loss=0.0047725 acc=0.9997559
|
|
||||||
INFO:__main__:step= 260 loss=0.0045635 acc=0.9996338
|
|
||||||
INFO:__main__:step= 261 loss=0.0053751 acc=0.9995728
|
|
||||||
INFO:__main__:step= 262 loss=0.0046179 acc=0.9998779
|
|
||||||
INFO:__main__:step= 263 loss=0.0052214 acc=0.9994507
|
|
||||||
INFO:__main__:step= 264 loss=0.0044471 acc=0.9996948
|
|
||||||
INFO:__main__:step= 265 loss=0.0053352 acc=0.9997559
|
|
||||||
INFO:__main__:step= 266 loss=0.0053903 acc=0.9998779
|
|
||||||
INFO:__main__:step= 267 loss=0.0050240 acc=0.9996948
|
|
||||||
INFO:__main__:step= 268 loss=0.0048375 acc=0.9996948
|
|
||||||
INFO:__main__:step= 269 loss=0.0055583 acc=0.9995117
|
|
||||||
INFO:__main__:step= 270 loss=0.0047512 acc=1.0000000
|
|
||||||
INFO:__main__:step= 271 loss=0.0045881 acc=0.9998779
|
|
||||||
INFO:__main__:step= 272 loss=0.0045754 acc=0.9996338
|
|
||||||
INFO:__main__:step= 273 loss=0.0051980 acc=0.9996338
|
|
||||||
INFO:__main__:step= 274 loss=0.0048658 acc=0.9997559
|
|
||||||
INFO:__main__:step= 275 loss=0.0049445 acc=0.9996338
|
|
||||||
INFO:__main__:step= 276 loss=0.0051786 acc=0.9998779
|
|
||||||
INFO:__main__:step= 277 loss=0.0045741 acc=0.9997559
|
|
||||||
INFO:__main__:step= 278 loss=0.0054667 acc=0.9995728
|
|
||||||
INFO:__main__:step= 279 loss=0.0049829 acc=0.9998169
|
|
||||||
INFO:__main__:step= 280 loss=0.0049411 acc=0.9995117
|
|
||||||
INFO:__main__:step= 281 loss=0.0051374 acc=0.9998779
|
|
||||||
INFO:__main__:step= 282 loss=0.0046585 acc=0.9996338
|
|
||||||
INFO:__main__:step= 283 loss=0.0049610 acc=0.9997559
|
|
||||||
INFO:__main__:step= 284 loss=0.0044862 acc=0.9998779
|
|
||||||
INFO:__main__:step= 285 loss=0.0050020 acc=0.9996948
|
|
||||||
INFO:__main__:step= 286 loss=0.0043370 acc=0.9999390
|
|
||||||
INFO:__main__:step= 287 loss=0.0051258 acc=0.9996338
|
|
||||||
INFO:__main__:step= 288 loss=0.0043914 acc=0.9998779
|
|
||||||
INFO:__main__:step= 289 loss=0.0043173 acc=0.9996948
|
|
||||||
INFO:__main__:step= 290 loss=0.0047723 acc=0.9998779
|
|
||||||
INFO:__main__:step= 291 loss=0.0049867 acc=0.9996948
|
|
||||||
INFO:__main__:step= 292 loss=0.0046409 acc=0.9998169
|
|
||||||
INFO:__main__:step= 293 loss=0.0039387 acc=1.0000000
|
|
||||||
INFO:__main__:step= 294 loss=0.0051583 acc=0.9994507
|
|
||||||
INFO:__main__:step= 295 loss=0.0043871 acc=0.9998169
|
|
||||||
INFO:__main__:step= 296 loss=0.0046067 acc=0.9997559
|
|
||||||
INFO:__main__:step= 297 loss=0.0043712 acc=0.9996948
|
|
||||||
INFO:__main__:step= 298 loss=0.0047873 acc=0.9995117
|
|
||||||
INFO:__main__:step= 299 loss=0.0047681 acc=0.9998779
|
|
||||||
INFO:__main__:step= 300 loss=0.0050251 acc=0.9997559
|
|
||||||
INFO:__main__:step= 301 loss=0.0045633 acc=0.9998169
|
|
||||||
INFO:__main__:step= 302 loss=0.0043650 acc=0.9999390
|
|
||||||
INFO:__main__:step= 303 loss=0.0049224 acc=0.9997559
|
|
||||||
INFO:__main__:step= 304 loss=0.0049409 acc=0.9998169
|
|
||||||
INFO:__main__:step= 305 loss=0.0052364 acc=0.9996948
|
|
||||||
INFO:__main__:step= 306 loss=0.0051323 acc=0.9995728
|
|
||||||
INFO:__main__:step= 307 loss=0.0047827 acc=0.9997559
|
|
||||||
INFO:__main__:step= 308 loss=0.0044320 acc=0.9997559
|
|
||||||
INFO:__main__:step= 309 loss=0.0050638 acc=0.9997559
|
|
||||||
INFO:__main__:step= 310 loss=0.0052357 acc=0.9997559
|
|
||||||
INFO:__main__:step= 311 loss=0.0044889 acc=0.9995728
|
|
||||||
INFO:__main__:step= 312 loss=0.0043897 acc=0.9997559
|
|
||||||
INFO:__main__:step= 313 loss=0.0050018 acc=0.9998779
|
|
||||||
INFO:__main__:step= 314 loss=0.0043691 acc=0.9998169
|
|
||||||
INFO:__main__:step= 315 loss=0.0042810 acc=0.9999390
|
|
||||||
INFO:__main__:step= 316 loss=0.0045961 acc=0.9996338
|
|
||||||
INFO:__main__:step= 317 loss=0.0047868 acc=0.9998169
|
|
||||||
INFO:__main__:step= 318 loss=0.0047044 acc=0.9998169
|
|
||||||
INFO:__main__:step= 319 loss=0.0046348 acc=0.9998169
|
|
||||||
INFO:__main__:step= 320 loss=0.0043805 acc=0.9998779
|
|
||||||
INFO:__main__:step= 321 loss=0.0041868 acc=0.9997559
|
|
||||||
INFO:__main__:step= 322 loss=0.0045165 acc=0.9997559
|
|
||||||
INFO:__main__:step= 323 loss=0.0050955 acc=0.9997559
|
|
||||||
INFO:__main__:step= 324 loss=0.0046222 acc=0.9996948
|
|
||||||
INFO:__main__:step= 325 loss=0.0045308 acc=0.9996338
|
|
||||||
INFO:__main__:step= 326 loss=0.0052673 acc=0.9993896
|
|
||||||
INFO:__main__:step= 327 loss=0.0049429 acc=0.9995728
|
|
||||||
INFO:__main__:step= 328 loss=0.0042904 acc=0.9997559
|
|
||||||
INFO:__main__:step= 329 loss=0.0047324 acc=0.9996948
|
|
||||||
INFO:__main__:step= 330 loss=0.0049455 acc=0.9996338
|
|
||||||
INFO:__main__:step= 331 loss=0.0042475 acc=0.9998779
|
|
||||||
INFO:__main__:step= 332 loss=0.0045173 acc=0.9996948
|
|
||||||
INFO:__main__:step= 333 loss=0.0047958 acc=0.9996338
|
|
||||||
INFO:__main__:step= 334 loss=0.0045062 acc=0.9996948
|
|
||||||
INFO:__main__:step= 335 loss=0.0043826 acc=0.9996338
|
|
||||||
INFO:__main__:step= 336 loss=0.0047664 acc=0.9998169
|
|
||||||
INFO:__main__:step= 337 loss=0.0048210 acc=0.9997559
|
|
||||||
INFO:__main__:step= 338 loss=0.0048734 acc=0.9996338
|
|
||||||
INFO:__main__:step= 339 loss=0.0038904 acc=0.9998169
|
|
||||||
INFO:__main__:step= 340 loss=0.0046818 acc=0.9993896
|
|
||||||
INFO:__main__:step= 341 loss=0.0050763 acc=0.9995728
|
|
||||||
INFO:__main__:step= 342 loss=0.0042104 acc=0.9998779
|
|
||||||
INFO:__main__:step= 343 loss=0.0045529 acc=0.9996338
|
|
||||||
INFO:__main__:step= 344 loss=0.0044936 acc=0.9997559
|
|
||||||
INFO:__main__:step= 345 loss=0.0045376 acc=0.9998779
|
|
||||||
INFO:__main__:step= 346 loss=0.0046967 acc=0.9998169
|
|
||||||
INFO:__main__:step= 347 loss=0.0042559 acc=0.9995728
|
|
||||||
INFO:__main__:step= 348 loss=0.0049302 acc=0.9993286
|
|
||||||
INFO:__main__:step= 349 loss=0.0045091 acc=0.9996948
|
|
||||||
INFO:__main__:step= 350 loss=0.0043695 acc=0.9996948
|
|
||||||
INFO:__main__:step= 351 loss=0.0048456 acc=0.9995117
|
|
||||||
INFO:__main__:step= 352 loss=0.0043468 acc=0.9996948
|
|
||||||
INFO:__main__:step= 353 loss=0.0044966 acc=0.9998779
|
|
||||||
INFO:__main__:step= 354 loss=0.0042396 acc=0.9998169
|
|
||||||
INFO:__main__:step= 355 loss=0.0048033 acc=0.9994507
|
|
||||||
INFO:__main__:step= 356 loss=0.0043947 acc=0.9997559
|
|
||||||
INFO:__main__:step= 357 loss=0.0042306 acc=0.9996338
|
|
||||||
INFO:__main__:step= 358 loss=0.0040438 acc=0.9998169
|
|
||||||
INFO:__main__:step= 359 loss=0.0046469 acc=0.9997559
|
|
||||||
INFO:__main__:step= 360 loss=0.0044782 acc=0.9998169
|
|
||||||
INFO:__main__:step= 361 loss=0.0045767 acc=0.9995117
|
|
||||||
INFO:__main__:step= 362 loss=0.0044509 acc=0.9998169
|
|
||||||
INFO:__main__:step= 363 loss=0.0049030 acc=0.9996948
|
|
||||||
INFO:__main__:step= 364 loss=0.0042743 acc=0.9998779
|
|
||||||
INFO:__main__:step= 365 loss=0.0043227 acc=0.9998779
|
|
||||||
INFO:__main__:step= 366 loss=0.0045902 acc=0.9996338
|
|
||||||
INFO:__main__:step= 367 loss=0.0040066 acc=1.0000000
|
|
||||||
INFO:__main__:step= 368 loss=0.0041235 acc=0.9998779
|
|
||||||
INFO:__main__:step= 369 loss=0.0042298 acc=0.9998779
|
|
||||||
INFO:__main__:step= 370 loss=0.0041855 acc=0.9996948
|
|
||||||
INFO:__main__:step= 371 loss=0.0042327 acc=0.9999390
|
|
||||||
INFO:__main__:step= 372 loss=0.0045756 acc=0.9998169
|
|
||||||
INFO:__main__:step= 373 loss=0.0041091 acc=0.9999390
|
|
||||||
INFO:__main__:step= 374 loss=0.0046337 acc=0.9995728
|
|
||||||
INFO:__main__:step= 375 loss=0.0048172 acc=0.9996338
|
|
||||||
INFO:__main__:step= 376 loss=0.0049001 acc=0.9997559
|
|
||||||
INFO:__main__:step= 377 loss=0.0047032 acc=0.9998779
|
|
||||||
INFO:__main__:step= 378 loss=0.0039463 acc=0.9997559
|
|
||||||
INFO:__main__:step= 379 loss=0.0043672 acc=0.9997559
|
|
||||||
INFO:__main__:step= 380 loss=0.0043338 acc=0.9995728
|
|
||||||
INFO:__main__:step= 381 loss=0.0044772 acc=0.9996948
|
|
||||||
INFO:__main__:step= 382 loss=0.0040340 acc=0.9997559
|
|
||||||
INFO:__main__:step= 383 loss=0.0042726 acc=0.9997559
|
|
||||||
INFO:__main__:step= 384 loss=0.0042994 acc=0.9998779
|
|
||||||
INFO:__main__:step= 385 loss=0.0040821 acc=0.9996948
|
|
||||||
INFO:__main__:step= 386 loss=0.0036824 acc=0.9998779
|
|
||||||
INFO:__main__:step= 387 loss=0.0049368 acc=0.9998169
|
|
||||||
INFO:__main__:step= 388 loss=0.0040065 acc=0.9997559
|
|
||||||
INFO:__main__:step= 389 loss=0.0043364 acc=0.9998779
|
|
||||||
INFO:__main__:step= 390 loss=0.0042396 acc=0.9997559
|
|
||||||
INFO:__main__:step= 391 loss=0.0041326 acc=0.9999390
|
|
||||||
INFO:__main__:step= 392 loss=0.0045820 acc=0.9998169
|
|
||||||
INFO:__main__:step= 393 loss=0.0045091 acc=0.9998779
|
|
||||||
INFO:__main__:step= 394 loss=0.0041117 acc=0.9998169
|
|
||||||
INFO:__main__:step= 395 loss=0.0037867 acc=0.9996948
|
|
||||||
INFO:__main__:step= 396 loss=0.0041869 acc=0.9998169
|
|
||||||
INFO:__main__:step= 397 loss=0.0041473 acc=0.9996948
|
|
||||||
INFO:__main__:step= 398 loss=0.0045813 acc=0.9995728
|
|
||||||
INFO:__main__:step= 399 loss=0.0037894 acc=0.9998779
|
|
||||||
INFO:__main__:step= 400 loss=0.0042846 acc=0.9998169
|
|
||||||
INFO:__main__:step= 401 loss=0.0039527 acc=0.9996338
|
|
||||||
INFO:__main__:step= 402 loss=0.0042966 acc=0.9998169
|
|
||||||
INFO:__main__:step= 403 loss=0.0034902 acc=0.9998169
|
|
||||||
INFO:__main__:step= 404 loss=0.0039595 acc=0.9998169
|
|
||||||
INFO:__main__:step= 405 loss=0.0036952 acc=0.9998779
|
|
||||||
INFO:__main__:step= 406 loss=0.0039154 acc=0.9998779
|
|
||||||
INFO:__main__:step= 407 loss=0.0041118 acc=0.9998169
|
|
||||||
INFO:__main__:step= 408 loss=0.0037084 acc=0.9997559
|
|
||||||
INFO:__main__:step= 409 loss=0.0038251 acc=0.9998779
|
|
||||||
INFO:__main__:step= 410 loss=0.0042379 acc=0.9996948
|
|
||||||
INFO:__main__:step= 411 loss=0.0044292 acc=0.9996338
|
|
||||||
INFO:__main__:step= 412 loss=0.0048427 acc=1.0000000
|
|
||||||
INFO:__main__:step= 413 loss=0.0044034 acc=0.9996948
|
|
||||||
INFO:__main__:step= 414 loss=0.0039434 acc=0.9998779
|
|
||||||
INFO:__main__:step= 415 loss=0.0036064 acc=0.9999390
|
|
||||||
INFO:__main__:step= 416 loss=0.0043831 acc=0.9998169
|
|
||||||
INFO:__main__:step= 417 loss=0.0036476 acc=0.9999390
|
|
||||||
INFO:__main__:step= 418 loss=0.0039738 acc=0.9998779
|
|
||||||
INFO:__main__:step= 419 loss=0.0043787 acc=0.9996948
|
|
||||||
INFO:__main__:step= 420 loss=0.0042686 acc=0.9996948
|
|
||||||
INFO:__main__:step= 421 loss=0.0045272 acc=0.9998169
|
|
||||||
INFO:__main__:step= 422 loss=0.0043030 acc=0.9998779
|
|
||||||
INFO:__main__:step= 423 loss=0.0043391 acc=0.9996338
|
|
||||||
INFO:__main__:step= 424 loss=0.0037387 acc=0.9997559
|
|
||||||
INFO:__main__:step= 425 loss=0.0039780 acc=0.9998779
|
|
||||||
INFO:__main__:step= 426 loss=0.0041521 acc=0.9998779
|
|
||||||
INFO:__main__:step= 427 loss=0.0044029 acc=0.9996948
|
|
||||||
INFO:__main__:step= 428 loss=0.0037069 acc=0.9996948
|
|
||||||
INFO:__main__:step= 429 loss=0.0039751 acc=0.9997559
|
|
||||||
INFO:__main__:step= 430 loss=0.0036961 acc=0.9999390
|
|
||||||
INFO:__main__:step= 431 loss=0.0041230 acc=0.9996948
|
|
||||||
INFO:__main__:step= 432 loss=0.0037204 acc=0.9998169
|
|
||||||
INFO:__main__:step= 433 loss=0.0041116 acc=0.9998169
|
|
||||||
INFO:__main__:step= 434 loss=0.0043973 acc=0.9998779
|
|
||||||
INFO:__main__:step= 435 loss=0.0042320 acc=0.9995728
|
|
||||||
INFO:__main__:step= 436 loss=0.0048164 acc=0.9995728
|
|
||||||
INFO:__main__:step= 437 loss=0.0038487 acc=0.9996948
|
|
||||||
INFO:__main__:step= 438 loss=0.0044019 acc=0.9996948
|
|
||||||
INFO:__main__:step= 439 loss=0.0036137 acc=0.9997559
|
|
||||||
INFO:__main__:step= 440 loss=0.0040031 acc=0.9997559
|
|
||||||
INFO:__main__:step= 441 loss=0.0037439 acc=1.0000000
|
|
||||||
INFO:__main__:step= 442 loss=0.0047578 acc=0.9997559
|
|
||||||
INFO:__main__:step= 443 loss=0.0037242 acc=1.0000000
|
|
||||||
INFO:__main__:step= 444 loss=0.0035887 acc=0.9998169
|
|
||||||
INFO:__main__:step= 445 loss=0.0040946 acc=0.9998779
|
|
||||||
INFO:__main__:step= 446 loss=0.0038853 acc=0.9998779
|
|
||||||
INFO:__main__:step= 447 loss=0.0045090 acc=0.9997559
|
|
||||||
INFO:__main__:step= 448 loss=0.0035975 acc=0.9998779
|
|
||||||
INFO:__main__:step= 449 loss=0.0039571 acc=0.9998169
|
|
||||||
INFO:__main__:step= 450 loss=0.0033344 acc=0.9998779
|
|
||||||
INFO:__main__:step= 451 loss=0.0038010 acc=0.9996338
|
|
||||||
INFO:__main__:step= 452 loss=0.0037623 acc=0.9999390
|
|
||||||
INFO:__main__:step= 453 loss=0.0039248 acc=0.9998779
|
|
||||||
INFO:__main__:step= 454 loss=0.0038458 acc=0.9996948
|
|
||||||
INFO:__main__:step= 455 loss=0.0038334 acc=0.9997559
|
|
||||||
INFO:__main__:step= 456 loss=0.0039828 acc=0.9997559
|
|
||||||
INFO:__main__:step= 457 loss=0.0037553 acc=0.9997559
|
|
||||||
INFO:__main__:step= 458 loss=0.0034900 acc=0.9998169
|
|
||||||
INFO:__main__:step= 459 loss=0.0039507 acc=0.9997559
|
|
||||||
INFO:__main__:step= 460 loss=0.0040168 acc=0.9997559
|
|
||||||
INFO:__main__:step= 461 loss=0.0039682 acc=0.9996338
|
|
||||||
INFO:__main__:step= 462 loss=0.0033216 acc=1.0000000
|
|
||||||
INFO:__main__:step= 463 loss=0.0035992 acc=0.9998779
|
|
||||||
INFO:__main__:step= 464 loss=0.0035374 acc=0.9998169
|
|
||||||
INFO:__main__:step= 465 loss=0.0039145 acc=0.9998169
|
|
||||||
INFO:__main__:step= 466 loss=0.0036567 acc=0.9996338
|
|
||||||
INFO:__main__:step= 467 loss=0.0035094 acc=0.9998169
|
|
||||||
INFO:__main__:step= 468 loss=0.0038882 acc=0.9996338
|
|
||||||
INFO:__main__:step= 469 loss=0.0037108 acc=0.9998169
|
|
||||||
INFO:__main__:step= 470 loss=0.0034592 acc=0.9997559
|
|
||||||
INFO:__main__:step= 471 loss=0.0031493 acc=0.9999390
|
|
||||||
INFO:__main__:step= 472 loss=0.0038837 acc=0.9996948
|
|
||||||
INFO:__main__:step= 473 loss=0.0035162 acc=0.9997559
|
|
||||||
INFO:__main__:step= 474 loss=0.0037461 acc=0.9998169
|
|
||||||
INFO:__main__:step= 475 loss=0.0037000 acc=1.0000000
|
|
||||||
INFO:__main__:step= 476 loss=0.0042010 acc=0.9998169
|
|
||||||
INFO:__main__:step= 477 loss=0.0034492 acc=0.9998779
|
|
||||||
INFO:__main__:step= 478 loss=0.0035461 acc=1.0000000
|
|
||||||
INFO:__main__:step= 479 loss=0.0035708 acc=0.9998169
|
|
||||||
INFO:__main__:step= 480 loss=0.0040509 acc=0.9998169
|
|
||||||
INFO:__main__:step= 481 loss=0.0033549 acc=0.9996948
|
|
||||||
INFO:__main__:step= 482 loss=0.0035421 acc=1.0000000
|
|
||||||
INFO:__main__:step= 483 loss=0.0032111 acc=0.9998779
|
|
||||||
INFO:__main__:step= 484 loss=0.0039644 acc=0.9997559
|
|
||||||
INFO:__main__:step= 485 loss=0.0042124 acc=0.9996338
|
|
||||||
INFO:__main__:step= 486 loss=0.0040706 acc=0.9994507
|
|
||||||
INFO:__main__:step= 487 loss=0.0035862 acc=0.9998169
|
|
||||||
INFO:__main__:step= 488 loss=0.0044468 acc=0.9996338
|
|
||||||
INFO:__main__:step= 489 loss=0.0036269 acc=0.9999390
|
|
||||||
INFO:__main__:step= 490 loss=0.0031157 acc=0.9998169
|
|
||||||
INFO:__main__:step= 491 loss=0.0037752 acc=0.9996948
|
|
||||||
INFO:__main__:step= 492 loss=0.0033704 acc=0.9997559
|
|
||||||
INFO:__main__:step= 493 loss=0.0040347 acc=0.9996948
|
|
||||||
INFO:__main__:step= 494 loss=0.0034053 acc=0.9998169
|
|
||||||
INFO:__main__:step= 495 loss=0.0036225 acc=0.9998779
|
|
||||||
INFO:__main__:step= 496 loss=0.0042589 acc=0.9997559
|
|
||||||
INFO:__main__:step= 497 loss=0.0040456 acc=0.9996338
|
|
||||||
INFO:__main__:step= 498 loss=0.0035443 acc=0.9997559
|
|
||||||
INFO:__main__:step= 499 loss=0.0035868 acc=1.0000000
|
|
||||||
INFO:__main__:step= 500 loss=0.0035023 acc=0.9998779
|
|
||||||
INFO:__main__:step= 1000 loss=0.0022645 acc=0.9996948
|
|
||||||
INFO:__main__:step= 1500 loss=0.0006305 acc=0.9999390
|
|
||||||
INFO:__main__:step= 2000 loss=0.0006203 acc=0.9997559
|
|
||||||
INFO:__main__:step= 2500 loss=0.0004746 acc=0.9996948
|
|
||||||
INFO:__main__:step= 3000 loss=0.0002501 acc=0.9999390
|
|
||||||
INFO:__main__:step= 3500 loss=0.0001484 acc=1.0000000
|
|
||||||
INFO:__main__:step= 4000 loss=0.0001385 acc=1.0000000
|
|
||||||
INFO:__main__:step= 4500 loss=0.0001816 acc=0.9999390
|
|
||||||
INFO:__main__:step= 5000 loss=0.0001943 acc=0.9999390
|
|
||||||
INFO:__main__:step= 5500 loss=0.0001595 acc=1.0000000
|
|
||||||
INFO:__main__:step= 6000 loss=0.0001608 acc=1.0000000
|
|
||||||
INFO:__main__:step= 6500 loss=0.0000808 acc=1.0000000
|
|
||||||
INFO:__main__:step= 7000 loss=0.0001676 acc=0.9998779
|
|
||||||
INFO:__main__:step= 7500 loss=0.0001089 acc=1.0000000
|
|
||||||
INFO:__main__:step= 8000 loss=0.0000847 acc=1.0000000
|
|
||||||
INFO:__main__:step= 8500 loss=0.0000707 acc=1.0000000
|
|
||||||
INFO:__main__:step= 9000 loss=0.0001036 acc=1.0000000
|
|
||||||
INFO:__main__:step= 9500 loss=0.0001216 acc=1.0000000
|
|
||||||
INFO:__main__:step=10000 loss=0.0000516 acc=1.0000000
|
|
||||||
INFO:__main__:step=10500 loss=0.0000347 acc=1.0000000
|
|
||||||
INFO:__main__:step=11000 loss=0.0000673 acc=1.0000000
|
|
||||||
INFO:__main__:step=11500 loss=0.0000634 acc=1.0000000
|
|
||||||
INFO:__main__:step=12000 loss=0.0000239 acc=1.0000000
|
|
||||||
INFO:__main__:step=12500 loss=0.0000039 acc=1.0000000
|
|
||||||
INFO:__main__:step=13000 loss=0.0000418 acc=1.0000000
|
|
||||||
INFO:__main__:step=13500 loss=0.0000079 acc=1.0000000
|
|
||||||
INFO:__main__:step=14000 loss=0.0000367 acc=1.0000000
|
|
||||||
INFO:__main__:step=14500 loss=0.0000217 acc=1.0000000
|
|
||||||
INFO:__main__:step=15000 loss=0.0000267 acc=1.0000000
|
|
||||||
INFO:__main__:step=15500 loss=0.0000759 acc=1.0000000
|
|
||||||
INFO:__main__:step=16000 loss=0.0000215 acc=1.0000000
|
|
||||||
INFO:__main__:step=16500 loss=0.0000015 acc=1.0000000
|
|
||||||
INFO:__main__:step=17000 loss=0.0000008 acc=1.0000000
|
|
||||||
INFO:__main__:step=17500 loss=0.0000147 acc=1.0000000
|
|
||||||
INFO:__main__:step=18000 loss=0.0000000 acc=1.0000000
|
|
||||||
INFO:__main__:step=18500 loss=0.0000001 acc=1.0000000
|
|
||||||
INFO:__main__:step=19000 loss=0.0000318 acc=1.0000000
|
|
||||||
INFO:__main__:step=19500 loss=0.0000243 acc=1.0000000
|
|
||||||
"""
|
|
||||||
|
|
||||||
pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)")
|
|
||||||
rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)]
|
|
||||||
df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True)
|
|
||||||
|
|
||||||
# Avoid log(0) issues for loss plot by clamping at a tiny positive value
|
|
||||||
eps = 1e-10
|
|
||||||
df["loss_clamped"] = df["loss"].clip(lower=eps)
|
|
||||||
|
|
||||||
# Plot 1: Loss
|
|
||||||
plt.figure(figsize=(9, 4.8))
|
|
||||||
plt.plot(df["step"], df["loss_clamped"])
|
|
||||||
plt.yscale("log")
|
|
||||||
plt.xlabel("Step")
|
|
||||||
plt.ylabel("Loss (log scale)")
|
|
||||||
plt.title("Training Loss vs Step")
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
# Plot 2: Accuracy
|
|
||||||
df["err"] = (1.0 - df["acc"]).clip(lower=eps)
|
|
||||||
plt.figure(figsize=(9, 4.8))
|
|
||||||
plt.plot(df["step"], df["err"])
|
|
||||||
plt.yscale("log")
|
|
||||||
plt.xlabel("Step")
|
|
||||||
plt.ylabel("Error rate (1 - accuracy) (log scale)")
|
|
||||||
plt.title("Training Error Rate vs Step")
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.show()
|
|
||||||
37
pairwise_comp_nn.py
Normal file
37
pairwise_comp_nn.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
# 2) Number "embedding" network: R -> R^d
|
||||||
|
class NumberEmbedder(nn.Module):
|
||||||
|
def __init__(self, d=2, hidden=4):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(1, hidden),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden, d),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
# MLP Comparator head: takes (ea, eb, e) -> logit for "a > b"
|
||||||
|
class PairwiseComparator(nn.Module):
|
||||||
|
def __init__(self, d=2, hidden=4, k=0.5):
|
||||||
|
super().__init__()
|
||||||
|
self.log_k = nn.Parameter(torch.tensor([k]))
|
||||||
|
self.embed = NumberEmbedder(d, hidden)
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
nn.Linear(d, hidden),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden, hidden),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, a, b):
|
||||||
|
# trying to force antisym here: h(a,b)=-h(b,a)
|
||||||
|
phi = self.head(self.embed(a-b))
|
||||||
|
phi_neg = self.head(self.embed(b-a))
|
||||||
|
logit = phi - phi_neg
|
||||||
|
|
||||||
|
return (self.log_k ** 2) * logit
|
||||||
@@ -2,23 +2,126 @@
|
|||||||
# pairwise_compare.py
|
# pairwise_compare.py
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
import lzma
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
import re
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
import pairwise_comp_nn as comp_nn
|
||||||
|
|
||||||
# early pytorch device setup
|
# early pytorch device setup
|
||||||
DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu"
|
DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu"
|
||||||
|
|
||||||
# Valves
|
# Valves
|
||||||
DIMENSIONS = 1
|
DIMENSIONS = 2
|
||||||
TRAIN_STEPS = 20000
|
HIDDEN_NEURONS = 4
|
||||||
TRAIN_BATCHSZ = 16384
|
ADAMW_LR = 5e-3
|
||||||
TRAIN_PROGRESS = 500
|
ADAMW_DECAY = 5e-4
|
||||||
BATCH_LOWER = -512.0
|
TRAIN_STEPS = 2000
|
||||||
BATCH_UPPER = 512.0
|
TRAIN_BATCHSZ = 8192
|
||||||
DO_VERBOSE_EARLY_TRAIN = True
|
TRAIN_PROGRESS = 10
|
||||||
|
BATCH_LOWER = -100.0
|
||||||
|
BATCH_UPPER = 100.0
|
||||||
|
|
||||||
def get_torch_info():
|
# Files
|
||||||
|
MODEL_PATH = "./files/pwcomp.model"
|
||||||
|
LOGGING_PATH = "./files/output.log"
|
||||||
|
EMBED_CHART_PATH = "./files/embedding_chart.png"
|
||||||
|
EMBEDDINGS_DATA_PATH = "./files/embedding_data.csv"
|
||||||
|
TRAINING_LOG_PATH = "./files/training.log.xz"
|
||||||
|
LOSS_CHART_PATH = "./files/training_loss_v_step.png"
|
||||||
|
ACC_CHART_PATH = "./files/training_error_v_step.png"
|
||||||
|
|
||||||
|
# TODO: Move plotting into its own file
|
||||||
|
def parse_training_log(file_path: str) -> pd.DataFrame:
|
||||||
|
text: str = ""
|
||||||
|
with lzma.open(file_path, mode='rt') as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)")
|
||||||
|
rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)]
|
||||||
|
df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True)
|
||||||
|
|
||||||
|
# Avoid log(0) issues for loss plot by clamping at a tiny positive value
|
||||||
|
eps = 1e-10
|
||||||
|
df["loss_clamped"] = df["loss"].clip(lower=eps)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
# TODO: Move plotting into its own file
|
||||||
|
def plt_loss_tstep(df: pd.DataFrame) -> None:
|
||||||
|
# Plot 1: Loss
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(df["step"], df["loss_clamped"])
|
||||||
|
plt.yscale("log")
|
||||||
|
plt.xlabel("Step")
|
||||||
|
plt.ylabel("Loss (log scale)")
|
||||||
|
plt.title("Training Loss vs Step")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(LOSS_CHART_PATH)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# TODO: Move plotting into its own file
|
||||||
|
def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None:
|
||||||
|
# Plot 2: Accuracy
|
||||||
|
df["err"] = (1.0 - df["acc"]).clip(lower=eps)
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.plot(df["step"], df["err"])
|
||||||
|
plt.yscale("log")
|
||||||
|
plt.xlabel("Step")
|
||||||
|
plt.ylabel("Error rate (1 - accuracy) (log scale)")
|
||||||
|
plt.title("Training Error Rate vs Step")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(ACC_CHART_PATH)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# TODO: Move plotting into its own file
|
||||||
|
def plt_embeddings(model: comp_nn.PairwiseComparator) -> None:
|
||||||
|
import csv
|
||||||
|
|
||||||
|
log.info("Starting embeddings sweep...")
|
||||||
|
# samples for embedding mapping
|
||||||
|
with torch.no_grad():
|
||||||
|
xs = torch.arange(
|
||||||
|
BATCH_LOWER,
|
||||||
|
BATCH_UPPER + 1.0,
|
||||||
|
0.1,
|
||||||
|
).unsqueeze(1).to(DEVICE) # shape: (N, 1)
|
||||||
|
|
||||||
|
embeddings = model.embed(xs) # shape: (N, d)
|
||||||
|
|
||||||
|
# move data back to CPU for plotting
|
||||||
|
embeddings = embeddings.cpu()
|
||||||
|
xs = xs.cpu()
|
||||||
|
|
||||||
|
# Plot 3: x vs h(x)
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
for i in range(embeddings.shape[1]):
|
||||||
|
plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}")
|
||||||
|
plt.title("x vs h(x)")
|
||||||
|
plt.xlabel("x [input]")
|
||||||
|
plt.ylabel("h(x) [embedding]")
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig(EMBED_CHART_PATH)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
# save all our embeddings data to csv
|
||||||
|
csv_data = list(zip(xs.squeeze().tolist(), embeddings.tolist()))
|
||||||
|
with open(file=EMBEDDINGS_DATA_PATH, mode="w", newline='') as f:
|
||||||
|
csv_file = csv.writer(f)
|
||||||
|
csv_file.writerows(csv_data)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_torch_info() -> None:
|
||||||
log.info("PyTorch Version: %s", torch.__version__)
|
log.info("PyTorch Version: %s", torch.__version__)
|
||||||
log.info("HIP Version: %s", torch.version.hip)
|
log.info("HIP Version: %s", torch.version.hip)
|
||||||
log.info("CUDA support: %s", torch.cuda.is_available())
|
log.info("CUDA support: %s", torch.cuda.is_available())
|
||||||
@@ -28,67 +131,40 @@ def get_torch_info():
|
|||||||
|
|
||||||
log.info("Using %s compute mode", DEVICE)
|
log.info("Using %s compute mode", DEVICE)
|
||||||
|
|
||||||
def set_seed(seed: int):
|
def set_seed(seed: int) -> None:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
# 1) Data: pairs (a, b) with label y = 1 if a > b else 0
|
# pairs (a, b) with label y = 1 if a > b else 0 -> (a,b,y)
|
||||||
def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER):
|
# uses epsi to select the window in which a == b for equality training
|
||||||
|
def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER, epsi=1e-4) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
a = (high - low) * torch.rand(batch_size, 1) + low
|
a = (high - low) * torch.rand(batch_size, 1) + low
|
||||||
b = (high - low) * torch.rand(batch_size, 1) + low
|
b = (high - low) * torch.rand(batch_size, 1) + low
|
||||||
|
|
||||||
# train for if a > b
|
epsi = 1e-4
|
||||||
y = (a > b).float()
|
y = torch.where(a > b + epsi, 1.0,
|
||||||
|
torch.where(a < b - epsi, 0.0, 0.5))
|
||||||
|
|
||||||
# removed but left for my notes; it seems training for equality hurts classifing results that are ~eq
|
|
||||||
# when trained only on "if a > b => y", the model produces more accurate results when classifing if things are equal (~.5 prob).
|
|
||||||
# eq = (a == b).float()
|
|
||||||
# y = gt + 0.5 * eq
|
|
||||||
return a, b, y
|
return a, b, y
|
||||||
|
|
||||||
# 2) Number "embedding" network: R -> R^d
|
|
||||||
class NumberEmbedder(nn.Module):
|
|
||||||
def __init__(self, d=8):
|
|
||||||
super().__init__()
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
nn.Linear(1, 16),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(16, d),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
# 3) Comparator head: takes (ea, eb) -> logit for "a > b"
|
|
||||||
class PairwiseComparator(nn.Module):
|
|
||||||
def __init__(self, d=8):
|
|
||||||
super().__init__()
|
|
||||||
self.embed = NumberEmbedder(d)
|
|
||||||
self.head = nn.Sequential(
|
|
||||||
nn.Linear(2 * d + 1, 16),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(16, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, a, b):
|
|
||||||
ea = self.embed(a)
|
|
||||||
eb = self.embed(b)
|
|
||||||
delta_ab = a - b
|
|
||||||
x = torch.cat([ea, eb, delta_ab], dim=-1)
|
|
||||||
|
|
||||||
return self.head(x) # logits
|
|
||||||
|
|
||||||
def training_entry():
|
def training_entry():
|
||||||
|
get_torch_info()
|
||||||
|
|
||||||
# all prng seeds to 0 for deterministic outputs durring testing
|
# all prng seeds to 0 for deterministic outputs durring testing
|
||||||
# the seed should initialized normally otherwise
|
# the seed should initialized normally otherwise
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
|
|
||||||
model = PairwiseComparator(d=DIMENSIONS).to(DEVICE)
|
model = comp_nn.PairwiseComparator(d=DIMENSIONS, hidden=HIDDEN_NEURONS).to(DEVICE)
|
||||||
opt = torch.optim.AdamW(model.parameters(), lr=2e-3)
|
opt = torch.optim.AdamW(model.parameters(), lr=ADAMW_LR, weight_decay=ADAMW_DECAY)
|
||||||
|
|
||||||
|
log.info(f"Using {TRAINING_LOG_PATH} as the logging destination for training...")
|
||||||
|
with lzma.open(TRAINING_LOG_PATH, mode='wt') as tlog:
|
||||||
|
# training loop
|
||||||
|
training_start_time = datetime.datetime.now()
|
||||||
|
last_ack = datetime.datetime.now()
|
||||||
|
|
||||||
# 4) Train
|
|
||||||
for step in range(TRAIN_STEPS):
|
for step in range(TRAIN_STEPS):
|
||||||
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
||||||
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
|
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
|
||||||
@@ -100,20 +176,24 @@ def training_entry():
|
|||||||
loss_fn.backward()
|
loss_fn.backward()
|
||||||
opt.step()
|
opt.step()
|
||||||
|
|
||||||
if step <= TRAIN_PROGRESS and DO_VERBOSE_EARLY_TRAIN is True:
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred = (torch.sigmoid(logits) > 0.5).float()
|
pred = (torch.sigmoid(logits) > 0.5).float()
|
||||||
acc = (pred == y).float().mean().item()
|
acc = (pred == y).float().mean().item()
|
||||||
log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}")
|
tlog.write(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}\n")
|
||||||
elif step % TRAIN_PROGRESS == 0:
|
|
||||||
with torch.no_grad():
|
|
||||||
pred = (torch.sigmoid(logits) > 0.5).float()
|
|
||||||
acc = (pred == y).float().mean().item()
|
|
||||||
log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}")
|
|
||||||
|
|
||||||
# 5) Quick test: evaluate accuracy on fresh pairs
|
# also print to normal text log occasionally to show some activity.
|
||||||
|
# every 10 steps check if its been longer than 5 seconds since we've updated the user
|
||||||
|
if step % 10 == 0:
|
||||||
|
if (datetime.datetime.now() - last_ack).total_seconds() > 5:
|
||||||
|
log.info(f"still training... step={step} of {TRAIN_STEPS}")
|
||||||
|
last_ack = datetime.datetime.now()
|
||||||
|
|
||||||
|
training_end_time = datetime.datetime.now()
|
||||||
|
log.info(f"Training steps complete. Start time: {training_start_time} End time: {training_end_time}")
|
||||||
|
|
||||||
|
# evaluate final model accuracy on fresh pairs
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
a, b, y = sample_batch(TRAIN_BATCHSZ)
|
a, b, y = sample_batch(TRAIN_BATCHSZ*4)
|
||||||
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
|
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
|
||||||
logits = model(a, b)
|
logits = model(a, b)
|
||||||
pred = (torch.sigmoid(logits) > 0.5).float()
|
pred = (torch.sigmoid(logits) > 0.5).float()
|
||||||
@@ -121,13 +201,15 @@ def training_entry():
|
|||||||
acc = (pred == y).float().mean().item()
|
acc = (pred == y).float().mean().item()
|
||||||
log.info(f"Final test acc: {acc} errors: {errors}")
|
log.info(f"Final test acc: {acc} errors: {errors}")
|
||||||
|
|
||||||
# embed model depth into the model serialization
|
# embed model dimensions into the model serialization
|
||||||
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, "model.pth")
|
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS, "h": HIDDEN_NEURONS}, MODEL_PATH)
|
||||||
log.info("Saved PyTorch Model State to model.pth")
|
log.info(f"Saved PyTorch Model State to {MODEL_PATH}")
|
||||||
|
|
||||||
def infer_entry():
|
def infer_entry():
|
||||||
model_ckpt = torch.load("model.pth", map_location=DEVICE)
|
get_torch_info()
|
||||||
model = PairwiseComparator(d=model_ckpt["d"]).to(DEVICE)
|
|
||||||
|
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
|
||||||
|
model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE)
|
||||||
model.load_state_dict(model_ckpt["state_dict"])
|
model.load_state_dict(model_ckpt["state_dict"])
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@@ -138,44 +220,104 @@ def infer_entry():
|
|||||||
b = torch.tensor([[p[1]] for p in pairs], dtype=torch.float32, device=DEVICE)
|
b = torch.tensor([[p[1]] for p in pairs], dtype=torch.float32, device=DEVICE)
|
||||||
|
|
||||||
# sanity check before inference
|
# sanity check before inference
|
||||||
log.info(f"a.device: {a.device} model.device: {next(model.parameters()).device}")
|
log.debug(f"a.device: {a.device} model.device: {next(model.parameters()).device}")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
probs = torch.sigmoid(model(a, b))
|
probs = torch.sigmoid(model(a, b))
|
||||||
|
|
||||||
|
log.info(f"Output probabilities for {pairs.__len__()} pairs")
|
||||||
for (x, y), p in zip(pairs, probs):
|
for (x, y), p in zip(pairs, probs):
|
||||||
log.info(f"P({x} > {y}) = {p.item():.3f}")
|
log.info(f"P({x} > {y}) = {p.item():.3f}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def graphs_entry():
|
||||||
|
get_torch_info()
|
||||||
|
|
||||||
|
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
|
||||||
|
model = comp_nn.PairwiseComparator(d=model_ckpt["d"], hidden=model_ckpt["h"]).to(DEVICE)
|
||||||
|
model.load_state_dict(model_ckpt["state_dict"])
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
plt_embeddings(model)
|
||||||
|
|
||||||
|
data = parse_training_log(TRAINING_LOG_PATH)
|
||||||
|
plt_loss_tstep(data)
|
||||||
|
plt_acc_tstep(data)
|
||||||
|
|
||||||
|
help_text = r"""
|
||||||
|
pairwise_compare.py — tiny pairwise "a > b?" neural comparator
|
||||||
|
|
||||||
|
USAGE
|
||||||
|
python3 pairwise_compare.py train
|
||||||
|
Train a PairwiseComparator on synthetic (a,b) pairs sampled uniformly from
|
||||||
|
[BATCH_LOWER, BATCH_UPPER]. Labels are:
|
||||||
|
1.0 if a > b + epsi
|
||||||
|
0.0 if a < b - epsi
|
||||||
|
0.5 otherwise (near-equality window)
|
||||||
|
Writes training metrics to:
|
||||||
|
./files/training.log.xz
|
||||||
|
Saves the trained model checkpoint to:
|
||||||
|
./files/pwcomp.model
|
||||||
|
|
||||||
|
python3 pairwise_compare.py infer
|
||||||
|
Load ./files/pwcomp.model and run inference on a built-in list of test pairs.
|
||||||
|
Prints probabilities as:
|
||||||
|
P(a > b) = sigmoid(model(a,b))
|
||||||
|
|
||||||
|
python3 pairwise_compare.py graphs
|
||||||
|
Load ./files/pwcomp.model and generate plots + exports:
|
||||||
|
./files/embedding_chart.png (embed(x) vs x for each embedding dimension)
|
||||||
|
./files/embedding_data.csv (x and embedding vectors)
|
||||||
|
./files/training_loss_v_step.png
|
||||||
|
./files/training_error_v_step.png (1 - acc, log scale)
|
||||||
|
Requires that ./files/training.log.xz exists (i.e., you ran "train" first).
|
||||||
|
|
||||||
|
FILES
|
||||||
|
./files/output.log General runtime log (info/errors)
|
||||||
|
./files/pwcomp.model Torch checkpoint: {"state_dict": ..., "d": DIMENSIONS}
|
||||||
|
./files/training.log.xz step/loss/acc trace used for plots
|
||||||
|
|
||||||
|
NOTES
|
||||||
|
- DEVICE is chosen via torch.accelerator if available, else CPU.
|
||||||
|
- Hyperparameters are controlled by the "Valves" constants near the top.
|
||||||
|
"""
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
# TODO: tidy up the paths to files and checking if the directory exists
|
||||||
|
if not os.path.exists("./files/"):
|
||||||
|
os.mkdir("./files")
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(LOGGING_PATH),
|
||||||
|
logging.StreamHandler(stream=sys.stdout)
|
||||||
|
])
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
logging.basicConfig(filename='pairwise_compare.log', level=logging.INFO)
|
log.info(f"Log file {LOGGING_PATH} opened {datetime.datetime.now()}")
|
||||||
log.info(f"Log opened {datetime.datetime.now()}")
|
|
||||||
|
|
||||||
get_torch_info()
|
|
||||||
|
|
||||||
name = os.path.basename(sys.argv[0])
|
|
||||||
if name == 'train.py':
|
|
||||||
training_entry()
|
|
||||||
elif name == 'infer.py':
|
|
||||||
infer_entry()
|
|
||||||
else:
|
|
||||||
# alt call patern
|
|
||||||
# python3 pairwise_compare.py train
|
# python3 pairwise_compare.py train
|
||||||
# python3 pairwise_compare.py infer
|
# python3 pairwise_compare.py infer
|
||||||
|
# python3 pairwise_compare.py graphs
|
||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
mode = sys.argv[1].strip().lower()
|
match sys.argv[1].strip().lower():
|
||||||
if mode == "train":
|
case "train":
|
||||||
training_entry()
|
training_entry()
|
||||||
elif mode == "infer":
|
case "infer":
|
||||||
infer_entry()
|
infer_entry()
|
||||||
else:
|
case "graphs":
|
||||||
|
graphs_entry()
|
||||||
|
case "help":
|
||||||
|
log.info(help_text)
|
||||||
|
case mode:
|
||||||
log.error(f"Unknown operation: {mode}")
|
log.error(f"Unknown operation: {mode}")
|
||||||
log.error("Invalid call syntax, call script as \"train.py\" or \"infer.py\" or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
|
log.error("valid options are one of [\"train\", \"infer\", \"graphs\", \"help\"]")
|
||||||
else:
|
log.info(help_text)
|
||||||
log.error("Not enough arguments passed to script; call as train.py or infer.py or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
|
|
||||||
|
|
||||||
log.info(f"Log closed {datetime.datetime.now()}")
|
log.info(f"Log closed {datetime.datetime.now()}")
|
||||||
Reference in New Issue
Block a user