first commit

This commit is contained in:
2025-12-18 20:05:57 -05:00
commit 5c89149388
6 changed files with 814 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
.venv/
*.pth
*.log

55
README.md Normal file
View File

@@ -0,0 +1,55 @@
# Toy ML models and experiments
This repo contains some experiments with ML I am working on to learn how they work.
## Setup
### Debian 13 (Trixie) Python3 setup
```bash
$ sudo apt install python3-full python3-venv
python3-full is already the newest version (3.13.5-1).
python3-venv is already the newest version (3.13.5-1).
```
### Get the code and init the venv
```bash
git clone http://the.git.repo/user/mltoys.git mltoys && cd mltoys
python3 -m venv .venv
source .venv/bin/activate
```
### Installing python deps
For instructions installing pytorch refer to the [PyTorch Home page]
```bash
pip3 install numpy
# use the nvidia CUDA or CPU only packages if required.
# I use the ROCm packages, so the repo uses the ROCm packages.
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4
pip3 install pandas matplotlib
```
## Running the code
### Training the model
```bash
(.venv) ➜ python3 pairwise_compare.py train
```
### Running inference
```bash
(.venv) ➜ python3 pairwise_compare.py infer
```
### Some General notes
- This module logs to pairwise_compare.log using the python logging module
- Saved model state (training data) is saved at model.pth
- output_graphs.py generates two graphs to inspect the training process, copy the step log into the script to generate graphs.
[PyTorch Home page]: https://pytorch.org

1
infer.py Symbolic link
View File

@@ -0,0 +1 @@
pairwise_compare.py

573
output_graphs.py Normal file
View File

@@ -0,0 +1,573 @@
import re
import pandas as pd
import matplotlib.pyplot as plt
text = r"""INFO:__main__:step= 0 loss=1.3149878 acc=0.9093018
INFO:__main__:step= 1 loss=1.0067503 acc=0.9183350
INFO:__main__:step= 2 loss=0.7588759 acc=0.9278564
INFO:__main__:step= 3 loss=0.5432727 acc=0.9384766
INFO:__main__:step= 4 loss=0.3854972 acc=0.9427490
INFO:__main__:step= 5 loss=0.2379047 acc=0.9512939
INFO:__main__:step= 6 loss=0.1357967 acc=0.9575195
INFO:__main__:step= 7 loss=0.0659471 acc=0.9808960
INFO:__main__:step= 8 loss=0.0325629 acc=0.9878540
INFO:__main__:step= 9 loss=0.0159342 acc=0.9951172
INFO:__main__:step= 10 loss=0.0123130 acc=0.9995117
INFO:__main__:step= 11 loss=0.0129503 acc=0.9963989
INFO:__main__:step= 12 loss=0.0198095 acc=0.9927979
INFO:__main__:step= 13 loss=0.0261302 acc=0.9901733
INFO:__main__:step= 14 loss=0.0338637 acc=0.9891968
INFO:__main__:step= 15 loss=0.0425015 acc=0.9861450
INFO:__main__:step= 16 loss=0.0458022 acc=0.9865112
INFO:__main__:step= 17 loss=0.0492692 acc=0.9859619
INFO:__main__:step= 18 loss=0.0543331 acc=0.9850464
INFO:__main__:step= 19 loss=0.0575214 acc=0.9846802
INFO:__main__:step= 20 loss=0.0635471 acc=0.9841919
INFO:__main__:step= 21 loss=0.0503580 acc=0.9871216
INFO:__main__:step= 22 loss=0.0503206 acc=0.9863892
INFO:__main__:step= 23 loss=0.0549617 acc=0.9852905
INFO:__main__:step= 24 loss=0.0438731 acc=0.9872437
INFO:__main__:step= 25 loss=0.0471392 acc=0.9867554
INFO:__main__:step= 26 loss=0.0431895 acc=0.9868774
INFO:__main__:step= 27 loss=0.0354491 acc=0.9892578
INFO:__main__:step= 28 loss=0.0311145 acc=0.9905396
INFO:__main__:step= 29 loss=0.0305991 acc=0.9897461
INFO:__main__:step= 30 loss=0.0229458 acc=0.9920044
INFO:__main__:step= 31 loss=0.0203844 acc=0.9925537
INFO:__main__:step= 32 loss=0.0190371 acc=0.9929810
INFO:__main__:step= 33 loss=0.0167215 acc=0.9935303
INFO:__main__:step= 34 loss=0.0153083 acc=0.9937744
INFO:__main__:step= 35 loss=0.0120912 acc=0.9947510
INFO:__main__:step= 36 loss=0.0101961 acc=0.9961548
INFO:__main__:step= 37 loss=0.0100932 acc=0.9968262
INFO:__main__:step= 38 loss=0.0101871 acc=0.9968262
INFO:__main__:step= 39 loss=0.0088590 acc=0.9978027
INFO:__main__:step= 40 loss=0.0090834 acc=0.9982300
INFO:__main__:step= 41 loss=0.0084685 acc=0.9993896
INFO:__main__:step= 42 loss=0.0075352 acc=0.9989014
INFO:__main__:step= 43 loss=0.0080332 acc=0.9987793
INFO:__main__:step= 44 loss=0.0087203 acc=0.9979858
INFO:__main__:step= 45 loss=0.0086979 acc=0.9977417
INFO:__main__:step= 46 loss=0.0086313 acc=0.9972534
INFO:__main__:step= 47 loss=0.0092596 acc=0.9971924
INFO:__main__:step= 48 loss=0.0086056 acc=0.9978027
INFO:__main__:step= 49 loss=0.0101687 acc=0.9960327
INFO:__main__:step= 50 loss=0.0083168 acc=0.9978638
INFO:__main__:step= 51 loss=0.0095370 acc=0.9965820
INFO:__main__:step= 52 loss=0.0093318 acc=0.9968872
INFO:__main__:step= 53 loss=0.0093573 acc=0.9966431
INFO:__main__:step= 54 loss=0.0091653 acc=0.9971313
INFO:__main__:step= 55 loss=0.0097087 acc=0.9962769
INFO:__main__:step= 56 loss=0.0094240 acc=0.9972534
INFO:__main__:step= 57 loss=0.0088949 acc=0.9970093
INFO:__main__:step= 58 loss=0.0085519 acc=0.9975586
INFO:__main__:step= 59 loss=0.0098900 acc=0.9967651
INFO:__main__:step= 60 loss=0.0085984 acc=0.9977417
INFO:__main__:step= 61 loss=0.0082639 acc=0.9984131
INFO:__main__:step= 62 loss=0.0083590 acc=0.9979858
INFO:__main__:step= 63 loss=0.0072852 acc=0.9990234
INFO:__main__:step= 64 loss=0.0080845 acc=0.9983521
INFO:__main__:step= 65 loss=0.0074006 acc=0.9988403
INFO:__main__:step= 66 loss=0.0067656 acc=0.9990845
INFO:__main__:step= 67 loss=0.0072893 acc=0.9992065
INFO:__main__:step= 68 loss=0.0077519 acc=0.9991455
INFO:__main__:step= 69 loss=0.0072619 acc=0.9992065
INFO:__main__:step= 70 loss=0.0069199 acc=0.9993896
INFO:__main__:step= 71 loss=0.0070979 acc=0.9992676
INFO:__main__:step= 72 loss=0.0073295 acc=0.9992676
INFO:__main__:step= 73 loss=0.0074595 acc=0.9994507
INFO:__main__:step= 74 loss=0.0069510 acc=0.9993896
INFO:__main__:step= 75 loss=0.0072545 acc=0.9993896
INFO:__main__:step= 76 loss=0.0070704 acc=0.9994507
INFO:__main__:step= 77 loss=0.0080527 acc=0.9992065
INFO:__main__:step= 78 loss=0.0078222 acc=0.9990234
INFO:__main__:step= 79 loss=0.0074403 acc=0.9993896
INFO:__main__:step= 80 loss=0.0070136 acc=0.9992065
INFO:__main__:step= 81 loss=0.0068662 acc=0.9992676
INFO:__main__:step= 82 loss=0.0071025 acc=0.9992676
INFO:__main__:step= 83 loss=0.0068718 acc=0.9993896
INFO:__main__:step= 84 loss=0.0072839 acc=0.9991455
INFO:__main__:step= 85 loss=0.0071138 acc=0.9992676
INFO:__main__:step= 86 loss=0.0070522 acc=0.9992065
INFO:__main__:step= 87 loss=0.0073670 acc=0.9992065
INFO:__main__:step= 88 loss=0.0069347 acc=0.9995117
INFO:__main__:step= 89 loss=0.0066190 acc=0.9993286
INFO:__main__:step= 90 loss=0.0069868 acc=0.9997559
INFO:__main__:step= 91 loss=0.0071491 acc=0.9995728
INFO:__main__:step= 92 loss=0.0071107 acc=0.9991455
INFO:__main__:step= 93 loss=0.0065876 acc=0.9996338
INFO:__main__:step= 94 loss=0.0067994 acc=0.9995117
INFO:__main__:step= 95 loss=0.0072922 acc=0.9995728
INFO:__main__:step= 96 loss=0.0070875 acc=0.9995117
INFO:__main__:step= 97 loss=0.0067269 acc=0.9996338
INFO:__main__:step= 98 loss=0.0069193 acc=0.9995728
INFO:__main__:step= 99 loss=0.0068322 acc=0.9995728
INFO:__main__:step= 100 loss=0.0065636 acc=0.9995117
INFO:__main__:step= 101 loss=0.0072776 acc=0.9995117
INFO:__main__:step= 102 loss=0.0074568 acc=0.9996338
INFO:__main__:step= 103 loss=0.0067574 acc=0.9996338
INFO:__main__:step= 104 loss=0.0067869 acc=0.9995728
INFO:__main__:step= 105 loss=0.0064811 acc=0.9996948
INFO:__main__:step= 106 loss=0.0077020 acc=0.9990234
INFO:__main__:step= 107 loss=0.0069068 acc=0.9996338
INFO:__main__:step= 108 loss=0.0068666 acc=0.9994507
INFO:__main__:step= 109 loss=0.0066869 acc=0.9992065
INFO:__main__:step= 110 loss=0.0070414 acc=0.9993286
INFO:__main__:step= 111 loss=0.0072129 acc=0.9994507
INFO:__main__:step= 112 loss=0.0067772 acc=0.9993896
INFO:__main__:step= 113 loss=0.0069291 acc=0.9995728
INFO:__main__:step= 114 loss=0.0065003 acc=0.9996338
INFO:__main__:step= 115 loss=0.0071769 acc=0.9993286
INFO:__main__:step= 116 loss=0.0068795 acc=0.9993896
INFO:__main__:step= 117 loss=0.0068955 acc=0.9995117
INFO:__main__:step= 118 loss=0.0071484 acc=0.9993896
INFO:__main__:step= 119 loss=0.0072527 acc=0.9992676
INFO:__main__:step= 120 loss=0.0070072 acc=0.9996948
INFO:__main__:step= 121 loss=0.0069677 acc=0.9995117
INFO:__main__:step= 122 loss=0.0070129 acc=0.9996338
INFO:__main__:step= 123 loss=0.0068947 acc=0.9990845
INFO:__main__:step= 124 loss=0.0066504 acc=0.9996338
INFO:__main__:step= 125 loss=0.0060845 acc=0.9994507
INFO:__main__:step= 126 loss=0.0063553 acc=0.9995117
INFO:__main__:step= 127 loss=0.0063697 acc=0.9996948
INFO:__main__:step= 128 loss=0.0064274 acc=0.9995728
INFO:__main__:step= 129 loss=0.0057200 acc=0.9998169
INFO:__main__:step= 130 loss=0.0064411 acc=0.9996948
INFO:__main__:step= 131 loss=0.0060729 acc=0.9998779
INFO:__main__:step= 132 loss=0.0062524 acc=0.9996948
INFO:__main__:step= 133 loss=0.0060491 acc=0.9991455
INFO:__main__:step= 134 loss=0.0065517 acc=0.9996338
INFO:__main__:step= 135 loss=0.0058316 acc=0.9998169
INFO:__main__:step= 136 loss=0.0059018 acc=0.9997559
INFO:__main__:step= 137 loss=0.0065307 acc=0.9996948
INFO:__main__:step= 138 loss=0.0064329 acc=0.9995117
INFO:__main__:step= 139 loss=0.0066867 acc=0.9995728
INFO:__main__:step= 140 loss=0.0070452 acc=0.9992676
INFO:__main__:step= 141 loss=0.0064459 acc=0.9993896
INFO:__main__:step= 142 loss=0.0061031 acc=0.9997559
INFO:__main__:step= 143 loss=0.0067490 acc=0.9992676
INFO:__main__:step= 144 loss=0.0065364 acc=0.9998779
INFO:__main__:step= 145 loss=0.0058783 acc=0.9993286
INFO:__main__:step= 146 loss=0.0060647 acc=0.9998169
INFO:__main__:step= 147 loss=0.0064641 acc=0.9996338
INFO:__main__:step= 148 loss=0.0065979 acc=0.9996338
INFO:__main__:step= 149 loss=0.0058760 acc=0.9997559
INFO:__main__:step= 150 loss=0.0058440 acc=0.9998779
INFO:__main__:step= 151 loss=0.0060483 acc=0.9998169
INFO:__main__:step= 152 loss=0.0062729 acc=0.9998169
INFO:__main__:step= 153 loss=0.0061169 acc=0.9996338
INFO:__main__:step= 154 loss=0.0064641 acc=0.9995728
INFO:__main__:step= 155 loss=0.0062200 acc=0.9995117
INFO:__main__:step= 156 loss=0.0062547 acc=0.9997559
INFO:__main__:step= 157 loss=0.0064813 acc=0.9995728
INFO:__main__:step= 158 loss=0.0061465 acc=0.9996338
INFO:__main__:step= 159 loss=0.0056798 acc=0.9995728
INFO:__main__:step= 160 loss=0.0064472 acc=0.9996338
INFO:__main__:step= 161 loss=0.0055585 acc=0.9997559
INFO:__main__:step= 162 loss=0.0057363 acc=0.9995117
INFO:__main__:step= 163 loss=0.0061281 acc=0.9995728
INFO:__main__:step= 164 loss=0.0062811 acc=0.9995117
INFO:__main__:step= 165 loss=0.0058447 acc=0.9997559
INFO:__main__:step= 166 loss=0.0054476 acc=0.9998169
INFO:__main__:step= 167 loss=0.0067439 acc=0.9998779
INFO:__main__:step= 168 loss=0.0063363 acc=0.9992676
INFO:__main__:step= 169 loss=0.0060227 acc=0.9995117
INFO:__main__:step= 170 loss=0.0060772 acc=0.9998169
INFO:__main__:step= 171 loss=0.0060449 acc=0.9994507
INFO:__main__:step= 172 loss=0.0061583 acc=0.9997559
INFO:__main__:step= 173 loss=0.0061034 acc=0.9995728
INFO:__main__:step= 174 loss=0.0060934 acc=0.9996948
INFO:__main__:step= 175 loss=0.0062173 acc=0.9995117
INFO:__main__:step= 176 loss=0.0065296 acc=0.9995728
INFO:__main__:step= 177 loss=0.0067725 acc=0.9994507
INFO:__main__:step= 178 loss=0.0064399 acc=0.9997559
INFO:__main__:step= 179 loss=0.0060964 acc=0.9997559
INFO:__main__:step= 180 loss=0.0053458 acc=0.9995728
INFO:__main__:step= 181 loss=0.0056739 acc=0.9996338
INFO:__main__:step= 182 loss=0.0054851 acc=0.9996338
INFO:__main__:step= 183 loss=0.0064169 acc=0.9997559
INFO:__main__:step= 184 loss=0.0058324 acc=0.9996948
INFO:__main__:step= 185 loss=0.0058293 acc=0.9994507
INFO:__main__:step= 186 loss=0.0067926 acc=0.9996338
INFO:__main__:step= 187 loss=0.0053189 acc=0.9994507
INFO:__main__:step= 188 loss=0.0059640 acc=0.9998169
INFO:__main__:step= 189 loss=0.0055875 acc=0.9996338
INFO:__main__:step= 190 loss=0.0057980 acc=0.9995117
INFO:__main__:step= 191 loss=0.0057864 acc=0.9996338
INFO:__main__:step= 192 loss=0.0052436 acc=0.9999390
INFO:__main__:step= 193 loss=0.0063448 acc=0.9994507
INFO:__main__:step= 194 loss=0.0056429 acc=0.9996948
INFO:__main__:step= 195 loss=0.0057287 acc=0.9993286
INFO:__main__:step= 196 loss=0.0058553 acc=0.9998169
INFO:__main__:step= 197 loss=0.0058331 acc=0.9992676
INFO:__main__:step= 198 loss=0.0052676 acc=0.9995117
INFO:__main__:step= 199 loss=0.0058648 acc=0.9996948
INFO:__main__:step= 200 loss=0.0061336 acc=0.9998169
INFO:__main__:step= 201 loss=0.0053004 acc=0.9998169
INFO:__main__:step= 202 loss=0.0061638 acc=0.9997559
INFO:__main__:step= 203 loss=0.0065025 acc=0.9995117
INFO:__main__:step= 204 loss=0.0059157 acc=0.9998169
INFO:__main__:step= 205 loss=0.0054036 acc=0.9998779
INFO:__main__:step= 206 loss=0.0056138 acc=0.9996338
INFO:__main__:step= 207 loss=0.0054218 acc=0.9996338
INFO:__main__:step= 208 loss=0.0066803 acc=0.9992676
INFO:__main__:step= 209 loss=0.0058715 acc=0.9996948
INFO:__main__:step= 210 loss=0.0055717 acc=0.9998169
INFO:__main__:step= 211 loss=0.0060626 acc=0.9994507
INFO:__main__:step= 212 loss=0.0051552 acc=0.9998779
INFO:__main__:step= 213 loss=0.0059606 acc=0.9995117
INFO:__main__:step= 214 loss=0.0049876 acc=0.9998779
INFO:__main__:step= 215 loss=0.0057766 acc=0.9996338
INFO:__main__:step= 216 loss=0.0060298 acc=0.9996338
INFO:__main__:step= 217 loss=0.0055154 acc=0.9996338
INFO:__main__:step= 218 loss=0.0055578 acc=0.9998169
INFO:__main__:step= 219 loss=0.0052298 acc=0.9995728
INFO:__main__:step= 220 loss=0.0064167 acc=0.9995117
INFO:__main__:step= 221 loss=0.0056415 acc=0.9996338
INFO:__main__:step= 222 loss=0.0052709 acc=0.9995728
INFO:__main__:step= 223 loss=0.0060531 acc=0.9995728
INFO:__main__:step= 224 loss=0.0049705 acc=0.9995117
INFO:__main__:step= 225 loss=0.0053588 acc=0.9997559
INFO:__main__:step= 226 loss=0.0052681 acc=0.9998779
INFO:__main__:step= 227 loss=0.0053794 acc=0.9996948
INFO:__main__:step= 228 loss=0.0056610 acc=0.9997559
INFO:__main__:step= 229 loss=0.0057926 acc=0.9995728
INFO:__main__:step= 230 loss=0.0053835 acc=0.9998169
INFO:__main__:step= 231 loss=0.0045384 acc=0.9996948
INFO:__main__:step= 232 loss=0.0048465 acc=0.9997559
INFO:__main__:step= 233 loss=0.0057317 acc=0.9995117
INFO:__main__:step= 234 loss=0.0058601 acc=0.9995117
INFO:__main__:step= 235 loss=0.0050546 acc=0.9997559
INFO:__main__:step= 236 loss=0.0049725 acc=1.0000000
INFO:__main__:step= 237 loss=0.0048411 acc=0.9995728
INFO:__main__:step= 238 loss=0.0054285 acc=0.9998169
INFO:__main__:step= 239 loss=0.0051857 acc=0.9997559
INFO:__main__:step= 240 loss=0.0054422 acc=0.9996948
INFO:__main__:step= 241 loss=0.0054665 acc=0.9999390
INFO:__main__:step= 242 loss=0.0049092 acc=0.9996948
INFO:__main__:step= 243 loss=0.0050945 acc=0.9995728
INFO:__main__:step= 244 loss=0.0057586 acc=0.9992676
INFO:__main__:step= 245 loss=0.0055652 acc=0.9997559
INFO:__main__:step= 246 loss=0.0052052 acc=0.9996948
INFO:__main__:step= 247 loss=0.0050273 acc=0.9996948
INFO:__main__:step= 248 loss=0.0055413 acc=0.9994507
INFO:__main__:step= 249 loss=0.0050175 acc=0.9998169
INFO:__main__:step= 250 loss=0.0049932 acc=0.9998169
INFO:__main__:step= 251 loss=0.0050813 acc=0.9997559
INFO:__main__:step= 252 loss=0.0051735 acc=0.9998169
INFO:__main__:step= 253 loss=0.0049972 acc=0.9998169
INFO:__main__:step= 254 loss=0.0051353 acc=0.9996338
INFO:__main__:step= 255 loss=0.0052580 acc=0.9996338
INFO:__main__:step= 256 loss=0.0047390 acc=0.9996948
INFO:__main__:step= 257 loss=0.0051188 acc=0.9997559
INFO:__main__:step= 258 loss=0.0046280 acc=0.9997559
INFO:__main__:step= 259 loss=0.0047725 acc=0.9997559
INFO:__main__:step= 260 loss=0.0045635 acc=0.9996338
INFO:__main__:step= 261 loss=0.0053751 acc=0.9995728
INFO:__main__:step= 262 loss=0.0046179 acc=0.9998779
INFO:__main__:step= 263 loss=0.0052214 acc=0.9994507
INFO:__main__:step= 264 loss=0.0044471 acc=0.9996948
INFO:__main__:step= 265 loss=0.0053352 acc=0.9997559
INFO:__main__:step= 266 loss=0.0053903 acc=0.9998779
INFO:__main__:step= 267 loss=0.0050240 acc=0.9996948
INFO:__main__:step= 268 loss=0.0048375 acc=0.9996948
INFO:__main__:step= 269 loss=0.0055583 acc=0.9995117
INFO:__main__:step= 270 loss=0.0047512 acc=1.0000000
INFO:__main__:step= 271 loss=0.0045881 acc=0.9998779
INFO:__main__:step= 272 loss=0.0045754 acc=0.9996338
INFO:__main__:step= 273 loss=0.0051980 acc=0.9996338
INFO:__main__:step= 274 loss=0.0048658 acc=0.9997559
INFO:__main__:step= 275 loss=0.0049445 acc=0.9996338
INFO:__main__:step= 276 loss=0.0051786 acc=0.9998779
INFO:__main__:step= 277 loss=0.0045741 acc=0.9997559
INFO:__main__:step= 278 loss=0.0054667 acc=0.9995728
INFO:__main__:step= 279 loss=0.0049829 acc=0.9998169
INFO:__main__:step= 280 loss=0.0049411 acc=0.9995117
INFO:__main__:step= 281 loss=0.0051374 acc=0.9998779
INFO:__main__:step= 282 loss=0.0046585 acc=0.9996338
INFO:__main__:step= 283 loss=0.0049610 acc=0.9997559
INFO:__main__:step= 284 loss=0.0044862 acc=0.9998779
INFO:__main__:step= 285 loss=0.0050020 acc=0.9996948
INFO:__main__:step= 286 loss=0.0043370 acc=0.9999390
INFO:__main__:step= 287 loss=0.0051258 acc=0.9996338
INFO:__main__:step= 288 loss=0.0043914 acc=0.9998779
INFO:__main__:step= 289 loss=0.0043173 acc=0.9996948
INFO:__main__:step= 290 loss=0.0047723 acc=0.9998779
INFO:__main__:step= 291 loss=0.0049867 acc=0.9996948
INFO:__main__:step= 292 loss=0.0046409 acc=0.9998169
INFO:__main__:step= 293 loss=0.0039387 acc=1.0000000
INFO:__main__:step= 294 loss=0.0051583 acc=0.9994507
INFO:__main__:step= 295 loss=0.0043871 acc=0.9998169
INFO:__main__:step= 296 loss=0.0046067 acc=0.9997559
INFO:__main__:step= 297 loss=0.0043712 acc=0.9996948
INFO:__main__:step= 298 loss=0.0047873 acc=0.9995117
INFO:__main__:step= 299 loss=0.0047681 acc=0.9998779
INFO:__main__:step= 300 loss=0.0050251 acc=0.9997559
INFO:__main__:step= 301 loss=0.0045633 acc=0.9998169
INFO:__main__:step= 302 loss=0.0043650 acc=0.9999390
INFO:__main__:step= 303 loss=0.0049224 acc=0.9997559
INFO:__main__:step= 304 loss=0.0049409 acc=0.9998169
INFO:__main__:step= 305 loss=0.0052364 acc=0.9996948
INFO:__main__:step= 306 loss=0.0051323 acc=0.9995728
INFO:__main__:step= 307 loss=0.0047827 acc=0.9997559
INFO:__main__:step= 308 loss=0.0044320 acc=0.9997559
INFO:__main__:step= 309 loss=0.0050638 acc=0.9997559
INFO:__main__:step= 310 loss=0.0052357 acc=0.9997559
INFO:__main__:step= 311 loss=0.0044889 acc=0.9995728
INFO:__main__:step= 312 loss=0.0043897 acc=0.9997559
INFO:__main__:step= 313 loss=0.0050018 acc=0.9998779
INFO:__main__:step= 314 loss=0.0043691 acc=0.9998169
INFO:__main__:step= 315 loss=0.0042810 acc=0.9999390
INFO:__main__:step= 316 loss=0.0045961 acc=0.9996338
INFO:__main__:step= 317 loss=0.0047868 acc=0.9998169
INFO:__main__:step= 318 loss=0.0047044 acc=0.9998169
INFO:__main__:step= 319 loss=0.0046348 acc=0.9998169
INFO:__main__:step= 320 loss=0.0043805 acc=0.9998779
INFO:__main__:step= 321 loss=0.0041868 acc=0.9997559
INFO:__main__:step= 322 loss=0.0045165 acc=0.9997559
INFO:__main__:step= 323 loss=0.0050955 acc=0.9997559
INFO:__main__:step= 324 loss=0.0046222 acc=0.9996948
INFO:__main__:step= 325 loss=0.0045308 acc=0.9996338
INFO:__main__:step= 326 loss=0.0052673 acc=0.9993896
INFO:__main__:step= 327 loss=0.0049429 acc=0.9995728
INFO:__main__:step= 328 loss=0.0042904 acc=0.9997559
INFO:__main__:step= 329 loss=0.0047324 acc=0.9996948
INFO:__main__:step= 330 loss=0.0049455 acc=0.9996338
INFO:__main__:step= 331 loss=0.0042475 acc=0.9998779
INFO:__main__:step= 332 loss=0.0045173 acc=0.9996948
INFO:__main__:step= 333 loss=0.0047958 acc=0.9996338
INFO:__main__:step= 334 loss=0.0045062 acc=0.9996948
INFO:__main__:step= 335 loss=0.0043826 acc=0.9996338
INFO:__main__:step= 336 loss=0.0047664 acc=0.9998169
INFO:__main__:step= 337 loss=0.0048210 acc=0.9997559
INFO:__main__:step= 338 loss=0.0048734 acc=0.9996338
INFO:__main__:step= 339 loss=0.0038904 acc=0.9998169
INFO:__main__:step= 340 loss=0.0046818 acc=0.9993896
INFO:__main__:step= 341 loss=0.0050763 acc=0.9995728
INFO:__main__:step= 342 loss=0.0042104 acc=0.9998779
INFO:__main__:step= 343 loss=0.0045529 acc=0.9996338
INFO:__main__:step= 344 loss=0.0044936 acc=0.9997559
INFO:__main__:step= 345 loss=0.0045376 acc=0.9998779
INFO:__main__:step= 346 loss=0.0046967 acc=0.9998169
INFO:__main__:step= 347 loss=0.0042559 acc=0.9995728
INFO:__main__:step= 348 loss=0.0049302 acc=0.9993286
INFO:__main__:step= 349 loss=0.0045091 acc=0.9996948
INFO:__main__:step= 350 loss=0.0043695 acc=0.9996948
INFO:__main__:step= 351 loss=0.0048456 acc=0.9995117
INFO:__main__:step= 352 loss=0.0043468 acc=0.9996948
INFO:__main__:step= 353 loss=0.0044966 acc=0.9998779
INFO:__main__:step= 354 loss=0.0042396 acc=0.9998169
INFO:__main__:step= 355 loss=0.0048033 acc=0.9994507
INFO:__main__:step= 356 loss=0.0043947 acc=0.9997559
INFO:__main__:step= 357 loss=0.0042306 acc=0.9996338
INFO:__main__:step= 358 loss=0.0040438 acc=0.9998169
INFO:__main__:step= 359 loss=0.0046469 acc=0.9997559
INFO:__main__:step= 360 loss=0.0044782 acc=0.9998169
INFO:__main__:step= 361 loss=0.0045767 acc=0.9995117
INFO:__main__:step= 362 loss=0.0044509 acc=0.9998169
INFO:__main__:step= 363 loss=0.0049030 acc=0.9996948
INFO:__main__:step= 364 loss=0.0042743 acc=0.9998779
INFO:__main__:step= 365 loss=0.0043227 acc=0.9998779
INFO:__main__:step= 366 loss=0.0045902 acc=0.9996338
INFO:__main__:step= 367 loss=0.0040066 acc=1.0000000
INFO:__main__:step= 368 loss=0.0041235 acc=0.9998779
INFO:__main__:step= 369 loss=0.0042298 acc=0.9998779
INFO:__main__:step= 370 loss=0.0041855 acc=0.9996948
INFO:__main__:step= 371 loss=0.0042327 acc=0.9999390
INFO:__main__:step= 372 loss=0.0045756 acc=0.9998169
INFO:__main__:step= 373 loss=0.0041091 acc=0.9999390
INFO:__main__:step= 374 loss=0.0046337 acc=0.9995728
INFO:__main__:step= 375 loss=0.0048172 acc=0.9996338
INFO:__main__:step= 376 loss=0.0049001 acc=0.9997559
INFO:__main__:step= 377 loss=0.0047032 acc=0.9998779
INFO:__main__:step= 378 loss=0.0039463 acc=0.9997559
INFO:__main__:step= 379 loss=0.0043672 acc=0.9997559
INFO:__main__:step= 380 loss=0.0043338 acc=0.9995728
INFO:__main__:step= 381 loss=0.0044772 acc=0.9996948
INFO:__main__:step= 382 loss=0.0040340 acc=0.9997559
INFO:__main__:step= 383 loss=0.0042726 acc=0.9997559
INFO:__main__:step= 384 loss=0.0042994 acc=0.9998779
INFO:__main__:step= 385 loss=0.0040821 acc=0.9996948
INFO:__main__:step= 386 loss=0.0036824 acc=0.9998779
INFO:__main__:step= 387 loss=0.0049368 acc=0.9998169
INFO:__main__:step= 388 loss=0.0040065 acc=0.9997559
INFO:__main__:step= 389 loss=0.0043364 acc=0.9998779
INFO:__main__:step= 390 loss=0.0042396 acc=0.9997559
INFO:__main__:step= 391 loss=0.0041326 acc=0.9999390
INFO:__main__:step= 392 loss=0.0045820 acc=0.9998169
INFO:__main__:step= 393 loss=0.0045091 acc=0.9998779
INFO:__main__:step= 394 loss=0.0041117 acc=0.9998169
INFO:__main__:step= 395 loss=0.0037867 acc=0.9996948
INFO:__main__:step= 396 loss=0.0041869 acc=0.9998169
INFO:__main__:step= 397 loss=0.0041473 acc=0.9996948
INFO:__main__:step= 398 loss=0.0045813 acc=0.9995728
INFO:__main__:step= 399 loss=0.0037894 acc=0.9998779
INFO:__main__:step= 400 loss=0.0042846 acc=0.9998169
INFO:__main__:step= 401 loss=0.0039527 acc=0.9996338
INFO:__main__:step= 402 loss=0.0042966 acc=0.9998169
INFO:__main__:step= 403 loss=0.0034902 acc=0.9998169
INFO:__main__:step= 404 loss=0.0039595 acc=0.9998169
INFO:__main__:step= 405 loss=0.0036952 acc=0.9998779
INFO:__main__:step= 406 loss=0.0039154 acc=0.9998779
INFO:__main__:step= 407 loss=0.0041118 acc=0.9998169
INFO:__main__:step= 408 loss=0.0037084 acc=0.9997559
INFO:__main__:step= 409 loss=0.0038251 acc=0.9998779
INFO:__main__:step= 410 loss=0.0042379 acc=0.9996948
INFO:__main__:step= 411 loss=0.0044292 acc=0.9996338
INFO:__main__:step= 412 loss=0.0048427 acc=1.0000000
INFO:__main__:step= 413 loss=0.0044034 acc=0.9996948
INFO:__main__:step= 414 loss=0.0039434 acc=0.9998779
INFO:__main__:step= 415 loss=0.0036064 acc=0.9999390
INFO:__main__:step= 416 loss=0.0043831 acc=0.9998169
INFO:__main__:step= 417 loss=0.0036476 acc=0.9999390
INFO:__main__:step= 418 loss=0.0039738 acc=0.9998779
INFO:__main__:step= 419 loss=0.0043787 acc=0.9996948
INFO:__main__:step= 420 loss=0.0042686 acc=0.9996948
INFO:__main__:step= 421 loss=0.0045272 acc=0.9998169
INFO:__main__:step= 422 loss=0.0043030 acc=0.9998779
INFO:__main__:step= 423 loss=0.0043391 acc=0.9996338
INFO:__main__:step= 424 loss=0.0037387 acc=0.9997559
INFO:__main__:step= 425 loss=0.0039780 acc=0.9998779
INFO:__main__:step= 426 loss=0.0041521 acc=0.9998779
INFO:__main__:step= 427 loss=0.0044029 acc=0.9996948
INFO:__main__:step= 428 loss=0.0037069 acc=0.9996948
INFO:__main__:step= 429 loss=0.0039751 acc=0.9997559
INFO:__main__:step= 430 loss=0.0036961 acc=0.9999390
INFO:__main__:step= 431 loss=0.0041230 acc=0.9996948
INFO:__main__:step= 432 loss=0.0037204 acc=0.9998169
INFO:__main__:step= 433 loss=0.0041116 acc=0.9998169
INFO:__main__:step= 434 loss=0.0043973 acc=0.9998779
INFO:__main__:step= 435 loss=0.0042320 acc=0.9995728
INFO:__main__:step= 436 loss=0.0048164 acc=0.9995728
INFO:__main__:step= 437 loss=0.0038487 acc=0.9996948
INFO:__main__:step= 438 loss=0.0044019 acc=0.9996948
INFO:__main__:step= 439 loss=0.0036137 acc=0.9997559
INFO:__main__:step= 440 loss=0.0040031 acc=0.9997559
INFO:__main__:step= 441 loss=0.0037439 acc=1.0000000
INFO:__main__:step= 442 loss=0.0047578 acc=0.9997559
INFO:__main__:step= 443 loss=0.0037242 acc=1.0000000
INFO:__main__:step= 444 loss=0.0035887 acc=0.9998169
INFO:__main__:step= 445 loss=0.0040946 acc=0.9998779
INFO:__main__:step= 446 loss=0.0038853 acc=0.9998779
INFO:__main__:step= 447 loss=0.0045090 acc=0.9997559
INFO:__main__:step= 448 loss=0.0035975 acc=0.9998779
INFO:__main__:step= 449 loss=0.0039571 acc=0.9998169
INFO:__main__:step= 450 loss=0.0033344 acc=0.9998779
INFO:__main__:step= 451 loss=0.0038010 acc=0.9996338
INFO:__main__:step= 452 loss=0.0037623 acc=0.9999390
INFO:__main__:step= 453 loss=0.0039248 acc=0.9998779
INFO:__main__:step= 454 loss=0.0038458 acc=0.9996948
INFO:__main__:step= 455 loss=0.0038334 acc=0.9997559
INFO:__main__:step= 456 loss=0.0039828 acc=0.9997559
INFO:__main__:step= 457 loss=0.0037553 acc=0.9997559
INFO:__main__:step= 458 loss=0.0034900 acc=0.9998169
INFO:__main__:step= 459 loss=0.0039507 acc=0.9997559
INFO:__main__:step= 460 loss=0.0040168 acc=0.9997559
INFO:__main__:step= 461 loss=0.0039682 acc=0.9996338
INFO:__main__:step= 462 loss=0.0033216 acc=1.0000000
INFO:__main__:step= 463 loss=0.0035992 acc=0.9998779
INFO:__main__:step= 464 loss=0.0035374 acc=0.9998169
INFO:__main__:step= 465 loss=0.0039145 acc=0.9998169
INFO:__main__:step= 466 loss=0.0036567 acc=0.9996338
INFO:__main__:step= 467 loss=0.0035094 acc=0.9998169
INFO:__main__:step= 468 loss=0.0038882 acc=0.9996338
INFO:__main__:step= 469 loss=0.0037108 acc=0.9998169
INFO:__main__:step= 470 loss=0.0034592 acc=0.9997559
INFO:__main__:step= 471 loss=0.0031493 acc=0.9999390
INFO:__main__:step= 472 loss=0.0038837 acc=0.9996948
INFO:__main__:step= 473 loss=0.0035162 acc=0.9997559
INFO:__main__:step= 474 loss=0.0037461 acc=0.9998169
INFO:__main__:step= 475 loss=0.0037000 acc=1.0000000
INFO:__main__:step= 476 loss=0.0042010 acc=0.9998169
INFO:__main__:step= 477 loss=0.0034492 acc=0.9998779
INFO:__main__:step= 478 loss=0.0035461 acc=1.0000000
INFO:__main__:step= 479 loss=0.0035708 acc=0.9998169
INFO:__main__:step= 480 loss=0.0040509 acc=0.9998169
INFO:__main__:step= 481 loss=0.0033549 acc=0.9996948
INFO:__main__:step= 482 loss=0.0035421 acc=1.0000000
INFO:__main__:step= 483 loss=0.0032111 acc=0.9998779
INFO:__main__:step= 484 loss=0.0039644 acc=0.9997559
INFO:__main__:step= 485 loss=0.0042124 acc=0.9996338
INFO:__main__:step= 486 loss=0.0040706 acc=0.9994507
INFO:__main__:step= 487 loss=0.0035862 acc=0.9998169
INFO:__main__:step= 488 loss=0.0044468 acc=0.9996338
INFO:__main__:step= 489 loss=0.0036269 acc=0.9999390
INFO:__main__:step= 490 loss=0.0031157 acc=0.9998169
INFO:__main__:step= 491 loss=0.0037752 acc=0.9996948
INFO:__main__:step= 492 loss=0.0033704 acc=0.9997559
INFO:__main__:step= 493 loss=0.0040347 acc=0.9996948
INFO:__main__:step= 494 loss=0.0034053 acc=0.9998169
INFO:__main__:step= 495 loss=0.0036225 acc=0.9998779
INFO:__main__:step= 496 loss=0.0042589 acc=0.9997559
INFO:__main__:step= 497 loss=0.0040456 acc=0.9996338
INFO:__main__:step= 498 loss=0.0035443 acc=0.9997559
INFO:__main__:step= 499 loss=0.0035868 acc=1.0000000
INFO:__main__:step= 500 loss=0.0035023 acc=0.9998779
INFO:__main__:step= 1000 loss=0.0022645 acc=0.9996948
INFO:__main__:step= 1500 loss=0.0006305 acc=0.9999390
INFO:__main__:step= 2000 loss=0.0006203 acc=0.9997559
INFO:__main__:step= 2500 loss=0.0004746 acc=0.9996948
INFO:__main__:step= 3000 loss=0.0002501 acc=0.9999390
INFO:__main__:step= 3500 loss=0.0001484 acc=1.0000000
INFO:__main__:step= 4000 loss=0.0001385 acc=1.0000000
INFO:__main__:step= 4500 loss=0.0001816 acc=0.9999390
INFO:__main__:step= 5000 loss=0.0001943 acc=0.9999390
INFO:__main__:step= 5500 loss=0.0001595 acc=1.0000000
INFO:__main__:step= 6000 loss=0.0001608 acc=1.0000000
INFO:__main__:step= 6500 loss=0.0000808 acc=1.0000000
INFO:__main__:step= 7000 loss=0.0001676 acc=0.9998779
INFO:__main__:step= 7500 loss=0.0001089 acc=1.0000000
INFO:__main__:step= 8000 loss=0.0000847 acc=1.0000000
INFO:__main__:step= 8500 loss=0.0000707 acc=1.0000000
INFO:__main__:step= 9000 loss=0.0001036 acc=1.0000000
INFO:__main__:step= 9500 loss=0.0001216 acc=1.0000000
INFO:__main__:step=10000 loss=0.0000516 acc=1.0000000
INFO:__main__:step=10500 loss=0.0000347 acc=1.0000000
INFO:__main__:step=11000 loss=0.0000673 acc=1.0000000
INFO:__main__:step=11500 loss=0.0000634 acc=1.0000000
INFO:__main__:step=12000 loss=0.0000239 acc=1.0000000
INFO:__main__:step=12500 loss=0.0000039 acc=1.0000000
INFO:__main__:step=13000 loss=0.0000418 acc=1.0000000
INFO:__main__:step=13500 loss=0.0000079 acc=1.0000000
INFO:__main__:step=14000 loss=0.0000367 acc=1.0000000
INFO:__main__:step=14500 loss=0.0000217 acc=1.0000000
INFO:__main__:step=15000 loss=0.0000267 acc=1.0000000
INFO:__main__:step=15500 loss=0.0000759 acc=1.0000000
INFO:__main__:step=16000 loss=0.0000215 acc=1.0000000
INFO:__main__:step=16500 loss=0.0000015 acc=1.0000000
INFO:__main__:step=17000 loss=0.0000008 acc=1.0000000
INFO:__main__:step=17500 loss=0.0000147 acc=1.0000000
INFO:__main__:step=18000 loss=0.0000000 acc=1.0000000
INFO:__main__:step=18500 loss=0.0000001 acc=1.0000000
INFO:__main__:step=19000 loss=0.0000318 acc=1.0000000
INFO:__main__:step=19500 loss=0.0000243 acc=1.0000000
"""
pattern = re.compile(r"step=\s*(\d+)\s+loss=([0-9.]+)\s+acc=([0-9.]+)")
rows = [(int(s), float(l), float(a)) for s, l, a in pattern.findall(text)]
df = pd.DataFrame(rows, columns=["step", "loss", "acc"]).sort_values("step").reset_index(drop=True)
# Avoid log(0) issues for loss plot by clamping at a tiny positive value
eps = 1e-10
df["loss_clamped"] = df["loss"].clip(lower=eps)
# Plot 1: Loss
plt.figure(figsize=(9, 4.8))
plt.plot(df["step"], df["loss_clamped"])
plt.yscale("log")
plt.xlabel("Step")
plt.ylabel("Loss (log scale)")
plt.title("Training Loss vs Step")
plt.tight_layout()
plt.show()
# Plot 2: Accuracy
df["err"] = (1.0 - df["acc"]).clip(lower=eps)
plt.figure(figsize=(9, 4.8))
plt.plot(df["step"], df["err"])
plt.yscale("log")
plt.xlabel("Step")
plt.ylabel("Error rate (1 - accuracy) (log scale)")
plt.title("Training Error Rate vs Step")
plt.tight_layout()
plt.show()

181
pairwise_compare.py Executable file
View File

@@ -0,0 +1,181 @@
#!/usr/bin/env python3
# pairwise_compare.py
import logging
import random
import torch
from torch import nn
from torch.nn import functional as F
# early pytorch device setup
DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else "cpu"
# Valves
DIMENSIONS = 1
TRAIN_STEPS = 20000
TRAIN_BATCHSZ = 16384
TRAIN_PROGRESS = 500
BATCH_LOWER = -512.0
BATCH_UPPER = 512.0
DO_VERBOSE_EARLY_TRAIN = True
def get_torch_info():
log.info("PyTorch Version: %s", torch.__version__)
log.info("HIP Version: %s", torch.version.hip)
log.info("CUDA support: %s", torch.cuda.is_available())
if torch.cuda.is_available():
log.info("CUDA device detected: %s", torch.cuda.get_device_name(0))
log.info("Using %s compute mode", DEVICE)
def set_seed(seed: int):
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# 1) Data: pairs (a, b) with label y = 1 if a > b else 0
def sample_batch(batch_size: int, low=BATCH_LOWER, high=BATCH_UPPER):
a = (high - low) * torch.rand(batch_size, 1) + low
b = (high - low) * torch.rand(batch_size, 1) + low
# train for if a > b
y = (a > b).float()
# removed but left for my notes; it seems training for equality hurts classifing results that are ~eq
# when trained only on "if a > b => y", the model produces more accurate results when classifing if things are equal (~.5 prob).
# eq = (a == b).float()
# y = gt + 0.5 * eq
return a, b, y
# 2) Number "embedding" network: R -> R^d
class NumberEmbedder(nn.Module):
def __init__(self, d=8):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1, 16),
nn.ReLU(),
nn.Linear(16, d),
)
def forward(self, x):
return self.net(x)
# 3) Comparator head: takes (ea, eb) -> logit for "a > b"
class PairwiseComparator(nn.Module):
def __init__(self, d=8):
super().__init__()
self.embed = NumberEmbedder(d)
self.head = nn.Sequential(
nn.Linear(2 * d + 1, 16),
nn.ReLU(),
nn.Linear(16, 1),
)
def forward(self, a, b):
ea = self.embed(a)
eb = self.embed(b)
delta_ab = a - b
x = torch.cat([ea, eb, delta_ab], dim=-1)
return self.head(x) # logits
def training_entry():
# all prng seeds to 0 for deterministic outputs durring testing
# the seed should initialized normally otherwise
set_seed(0)
model = PairwiseComparator(d=DIMENSIONS).to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=2e-3)
# 4) Train
for step in range(TRAIN_STEPS):
a, b, y = sample_batch(TRAIN_BATCHSZ)
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
logits = model(a, b)
loss_fn = F.binary_cross_entropy_with_logits(logits, y)
opt.zero_grad()
loss_fn.backward()
opt.step()
if step <= TRAIN_PROGRESS and DO_VERBOSE_EARLY_TRAIN is True:
with torch.no_grad():
pred = (torch.sigmoid(logits) > 0.5).float()
acc = (pred == y).float().mean().item()
log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}")
elif step % TRAIN_PROGRESS == 0:
with torch.no_grad():
pred = (torch.sigmoid(logits) > 0.5).float()
acc = (pred == y).float().mean().item()
log.info(f"step={step:5d} loss={loss_fn.item():.7f} acc={acc:.7f}")
# 5) Quick test: evaluate accuracy on fresh pairs
with torch.no_grad():
a, b, y = sample_batch(TRAIN_BATCHSZ)
a, b, y = a.to(DEVICE), b.to(DEVICE), y.to(DEVICE)
logits = model(a, b)
pred = (torch.sigmoid(logits) > 0.5).float()
errors = (pred != y).sum().item()
acc = (pred == y).float().mean().item()
log.info(f"Final test acc: {acc} errors: {errors}")
# embed model depth into the model serialization
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, "model.pth")
log.info("Saved PyTorch Model State to model.pth")
def infer_entry():
model_ckpt = torch.load("model.pth", map_location=DEVICE)
model = PairwiseComparator(d=model_ckpt["d"]).to(DEVICE)
model.load_state_dict(model_ckpt["state_dict"])
model.eval()
# sample pairs
pairs = [(1, 2), (10, 3), (5, 5), (10, 35), (-64, 11), (300, 162), (2, 0), (2, 1), (3, 1), (4, 1), (3, 10),(30, 1), (0, 0), (-162, 237),
(10, 20), (100, 30), (50, 50), (100, 350), (-640, 110), (30, -420), (200, 0), (92, 5), (30, 17), (42, 10), (30, 100),(30, 1), (0, 400), (-42, -42)]
a = torch.tensor([[p[0]] for p in pairs], dtype=torch.float32, device=DEVICE)
b = torch.tensor([[p[1]] for p in pairs], dtype=torch.float32, device=DEVICE)
# sanity check before inference
log.info(f"a.device: {a.device} model.device: {next(model.parameters()).device}")
with torch.no_grad():
probs = torch.sigmoid(model(a, b))
for (x, y), p in zip(pairs, probs):
log.info(f"P({x} > {y}) = {p.item():.3f}")
if __name__ == '__main__':
import sys
import os
import datetime
log = logging.getLogger(__name__)
logging.basicConfig(filename='pairwise_compare.log', level=logging.INFO)
log.info(f"Log opened {datetime.datetime.now()}")
get_torch_info()
name = os.path.basename(sys.argv[0])
if name == 'train.py':
training_entry()
elif name == 'infer.py':
infer_entry()
else:
# alt call patern
# python3 pairwise_compare.py train
# python3 pairwise_compare.py infer
if len(sys.argv) > 1:
mode = sys.argv[1].strip().lower()
if mode == "train":
training_entry()
elif mode == "infer":
infer_entry()
else:
log.error(f"Unknown operation: {mode}")
log.error("Invalid call syntax, call script as \"train.py\" or \"infer.py\" or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
else:
log.error("Not enough arguments passed to script; call as train.py or infer.py or as pairwise_compare.py <mode> where mode is \"train\" or \"infer\"")
log.info(f"Log closed {datetime.datetime.now()}")

1
train.py Symbolic link
View File

@@ -0,0 +1 @@
pairwise_compare.py