diff --git a/pairwise_compare.py b/pairwise_compare.py index f9517c0..83f30d8 100755 --- a/pairwise_compare.py +++ b/pairwise_compare.py @@ -13,8 +13,8 @@ DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_availab # Valves DIMENSIONS = 1 -TRAIN_STEPS = 25000 -TRAIN_BATCHSZ = 16384 +TRAIN_STEPS = 10000 +TRAIN_BATCHSZ = 8192 TRAIN_PROGRESS = 100 BATCH_LOWER = -512.0 BATCH_UPPER = 512.0 @@ -95,7 +95,7 @@ def training_entry(): # embed model dimensions into the model serialization torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, MODEL_PATH) - log.info("Saved PyTorch Model State to model.pth") + log.info(f"Saved PyTorch Model State to {MODEL_PATH}") def infer_entry(): model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)