commit 5c891493883ff6aa48bebfbed1c3361c72705a26 Author: Elaina Claus Date: Thu Dec 18 20:05:57 2025 -0500 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a9b424a --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.venv/ +*.pth +*.log diff --git a/README.md b/README.md new file mode 100644 index 0000000..a1cbde8 --- /dev/null +++ b/README.md @@ -0,0 +1,55 @@ +# Toy ML models and experiments + +This repo contains some experiments with ML I am working on to learn how they work. + +## Setup + +### Debian 13 (Trixie) Python3 setup + +```bash +$ sudo apt install python3-full python3-venv +python3-full is already the newest version (3.13.5-1). +python3-venv is already the newest version (3.13.5-1). +``` + +### Get the code and init the venv + +```bash +git clone http://the.git.repo/user/mltoys.git mltoys && cd mltoys +python3 -m venv .venv +source .venv/bin/activate +``` + +### Installing python deps + +For instructions installing pytorch refer to the [PyTorch Home page] + +```bash +pip3 install numpy +# use the nvidia CUDA or CPU only packages if required. +# I use the ROCm packages, so the repo uses the ROCm packages. +pip3 install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4 +pip3 install pandas matplotlib +``` + +## Running the code + +### Training the model + +```bash +(.venv) ➜ python3 pairwise_compare.py train +``` + +### Running inference + +```bash +(.venv) ➜ python3 pairwise_compare.py infer +``` + +### Some General notes + +- This module logs to pairwise_compare.log using the python logging module +- Saved model state (training data) is saved at model.pth +- output_graphs.py generates two graphs to inspect the training process, copy the step log into the script to generate graphs. + +[PyTorch Home page]: https://pytorch.org diff --git a/infer.py b/infer.py new file mode 120000 index 0000000..235fa81 --- /dev/null +++ b/infer.py @@ -0,0 +1 @@ +pairwise_compare.py \ No newline at end of file diff --git a/output_graphs.py b/output_graphs.py new file mode 100644 index 0000000..dd5302b --- /dev/null +++ b/output_graphs.py @@ -0,0 +1,573 @@ +import re +import pandas as pd +import matplotlib.pyplot as plt + +text = r"""INFO:__main__:step= 0 loss=1.3149878 acc=0.9093018 +INFO:__main__:step= 1 loss=1.0067503 acc=0.9183350 +INFO:__main__:step= 2 loss=0.7588759 acc=0.9278564 +INFO:__main__:step= 3 loss=0.5432727 acc=0.9384766 +INFO:__main__:step= 4 loss=0.3854972 acc=0.9427490 +INFO:__main__:step= 5 loss=0.2379047 acc=0.9512939 +INFO:__main__:step= 6 loss=0.1357967 acc=0.9575195 +INFO:__main__:step= 7 loss=0.0659471 acc=0.9808960 +INFO:__main__:step= 8 loss=0.0325629 acc=0.9878540 +INFO:__main__:step= 9 loss=0.0159342 acc=0.9951172 +INFO:__main__:step= 10 loss=0.0123130 acc=0.9995117 +INFO:__main__:step= 11 loss=0.0129503 acc=0.9963989 +INFO:__main__:step= 12 loss=0.0198095 acc=0.9927979 +INFO:__main__:step= 13 loss=0.0261302 acc=0.9901733 +INFO:__main__:step= 14 loss=0.0338637 acc=0.9891968 +INFO:__main__:step= 15 loss=0.0425015 acc=0.9861450 +INFO:__main__:step= 16 loss=0.0458022 acc=0.9865112 +INFO:__main__:step= 17 loss=0.0492692 acc=0.9859619 +INFO:__main__:step= 18 loss=0.0543331 acc=0.9850464 +INFO:__main__:step= 19 loss=0.0575214 acc=0.9846802 +INFO:__main__:step= 20 loss=0.0635471 acc=0.9841919 +INFO:__main__:step= 21 loss=0.0503580 acc=0.9871216 +INFO:__main__:step= 22 loss=0.0503206 acc=0.9863892 +INFO:__main__:step= 23 loss=0.0549617 acc=0.9852905 +INFO:__main__:step= 24 loss=0.0438731 acc=0.9872437 +INFO:__main__:step= 25 loss=0.0471392 acc=0.9867554 +INFO:__main__:step= 26 loss=0.0431895 acc=0.9868774 +INFO:__main__:step= 27 loss=0.0354491 acc=0.9892578 +INFO:__main__:step= 28 loss=0.0311145 acc=0.9905396 +INFO:__main__:step= 29 loss=0.0305991 acc=0.9897461 +INFO:__main__:step= 30 loss=0.0229458 acc=0.9920044 +INFO:__main__:step= 31 loss=0.0203844 acc=0.9925537 +INFO:__main__:step= 32 loss=0.0190371 acc=0.9929810 +INFO:__main__:step= 33 loss=0.0167215 acc=0.9935303 +INFO:__main__:step= 34 loss=0.0153083 acc=0.9937744 +INFO:__main__:step= 35 loss=0.0120912 acc=0.9947510 +INFO:__main__:step= 36 loss=0.0101961 acc=0.9961548 +INFO:__main__:step= 37 loss=0.0100932 acc=0.9968262 +INFO:__main__:step= 38 loss=0.0101871 acc=0.9968262 +INFO:__main__:step= 39 loss=0.0088590 acc=0.9978027 +INFO:__main__:step= 40 loss=0.0090834 acc=0.9982300 +INFO:__main__:step= 41 loss=0.0084685 acc=0.9993896 +INFO:__main__:step= 42 loss=0.0075352 acc=0.9989014 +INFO:__main__:step= 43 loss=0.0080332 acc=0.9987793 +INFO:__main__:step= 44 loss=0.0087203 acc=0.9979858 +INFO:__main__:step= 45 loss=0.0086979 acc=0.9977417 +INFO:__main__:step= 46 loss=0.0086313 acc=0.9972534 +INFO:__main__:step= 47 loss=0.0092596 acc=0.9971924 +INFO:__main__:step= 48 loss=0.0086056 acc=0.9978027 +INFO:__main__:step= 49 loss=0.0101687 acc=0.9960327 +INFO:__main__:step= 50 loss=0.0083168 acc=0.9978638 +INFO:__main__:step= 51 loss=0.0095370 acc=0.9965820 +INFO:__main__:step= 52 loss=0.0093318 acc=0.9968872 +INFO:__main__:step= 53 loss=0.0093573 acc=0.9966431 +INFO:__main__:step= 54 loss=0.0091653 acc=0.9971313 +INFO:__main__:step= 55 loss=0.0097087 acc=0.9962769 +INFO:__main__:step= 56 loss=0.0094240 acc=0.9972534 +INFO:__main__:step= 57 loss=0.0088949 acc=0.9970093 +INFO:__main__:step= 58 loss=0.0085519 acc=0.9975586 +INFO:__main__:step= 59 loss=0.0098900 acc=0.9967651 +INFO:__main__:step= 60 loss=0.0085984 acc=0.9977417 +INFO:__main__:step= 61 loss=0.0082639 acc=0.9984131 +INFO:__main__:step= 62 loss=0.0083590 acc=0.9979858 +INFO:__main__:step= 63 loss=0.0072852 acc=0.9990234 +INFO:__main__:step= 64 loss=0.0080845 acc=0.9983521 +INFO:__main__:step= 65 loss=0.0074006 acc=0.9988403 +INFO:__main__:step= 66 loss=0.0067656 acc=0.9990845 +INFO:__main__:step= 67 loss=0.0072893 acc=0.9992065 +INFO:__main__:step= 68 loss=0.0077519 acc=0.9991455 +INFO:__main__:step= 69 loss=0.0072619 acc=0.9992065 +INFO:__main__:step= 70 loss=0.0069199 acc=0.9993896 +INFO:__main__:step= 71 loss=0.0070979 acc=0.9992676 +INFO:__main__:step= 72 loss=0.0073295 acc=0.9992676 +INFO:__main__:step= 73 loss=0.0074595 acc=0.9994507 +INFO:__main__:step= 74 loss=0.0069510 acc=0.9993896 +INFO:__main__:step= 75 loss=0.0072545 acc=0.9993896 +INFO:__main__:step= 76 loss=0.0070704 acc=0.9994507 +INFO:__main__:step= 77 loss=0.0080527 acc=0.9992065 +INFO:__main__:step= 78 loss=0.0078222 acc=0.9990234 +INFO:__main__:step= 79 loss=0.0074403 acc=0.9993896 +INFO:__main__:step= 80 loss=0.0070136 acc=0.9992065 +INFO:__main__:step= 81 loss=0.0068662 acc=0.9992676 +INFO:__main__:step= 82 loss=0.0071025 acc=0.9992676 +INFO:__main__:step= 83 loss=0.0068718 acc=0.9993896 +INFO:__main__:step= 84 loss=0.0072839 acc=0.9991455 +INFO:__main__:step= 85 loss=0.0071138 acc=0.9992676 +INFO:__main__:step= 86 loss=0.0070522 acc=0.9992065 +INFO:__main__:step= 87 loss=0.0073670 acc=0.9992065 +INFO:__main__:step= 88 loss=0.0069347 acc=0.9995117 +INFO:__main__:step= 89 loss=0.0066190 acc=0.9993286 +INFO:__main__:step= 90 loss=0.0069868 acc=0.9997559 +INFO:__main__:step= 91 loss=0.0071491 acc=0.9995728 +INFO:__main__:step= 92 loss=0.0071107 acc=0.9991455 +INFO:__main__:step= 93 loss=0.0065876 acc=0.9996338 +INFO:__main__:step= 94 loss=0.0067994 acc=0.9995117 +INFO:__main__:step= 95 loss=0.0072922 acc=0.9995728 +INFO:__main__:step= 96 loss=0.0070875 acc=0.9995117 +INFO:__main__:step= 97 loss=0.0067269 acc=0.9996338 +INFO:__main__:step= 98 loss=0.0069193 acc=0.9995728 +INFO:__main__:step= 99 loss=0.0068322 acc=0.9995728 +INFO:__main__:step= 100 loss=0.0065636 acc=0.9995117 +INFO:__main__:step= 101 loss=0.0072776 acc=0.9995117 +INFO:__main__:step= 102 loss=0.0074568 acc=0.9996338 +INFO:__main__:step= 103 loss=0.0067574 acc=0.9996338 +INFO:__main__:step= 104 loss=0.0067869 acc=0.9995728 +INFO:__main__:step= 105 loss=0.0064811 acc=0.9996948 +INFO:__main__:step= 106 loss=0.0077020 acc=0.9990234 +INFO:__main__:step= 107 loss=0.0069068 acc=0.9996338 +INFO:__main__:step= 108 loss=0.0068666 acc=0.9994507 +INFO:__main__:step= 109 loss=0.0066869 acc=0.9992065 +INFO:__main__:step= 110 loss=0.0070414 acc=0.9993286 +INFO:__main__:step= 111 loss=0.0072129 acc=0.9994507 +INFO:__main__:step= 112 loss=0.0067772 acc=0.9993896 +INFO:__main__:step= 113 loss=0.0069291 acc=0.9995728 +INFO:__main__:step= 114 loss=0.0065003 acc=0.9996338 +INFO:__main__:step= 115 loss=0.0071769 acc=0.9993286 +INFO:__main__:step= 116 loss=0.0068795 acc=0.9993896 +INFO:__main__:step= 117 loss=0.0068955 acc=0.9995117 +INFO:__main__:step= 118 loss=0.0071484 acc=0.9993896 +INFO:__main__:step= 119 loss=0.0072527 acc=0.9992676 +INFO:__main__:step= 120 loss=0.0070072 acc=0.9996948 +INFO:__main__:step= 121 loss=0.0069677 acc=0.9995117 +INFO:__main__:step= 122 loss=0.0070129 acc=0.9996338 +INFO:__main__:step= 123 loss=0.0068947 acc=0.9990845 +INFO:__main__:step= 124 loss=0.0066504 acc=0.9996338 +INFO:__main__:step= 125 loss=0.0060845 acc=0.9994507 +INFO:__main__:step= 126 loss=0.0063553 acc=0.9995117 +INFO:__main__:step= 127 loss=0.0063697 acc=0.9996948 +INFO:__main__:step= 128 loss=0.0064274 acc=0.9995728 +INFO:__main__:step= 129 loss=0.0057200 acc=0.9998169 +INFO:__main__:step= 130 loss=0.0064411 acc=0.9996948 +INFO:__main__:step= 131 loss=0.0060729 acc=0.9998779 +INFO:__main__:step= 132 loss=0.0062524 acc=0.9996948 +INFO:__main__:step= 133 loss=0.0060491 acc=0.9991455 +INFO:__main__:step= 134 loss=0.0065517 acc=0.9996338 +INFO:__main__:step= 135 loss=0.0058316 acc=0.9998169 +INFO:__main__:step= 136 loss=0.0059018 acc=0.9997559 +INFO:__main__:step= 137 loss=0.0065307 acc=0.9996948 +INFO:__main__:step= 138 loss=0.0064329 acc=0.9995117 +INFO:__main__:step= 139 loss=0.0066867 acc=0.9995728 +INFO:__main__:step= 140 loss=0.0070452 acc=0.9992676 +INFO:__main__:step= 141 loss=0.0064459 acc=0.9993896 +INFO:__main__:step= 142 loss=0.0061031 acc=0.9997559 +INFO:__main__:step= 143 loss=0.0067490 acc=0.9992676 +INFO:__main__:step= 144 loss=0.0065364 acc=0.9998779 +INFO:__main__:step= 145 loss=0.0058783 acc=0.9993286 +INFO:__main__:step= 146 loss=0.0060647 acc=0.9998169 +INFO:__main__:step= 147 loss=0.0064641 acc=0.9996338 +INFO:__main__:step= 148 loss=0.0065979 acc=0.9996338 +INFO:__main__:step= 149 loss=0.0058760 acc=0.9997559 +INFO:__main__:step= 150 loss=0.0058440 acc=0.9998779 +INFO:__main__:step= 151 loss=0.0060483 acc=0.9998169 +INFO:__main__:step= 152 loss=0.0062729 acc=0.9998169 +INFO:__main__:step= 153 loss=0.0061169 acc=0.9996338 +INFO:__main__:step= 154 loss=0.0064641 acc=0.9995728 +INFO:__main__:step= 155 loss=0.0062200 acc=0.9995117 +INFO:__main__:step= 156 loss=0.0062547 acc=0.9997559 +INFO:__main__:step= 157 loss=0.0064813 acc=0.9995728 +INFO:__main__:step= 158 loss=0.0061465 acc=0.9996338 +INFO:__main__:step= 159 loss=0.0056798 acc=0.9995728 +INFO:__main__:step= 160 loss=0.0064472 acc=0.9996338 +INFO:__main__:step= 161 loss=0.0055585 acc=0.9997559 +INFO:__main__:step= 162 loss=0.0057363 acc=0.9995117 +INFO:__main__:step= 163 loss=0.0061281 acc=0.9995728 +INFO:__main__:step= 164 loss=0.0062811 acc=0.9995117 +INFO:__main__:step= 165 loss=0.0058447 acc=0.9997559 +INFO:__main__:step= 166 loss=0.0054476 acc=0.9998169 +INFO:__main__:step= 167 loss=0.0067439 acc=0.9998779 +INFO:__main__:step= 168 loss=0.0063363 acc=0.9992676 +INFO:__main__:step= 169 loss=0.0060227 acc=0.9995117 +INFO:__main__:step= 170 loss=0.0060772 acc=0.9998169 +INFO:__main__:step= 171 loss=0.0060449 acc=0.9994507 +INFO:__main__:step= 172 loss=0.0061583 acc=0.9997559 +INFO:__main__:step= 173 loss=0.0061034 acc=0.9995728 +INFO:__main__:step= 174 loss=0.0060934 acc=0.9996948 +INFO:__main__:step= 175 loss=0.0062173 acc=0.9995117 +INFO:__main__:step= 176 loss=0.0065296 acc=0.9995728 +INFO:__main__:step= 177 loss=0.0067725 acc=0.9994507 +INFO:__main__:step= 178 loss=0.0064399 acc=0.9997559 +INFO:__main__:step= 179 loss=0.0060964 acc=0.9997559 +INFO:__main__:step= 180 loss=0.0053458 acc=0.9995728 +INFO:__main__:step= 181 loss=0.0056739 acc=0.9996338 +INFO:__main__:step= 182 loss=0.0054851 acc=0.9996338 +INFO:__main__:step= 183 loss=0.0064169 acc=0.9997559 +INFO:__main__:step= 184 loss=0.0058324 acc=0.9996948 +INFO:__main__:step= 185 loss=0.0058293 acc=0.9994507 +INFO:__main__:step= 186 loss=0.0067926 acc=0.9996338 +INFO:__main__:step= 187 loss=0.0053189 acc=0.9994507 +INFO:__main__:step= 188 loss=0.0059640 acc=0.9998169 +INFO:__main__:step= 189 loss=0.0055875 acc=0.9996338 +INFO:__main__:step= 190 loss=0.0057980 acc=0.9995117 +INFO:__main__:step= 191 loss=0.0057864 acc=0.9996338 +INFO:__main__:step= 192 loss=0.0052436 acc=0.9999390 +INFO:__main__:step= 193 loss=0.0063448 acc=0.9994507 +INFO:__main__:step= 194 loss=0.0056429 acc=0.9996948 +INFO:__main__:step= 195 loss=0.0057287 acc=0.9993286 +INFO:__main__:step= 196 loss=0.0058553 acc=0.9998169 +INFO:__main__:step= 197 loss=0.0058331 acc=0.9992676 +INFO:__main__:step= 198 loss=0.0052676 acc=0.9995117 +INFO:__main__:step= 199 loss=0.0058648 acc=0.9996948 +INFO:__main__:step= 200 loss=0.0061336 acc=0.9998169 +INFO:__main__:step= 201 loss=0.0053004 acc=0.9998169 +INFO:__main__:step= 202 loss=0.0061638 acc=0.9997559 +INFO:__main__:step= 203 loss=0.0065025 acc=0.9995117 +INFO:__main__:step= 204 loss=0.0059157 acc=0.9998169 +INFO:__main__:step= 205 loss=0.0054036 acc=0.9998779 +INFO:__main__:step= 206 loss=0.0056138 acc=0.9996338 +INFO:__main__:step= 207 loss=0.0054218 acc=0.9996338 +INFO:__main__:step= 208 loss=0.0066803 acc=0.9992676 +INFO:__main__:step= 209 loss=0.0058715 acc=0.9996948 +INFO:__main__:step= 210 loss=0.0055717 acc=0.9998169 +INFO:__main__:step= 211 loss=0.0060626 acc=0.9994507 +INFO:__main__:step= 212 loss=0.0051552 acc=0.9998779 +INFO:__main__:step= 213 loss=0.0059606 acc=0.9995117 +INFO:__main__:step= 214 loss=0.0049876 acc=0.9998779 +INFO:__main__:step= 215 loss=0.0057766 acc=0.9996338 +INFO:__main__:step= 216 loss=0.0060298 acc=0.9996338 +INFO:__main__:step= 217 loss=0.0055154 acc=0.9996338 +INFO:__main__:step= 218 loss=0.0055578 acc=0.9998169 +INFO:__main__:step= 219 loss=0.0052298 acc=0.9995728 +INFO:__main__:step= 220 loss=0.0064167 acc=0.9995117 +INFO:__main__:step= 221 loss=0.0056415 acc=0.9996338 +INFO:__main__:step= 222 loss=0.0052709 acc=0.9995728 +INFO:__main__:step= 223 loss=0.0060531 acc=0.9995728 +INFO:__main__:step= 224 loss=0.0049705 acc=0.9995117 +INFO:__main__:step= 225 loss=0.0053588 acc=0.9997559 +INFO:__main__:step= 226 loss=0.0052681 acc=0.9998779 +INFO:__main__:step= 227 loss=0.0053794 acc=0.9996948 +INFO:__main__:step= 228 loss=0.0056610 acc=0.9997559 +INFO:__main__:step= 229 loss=0.0057926 acc=0.9995728 +INFO:__main__:step= 230 loss=0.0053835 acc=0.9998169 +INFO:__main__:step= 231 loss=0.0045384 acc=0.9996948 +INFO:__main__:step= 232 loss=0.0048465 acc=0.9997559 +INFO:__main__:step= 233 loss=0.0057317 acc=0.9995117 +INFO:__main__:step= 234 loss=0.0058601 acc=0.9995117 +INFO:__main__:step= 235 loss=0.0050546 acc=0.9997559 +INFO:__main__:step= 236 loss=0.0049725 acc=1.0000000 +INFO:__main__:step= 237 loss=0.0048411 acc=0.9995728 +INFO:__main__:step= 238 loss=0.0054285 acc=0.9998169 +INFO:__main__:step= 239 loss=0.0051857 acc=0.9997559 +INFO:__main__:step= 240 loss=0.0054422 acc=0.9996948 +INFO:__main__:step= 241 loss=0.0054665 acc=0.9999390 +INFO:__main__:step= 242 loss=0.0049092 acc=0.9996948 +INFO:__main__:step= 243 loss=0.0050945 acc=0.9995728 +INFO:__main__:step= 244 loss=0.0057586 acc=0.9992676 +INFO:__main__:step= 245 loss=0.0055652 acc=0.9997559 +INFO:__main__:step= 246 loss=0.0052052 acc=0.9996948 +INFO:__main__:step= 247 loss=0.0050273 acc=0.9996948 +INFO:__main__:step= 248 loss=0.0055413 acc=0.9994507 +INFO:__main__:step= 249 loss=0.0050175 acc=0.9998169 +INFO:__main__:step= 250 loss=0.0049932 acc=0.9998169 +INFO:__main__:step= 251 loss=0.0050813 acc=0.9997559 +INFO:__main__:step= 252 loss=0.0051735 acc=0.9998169 +INFO:__main__:step= 253 loss=0.0049972 acc=0.9998169 +INFO:__main__:step= 254 loss=0.0051353 acc=0.9996338 +INFO:__main__:step= 255 loss=0.0052580 acc=0.9996338 +INFO:__main__:step= 256 loss=0.0047390 acc=0.9996948 +INFO:__main__:step= 257 loss=0.0051188 acc=0.9997559 +INFO:__main__:step= 258 loss=0.0046280 acc=0.9997559 +INFO:__main__:step= 259 loss=0.0047725 acc=0.9997559 +INFO:__main__:step= 260 loss=0.0045635 acc=0.9996338 +INFO:__main__:step= 261 loss=0.0053751 acc=0.9995728 +INFO:__main__:step= 262 loss=0.0046179 acc=0.9998779 +INFO:__main__:step= 263 loss=0.0052214 acc=0.9994507 +INFO:__main__:step= 264 loss=0.0044471 acc=0.9996948 +INFO:__main__:step= 265 loss=0.0053352 acc=0.9997559 +INFO:__main__:step= 266 loss=0.0053903 acc=0.9998779 +INFO:__main__:step= 267 loss=0.0050240 acc=0.9996948 +INFO:__main__:step= 268 loss=0.0048375 acc=0.9996948 +INFO:__main__:step= 269 loss=0.0055583 acc=0.9995117 +INFO:__main__:step= 270 loss=0.0047512 acc=1.0000000 +INFO:__main__:step= 271 loss=0.0045881 acc=0.9998779 +INFO:__main__:step= 272 loss=0.0045754 acc=0.9996338 +INFO:__main__:step= 273 loss=0.0051980 acc=0.9996338 +INFO:__main__:step= 274 loss=0.0048658 acc=0.9997559 +INFO:__main__:step= 275 loss=0.0049445 acc=0.9996338 +INFO:__main__:step= 276 loss=0.0051786 acc=0.9998779 +INFO:__main__:step= 277 loss=0.0045741 acc=0.9997559 +INFO:__main__:step= 278 loss=0.0054667 acc=0.9995728 +INFO:__main__:step= 279 loss=0.0049829 acc=0.9998169 +INFO:__main__:step= 280 loss=0.0049411 acc=0.9995117 +INFO:__main__:step= 281 loss=0.0051374 acc=0.9998779 +INFO:__main__:step= 282 loss=0.0046585 acc=0.9996338 +INFO:__main__:step= 283 loss=0.0049610 acc=0.9997559 +INFO:__main__:step= 284 loss=0.0044862 acc=0.9998779 +INFO:__main__:step= 285 loss=0.0050020 acc=0.9996948 +INFO:__main__:step= 286 loss=0.0043370 acc=0.9999390 +INFO:__main__:step= 287 loss=0.0051258 acc=0.9996338 +INFO:__main__:step= 288 loss=0.0043914 acc=0.9998779 +INFO:__main__:step= 289 loss=0.0043173 acc=0.9996948 +INFO:__main__:step= 290 loss=0.0047723 acc=0.9998779 +INFO:__main__:step= 291 loss=0.0049867 acc=0.9996948 +INFO:__main__:step= 292 loss=0.0046409 acc=0.9998169 +INFO:__main__:step= 293 loss=0.0039387 acc=1.0000000 +INFO:__main__:step= 294 loss=0.0051583 acc=0.9994507 +INFO:__main__:step= 295 loss=0.0043871 acc=0.9998169 +INFO:__main__:step= 296 loss=0.0046067 acc=0.9997559 +INFO:__main__:step= 297 loss=0.0043712 acc=0.9996948 +INFO:__main__:step= 298 loss=0.0047873 acc=0.9995117 +INFO:__main__:step= 299 loss=0.0047681 acc=0.9998779 +INFO:__main__:step= 300 loss=0.0050251 acc=0.9997559 +INFO:__main__:step= 301 loss=0.0045633 acc=0.9998169 +INFO:__main__:step= 302 loss=0.0043650 acc=0.9999390 +INFO:__main__:step= 303 loss=0.0049224 acc=0.9997559 +INFO:__main__:step= 304 loss=0.0049409 acc=0.9998169 +INFO:__main__:step= 305 loss=0.0052364 acc=0.9996948 +INFO:__main__:step= 306 loss=0.0051323 acc=0.9995728 +INFO:__main__:step= 307 loss=0.0047827 acc=0.9997559 +INFO:__main__:step= 308 loss=0.0044320 acc=0.9997559 +INFO:__main__:step= 309 loss=0.0050638 acc=0.9997559 +INFO:__main__:step= 310 loss=0.0052357 acc=0.9997559 +INFO:__main__:step= 311 loss=0.0044889 acc=0.9995728 +INFO:__main__:step= 312 loss=0.0043897 acc=0.9997559 +INFO:__main__:step= 313 loss=0.0050018 acc=0.9998779 +INFO:__main__:step= 314 loss=0.0043691 acc=0.9998169 +INFO:__main__:step= 315 loss=0.0042810 acc=0.9999390 +INFO:__main__:step= 316 loss=0.0045961 acc=0.9996338 +INFO:__main__:step= 317 loss=0.0047868 acc=0.9998169 +INFO:__main__:step= 318 loss=0.0047044 acc=0.9998169 +INFO:__main__:step= 319 loss=0.0046348 acc=0.9998169 +INFO:__main__:step= 320 loss=0.0043805 acc=0.9998779 +INFO:__main__:step= 321 loss=0.0041868 acc=0.9997559 +INFO:__main__:step= 322 loss=0.0045165 acc=0.9997559 +INFO:__main__:step= 323 loss=0.0050955 acc=0.9997559 +INFO:__main__:step= 324 loss=0.0046222 acc=0.9996948 +INFO:__main__:step= 325 loss=0.0045308 acc=0.9996338 +INFO:__main__:step= 326 loss=0.0052673 acc=0.9993896 +INFO:__main__:step= 327 loss=0.0049429 acc=0.9995728 +INFO:__main__:step= 328 loss=0.0042904 acc=0.9997559 +INFO:__main__:step= 329 loss=0.0047324 acc=0.9996948 +INFO:__main__:step= 330 loss=0.0049455 acc=0.9996338 +INFO:__main__:step= 331 loss=0.0042475 acc=0.9998779 +INFO:__main__:step= 332 loss=0.0045173 acc=0.9996948 +INFO:__main__:step= 333 loss=0.0047958 acc=0.9996338 +INFO:__main__:step= 334 loss=0.0045062 acc=0.9996948 +INFO:__main__:step= 335 loss=0.0043826 acc=0.9996338 +INFO:__main__:step= 336 loss=0.0047664 acc=0.9998169 +INFO:__main__:step= 337 loss=0.0048210 acc=0.9997559 +INFO:__main__:step= 338 loss=0.0048734 acc=0.9996338 +INFO:__main__:step= 339 loss=0.0038904 acc=0.9998169 +INFO:__main__:step= 340 loss=0.0046818 acc=0.9993896 +INFO:__main__:step= 341 loss=0.0050763 acc=0.9995728 +INFO:__main__:step= 342 loss=0.0042104 acc=0.9998779 +INFO:__main__:step= 343 loss=0.0045529 acc=0.9996338 +INFO:__main__:step= 344 loss=0.0044936 acc=0.9997559 +INFO:__main__:step= 345 loss=0.0045376 acc=0.9998779 +INFO:__main__:step= 346 loss=0.0046967 acc=0.9998169 +INFO:__main__:step= 347 loss=0.0042559 acc=0.9995728 +INFO:__main__:step= 348 loss=0.0049302 acc=0.9993286 +INFO:__main__:step= 349 loss=0.0045091 acc=0.9996948 +INFO:__main__:step= 350 loss=0.0043695 acc=0.9996948 +INFO:__main__:step= 351 loss=0.0048456 acc=0.9995117 +INFO:__main__:step= 352 loss=0.0043468 acc=0.9996948 +INFO:__main__:step= 353 loss=0.0044966 acc=0.9998779 +INFO:__main__:step= 354 loss=0.0042396 acc=0.9998169 +INFO:__main__:step= 355 loss=0.0048033 acc=0.9994507 +INFO:__main__:step= 356 loss=0.0043947 acc=0.9997559 +INFO:__main__:step= 357 loss=0.0042306 acc=0.9996338 +INFO:__main__:step= 358 loss=0.0040438 acc=0.9998169 +INFO:__main__:step= 359 loss=0.0046469 acc=0.9997559 +INFO:__main__:step= 360 loss=0.0044782 acc=0.9998169 +INFO:__main__:step= 361 loss=0.0045767 acc=0.9995117 +INFO:__main__:step= 362 loss=0.0044509 acc=0.9998169 +INFO:__main__:step= 363 loss=0.0049030 acc=0.9996948 +INFO:__main__:step= 364 loss=0.0042743 acc=0.9998779 +INFO:__main__:step= 365 loss=0.0043227 acc=0.9998779 +INFO:__main__:step= 366 loss=0.0045902 acc=0.9996338 +INFO:__main__:step= 367 loss=0.0040066 acc=1.0000000 +INFO:__main__:step= 368 loss=0.0041235 acc=0.9998779 +INFO:__main__:step= 369 loss=0.0042298 acc=0.9998779 +INFO:__main__:step= 370 loss=0.0041855 acc=0.9996948 +INFO:__main__:step= 371 loss=0.0042327 acc=0.9999390 +INFO:__main__:step= 372 loss=0.0045756 acc=0.9998169 +INFO:__main__:step= 373 loss=0.0041091 acc=0.9999390 +INFO:__main__:step= 374 loss=0.0046337 acc=0.9995728 +INFO:__main__:step= 375 loss=0.0048172 acc=0.9996338 +INFO:__main__:step= 376 loss=0.0049001 acc=0.9997559 +INFO:__main__:step= 377 loss=0.0047032 acc=0.9998779 +INFO:__main__:step= 378 loss=0.0039463 acc=0.9997559 +INFO:__main__:step= 379 loss=0.0043672 acc=0.9997559 +INFO:__main__:step= 380 loss=0.0043338 acc=0.9995728 +INFO:__main__:step= 381 loss=0.0044772 acc=0.9996948 +INFO:__main__:step= 382 loss=0.0040340 acc=0.9997559 +INFO:__main__:step= 383 loss=0.0042726 acc=0.9997559 +INFO:__main__:step= 384 loss=0.0042994 acc=0.9998779 +INFO:__main__:step= 385 loss=0.0040821 acc=0.9996948 +INFO:__main__:step= 386 loss=0.0036824 acc=0.9998779 +INFO:__main__:step= 387 loss=0.0049368 acc=0.9998169 +INFO:__main__:step= 388 loss=0.0040065 acc=0.9997559 +INFO:__main__:step= 389 loss=0.0043364 acc=0.9998779 +INFO:__main__:step= 390 loss=0.0042396 acc=0.9997559 +INFO:__main__:step= 391 loss=0.0041326 acc=0.9999390 +INFO:__main__:step= 392 loss=0.0045820 acc=0.9998169 +INFO:__main__:step= 393 loss=0.0045091 acc=0.9998779 +INFO:__main__:step= 394 loss=0.0041117 acc=0.9998169 +INFO:__main__:step= 395 loss=0.0037867 acc=0.9996948 +INFO:__main__:step= 396 loss=0.0041869 acc=0.9998169 +INFO:__main__:step= 397 loss=0.0041473 acc=0.9996948 +INFO:__main__:step= 398 loss=0.0045813 acc=0.9995728 +INFO:__main__:step= 399 loss=0.0037894 acc=0.9998779 +INFO:__main__:step= 400 loss=0.0042846 acc=0.9998169 +INFO:__main__:step= 401 loss=0.0039527 acc=0.9996338 +INFO:__main__:step= 402 loss=0.0042966 acc=0.9998169 +INFO:__main__:step= 403 loss=0.0034902 acc=0.9998169 +INFO:__main__:step= 404 loss=0.0039595 acc=0.9998169 +INFO:__main__:step= 405 loss=0.0036952 acc=0.9998779 +INFO:__main__:step= 406 loss=0.0039154 acc=0.9998779 +INFO:__main__:step= 407 loss=0.0041118 acc=0.9998169 +INFO:__main__:step= 408 loss=0.0037084 acc=0.9997559 +INFO:__main__:step= 409 loss=0.0038251 acc=0.9998779 +INFO:__main__:step= 410 loss=0.0042379 acc=0.9996948 +INFO:__main__:step= 411 loss=0.0044292 acc=0.9996338 +INFO:__main__:step= 412 loss=0.0048427 acc=1.0000000 +INFO:__main__:step= 413 loss=0.0044034 acc=0.9996948 +INFO:__main__:step= 414 loss=0.0039434 acc=0.9998779 +INFO:__main__:step= 415 loss=0.0036064 acc=0.9999390 +INFO:__main__:step= 416 loss=0.0043831 acc=0.9998169 +INFO:__main__:step= 417 loss=0.0036476 acc=0.9999390 +INFO:__main__:step= 418 loss=0.0039738 acc=0.9998779 +INFO:__main__:step= 419 loss=0.0043787 acc=0.9996948 +INFO:__main__:step= 420 loss=0.0042686 acc=0.9996948 +INFO:__main__:step= 421 loss=0.0045272 acc=0.9998169 +INFO:__main__:step= 422 loss=0.0043030 acc=0.9998779 +INFO:__main__:step= 423 loss=0.0043391 acc=0.9996338 +INFO:__main__:step= 424 loss=0.0037387 acc=0.9997559 +INFO:__main__:step= 425 loss=0.0039780 acc=0.9998779 +INFO:__main__:step= 426 loss=0.0041521 acc=0.9998779 +INFO:__main__:step= 427 loss=0.0044029 acc=0.9996948 +INFO:__main__:step= 428 loss=0.0037069 acc=0.9996948 +INFO:__main__:step= 429 loss=0.0039751 acc=0.9997559 +INFO:__main__:step= 430 loss=0.0036961 acc=0.9999390 +INFO:__main__:step= 431 loss=0.0041230 acc=0.9996948 +INFO:__main__:step= 432 loss=0.0037204 acc=0.9998169 +INFO:__main__:step= 433 loss=0.0041116 acc=0.9998169 +INFO:__main__:step= 434 loss=0.0043973 acc=0.9998779 +INFO:__main__:step= 435 loss=0.0042320 acc=0.9995728 +INFO:__main__:step= 436 loss=0.0048164 acc=0.9995728 +INFO:__main__:step= 437 loss=0.0038487 acc=0.9996948 +INFO:__main__:step= 438 loss=0.0044019 acc=0.9996948 +INFO:__main__:step= 439 loss=0.0036137 acc=0.9997559 +INFO:__main__:step= 440 loss=0.0040031 acc=0.9997559 +INFO:__main__:step= 441 loss=0.0037439 acc=1.0000000 +INFO:__main__:step= 442 loss=0.0047578 acc=0.9997559 +INFO:__main__:step= 443 loss=0.0037242 acc=1.0000000 +INFO:__main__:step= 444 loss=0.0035887 acc=0.9998169 +INFO:__main__:step= 445 loss=0.0040946 acc=0.9998779 +INFO:__main__:step= 446 loss=0.0038853 acc=0.9998779 +INFO:__main__:step= 447 loss=0.0045090 acc=0.9997559 +INFO:__main__:step= 448 loss=0.0035975 acc=0.9998779 +INFO:__main__:step= 449 loss=0.0039571 acc=0.9998169 +INFO:__main__:step= 450 loss=0.0033344 acc=0.9998779 +INFO:__main__:step= 451 loss=0.0038010 acc=0.9996338 +INFO:__main__:step= 452 loss=0.0037623 acc=0.9999390 +INFO:__main__:step= 453 loss=0.0039248 acc=0.9998779 +INFO:__main__:step= 454 loss=0.0038458 acc=0.9996948 +INFO:__main__:step= 455 loss=0.0038334 acc=0.9997559 +INFO:__main__:step= 456 loss=0.0039828 acc=0.9997559 +INFO:__main__:step= 457 loss=0.0037553 acc=0.9997559 +INFO:__main__:step= 458 loss=0.0034900 acc=0.9998169 +INFO:__main__:step= 459 loss=0.0039507 acc=0.9997559 +INFO:__main__:step= 460 loss=0.0040168 acc=0.9997559 +INFO:__main__:step= 461 loss=0.0039682 acc=0.9996338 +INFO:__main__:step= 462 loss=0.0033216 acc=1.0000000 +INFO:__main__:step= 463 loss=0.0035992 acc=0.9998779 +INFO:__main__:step= 464 loss=0.0035374 acc=0.9998169 +INFO:__main__:step= 465 loss=0.0039145 acc=0.9998169 +INFO:__main__:step= 466 loss=0.0036567 acc=0.9996338 +INFO:__main__:step= 467 loss=0.0035094 acc=0.9998169 +INFO:__main__:step= 468 loss=0.0038882 acc=0.9996338 +INFO:__main__:step= 469 loss=0.0037108 acc=0.9998169 +INFO:__main__:step= 470 loss=0.0034592 acc=0.9997559 +INFO:__main__:step= 471 loss=0.0031493 acc=0.9999390 +INFO:__main__:step= 472 loss=0.0038837 acc=0.9996948 +INFO:__main__:step= 473 loss=0.0035162 acc=0.9997559 +INFO:__main__:step= 474 loss=0.0037461 acc=0.9998169 +INFO:__main__:step= 475 loss=0.0037000 acc=1.0000000 +INFO:__main__:step= 476 loss=0.0042010 acc=0.9998169 +INFO:__main__:step= 477 loss=0.0034492 acc=0.9998779 +INFO:__main__:step= 478 loss=0.0035461 acc=1.0000000 +INFO:__main__:step= 479 loss=0.0035708 acc=0.9998169 +INFO:__main__:step= 480 loss=0.0040509 acc=0.9998169 +INFO:__main__:step= 481 loss=0.0033549 acc=0.9996948 +INFO:__main__:step= 482 loss=0.0035421 acc=1.0000000 +INFO:__main__:step= 483 loss=0.0032111 acc=0.9998779 +INFO:__main__:step= 484 loss=0.0039644 acc=0.9997559 +INFO:__main__:step= 485 loss=0.0042124 acc=0.9996338 +INFO:__main__:step= 486 loss=0.0040706 acc=0.9994507 +INFO:__main__:step= 487 loss=0.0035862 acc=0.9998169 +INFO:__main__:step= 488 loss=0.0044468 acc=0.9996338 +INFO:__main__:step= 489 loss=0.0036269 acc=0.9999390 +INFO:__main__:step= 490 loss=0.0031157 acc=0.9998169 +INFO:__main__:step= 491 loss=0.0037752 acc=0.9996948 +INFO:__main__:step= 492 loss=0.0033704 acc=0.9997559 +INFO:__main__:step= 493 loss=0.0040347 acc=0.9996948 +INFO:__main__:step= 494 loss=0.0034053 acc=0.9998169 +INFO:__main__:step= 495 loss=0.0036225 acc=0.9998779 +INFO:__main__:step= 496 loss=0.0042589 acc=0.9997559 +INFO:__main__:step= 497 loss=0.0040456 acc=0.9996338 +INFO:__main__:step= 498 loss=0.0035443 acc=0.9997559 +INFO:__main__:step= 499 loss=0.0035868 acc=1.0000000 +INFO:__main__:step= 500 loss=0.0035023 acc=0.9998779 +INFO:__main__:step= 1000 loss=0.0022645 acc=0.9996948 +INFO:__main__:step= 1500 loss=0.0006305 acc=0.9999390 +INFO:__main__:step= 2000 loss=0.0006203 acc=0.9997559 +INFO:__main__:step= 2500 loss=0.0004746 acc=0.9996948 +INFO:__main__:step= 3000 loss=0.0002501 acc=0.9999390 +INFO:__main__:step= 3500 loss=0.0001484 acc=1.0000000 +INFO:__main__:step= 4000 loss=0.0001385 acc=1.0000000 +INFO:__main__:step= 4500 loss=0.0001816 acc=0.9999390 +INFO:__main__:step= 5000 loss=0.0001943 acc=0.9999390 +INFO:__main__:step= 5500 loss=0.0001595 acc=1.0000000 +INFO:__main__:step= 6000 loss=0.0001608 acc=1.0000000 +INFO:__main__:step= 6500 loss=0.0000808 acc=1.0000000 +INFO:__main__:step= 7000 loss=0.0001676 acc=0.9998779 +INFO:__main__:step= 7500 loss=0.0001089 acc=1.0000000 +INFO:__main__:step= 8000 loss=0.0000847 acc=1.0000000 +INFO:__main__:step= 8500 loss=0.0000707 acc=1.0000000 +INFO:__main__:step= 9000 loss=0.0001036 acc=1.0000000 +INFO:__main__:step= 9500 loss=0.0001216 acc=1.0000000 +INFO:__main__:step=10000 loss=0.0000516 acc=1.0000000 +INFO:__main__:step=10500 loss=0.0000347 acc=1.0000000 +INFO:__main__:step=11000 loss=0.0000673 acc=1.0000000 +INFO:__main__:step=11500 loss=0.0000634 acc=1.0000000 +INFO:__main__:step=12000 loss=0.0000239 acc=1.0000000 +INFO:__main__:step=12500 loss=0.0000039 acc=1.0000000 +INFO:__main__:step=13000 loss=0.0000418 acc=1.0000000 +INFO:__main__:step=13500 loss=0.0000079 acc=1.0000000 +INFO:__main__:step=14000 loss=0.0000367 acc=1.0000000 +INFO:__main__:step=14500 loss=0.0000217 acc=1.0000000 +INFO:__main__:step=15000 loss=0.0000267 acc=1.0000000 +INFO:__main__:step=15500 loss=0.0000759 acc=1.0000000 +INFO:__main__:step=16000 loss=0.0000215 acc=1.0000000 +INFO:__main__:step=16500 loss=0.0000015 acc=1.0000000 +INFO:__main__:step=17000 loss=0.0000008 acc=1.0000000 +INFO:__main__:step=17500 loss=0.0000147 acc=1.0000000 +INFO:__main__:step=18000 loss=0.0000000 acc=1.0000000 +INFO:__main__:step=18500 loss=0.0000001 acc=1.0000000 +INFO:__main__:step=19000 loss=0.0000318 acc=1.0000000 +INFO:__main__:step=19500 loss=0.0000243 acc=1.0000000 +""" + +pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)") +rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)] +df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True) + +# Avoid log(0) issues for loss plot by clamping at a tiny positive value +eps = 1e-10 +df["loss_clamped"] = df["loss"].clip(lower=eps) + +# Plot 1: Loss +plt.figure(figsize=(9, 4.8)) +plt.plot(df["step"], df["loss_clamped"]) +plt.yscale("log") +plt.xlabel("Step") +plt.ylabel("Loss (log scale)") +plt.title("Training Loss vs Step") +plt.tight_layout() +plt.show() + +# Plot 2: Accuracy +df["err"] = (1.0 - df["acc"]).clip(lower=eps) +plt.figure(figsize=(9, 4.8)) +plt.plot(df["step"], df["err"]) +plt.yscale("log") +plt.xlabel("Step") +plt.ylabel("Error rate (1 - accuracy) (log scale)") +plt.title("Training Error Rate vs Step") +plt.tight_layout() +plt.show() diff --git a/pairwise_compare.py b/pairwise_compare.py new file mode 100755 index 0000000..8c6eeda --- /dev/null +++ b/pairwise_compare.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +# pairwise_compare.py +import logging +import random +import torch +from torch import nn +from torch.nn import functional as F + +# early pytorch device setup +DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu" + +# Valves +DIMENSIONS = 1 +TRAIN_STEPS = 20000 +TRAIN_BATCHSZ = 16384 +TRAIN_PROGRESS = 500 +BATCH_LOWER = -512.0 +BATCH_UPPER = 512.0 +DO_VERBOSE_EARLY_TRAIN = True + +def get_torch_info(): + log.info("PyTorch Version: %s", torch.__version__) + log.info("HIP Version: %s", torch.version.hip) + log.info("CUDA support: %s", torch.cuda.is_available()) + + if torch.cuda.is_available(): + log.info("CUDA device detected: %s", torch.cuda.get_device_name(0)) + + log.info("Using %s compute mode", DEVICE) + +def set_seed(seed: int): + random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + +# 1) Data: pairs (a, b) with label y = 1 if a > b else 0 +def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER): + a = (high - low) * torch.rand(batch_size, 1) + low + b = (high - low) * torch.rand(batch_size, 1) + low + + # train for if a > b + y = (a > b).float() + + # removed but left for my notes; it seems training for equality hurts classifing results that are ~eq + # when trained only on "if a > b => y", the model produces more accurate results when classifing if things are equal (~.5 prob). + # eq = (a == b).float() + # y = gt + 0.5 * eq + return a, b, y + +# 2) Number "embedding" network: R -> R^d +class NumberEmbedder(nn.Module): + def __init__(self, d=8): + super().__init__() + self.net = nn.Sequential( + nn.Linear(1, 16), + nn.ReLU(), + nn.Linear(16, d), + ) + + def forward(self, x): + return self.net(x) + +# 3) Comparator head: takes (ea, eb) -> logit for "a > b" +class PairwiseComparator(nn.Module): + def __init__(self, d=8): + super().__init__() + self.embed = NumberEmbedder(d) + self.head = nn.Sequential( + nn.Linear(2 * d + 1, 16), + nn.ReLU(), + nn.Linear(16, 1), + ) + + def forward(self, a, b): + ea = self.embed(a) + eb = self.embed(b) + delta_ab = a - b + x = torch.cat([ea, eb, delta_ab], dim=-1) + + return self.head(x) # logits + +def training_entry(): + # all prng seeds to 0 for deterministic outputs durring testing + # the seed should initialized normally otherwise + set_seed(0) + + model = PairwiseComparator(d=DIMENSIONS).to(DEVICE) + opt = torch.optim.AdamW(model.parameters(), lr=2e-3) + + # 4) Train + for step in range(TRAIN_STEPS): + a, b, y = sample_batch(TRAIN_BATCHSZ) + a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE) + + logits = model(a, b) + loss_fn = F.binary_cross_entropy_with_logits(logits, y) + + opt.zero_grad() + loss_fn.backward() + opt.step() + + if step <= TRAIN_PROGRESS and DO_VERBOSE_EARLY_TRAIN is True: + with torch.no_grad(): + pred = (torch.sigmoid(logits) > 0.5).float() + acc = (pred == y).float().mean().item() + log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}") + elif step % TRAIN_PROGRESS == 0: + with torch.no_grad(): + pred = (torch.sigmoid(logits) > 0.5).float() + acc = (pred == y).float().mean().item() + log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}") + + # 5) Quick test: evaluate accuracy on fresh pairs + with torch.no_grad(): + a, b, y = sample_batch(TRAIN_BATCHSZ) + a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE) + logits = model(a, b) + pred = (torch.sigmoid(logits) > 0.5).float() + errors = (pred != y).sum().item() + acc = (pred == y).float().mean().item() + log.info(f"Final test acc: {acc} errors: {errors}") + + # embed model depth into the model serialization + torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, "model.pth") + log.info("Saved PyTorch Model State to model.pth") + +def infer_entry(): + model_ckpt = torch.load("model.pth", map_location=DEVICE) + model = PairwiseComparator(d=model_ckpt["d"]).to(DEVICE) + model.load_state_dict(model_ckpt["state_dict"]) + model.eval() + + # sample pairs + pairs = [(1, 2), (10, 3), (5, 5), (10, 35), (-64, 11), (300, 162), (2, 0), (2, 1), (3, 1), (4, 1), (3, 10),(30, 1), (0, 0), (-162, 237), + (10, 20), (100, 30), (50, 50), (100, 350), (-640, 110), (30, -420), (200, 0), (92, 5), (30, 17), (42, 10), (30, 100),(30, 1), (0, 400), (-42, -42)] + a = torch.tensor([[p[0]] for p in pairs], dtype=torch.float32, device=DEVICE) + b = torch.tensor([[p[1]] for p in pairs], dtype=torch.float32, device=DEVICE) + + # sanity check before inference + log.info(f"a.device: {a.device} model.device: {next(model.parameters()).device}") + + with torch.no_grad(): + probs = torch.sigmoid(model(a, b)) + + for (x, y), p in zip(pairs, probs): + log.info(f"P({x} > {y}) = {p.item():.3f}") + +if __name__ == '__main__': + import sys + import os + import datetime + + log = logging.getLogger(__name__) + logging.basicConfig(filename='pairwise_compare.log', level=logging.INFO) + log.info(f"Log opened {datetime.datetime.now()}") + + get_torch_info() + + name = os.path.basename(sys.argv[0]) + if name == 'train.py': + training_entry() + elif name == 'infer.py': + infer_entry() + else: + # alt call patern + # python3 pairwise_compare.py train + # python3 pairwise_compare.py infer + if len(sys.argv) > 1: + mode = sys.argv[1].strip().lower() + if mode == "train": + training_entry() + elif mode == "infer": + infer_entry() + else: + log.error(f"Unknown operation: {mode}") + log.error("Invalid call syntax, call script as \"train.py\" or \"infer.py\" or as pairwise_compare.py where mode is \"train\" or \"infer\"") + else: + log.error("Not enough arguments passed to script; call as train.py or infer.py or as pairwise_compare.py where mode is \"train\" or \"infer\"") + + log.info(f"Log closed {datetime.datetime.now()}") \ No newline at end of file diff --git a/train.py b/train.py new file mode 120000 index 0000000..235fa81 --- /dev/null +++ b/train.py @@ -0,0 +1 @@ +pairwise_compare.py \ No newline at end of file