fix log entry for model state saving

adjust default batch size and passes to be a bit more CPU pytorch friendly
This commit is contained in:
2025-12-20 01:54:42 +00:00
parent 0d6a92823a
commit 5e5ad1bc20

View File

@@ -13,8 +13,8 @@ DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_availab
# Valves # Valves
DIMENSIONS = 1 DIMENSIONS = 1
TRAIN_STEPS = 25000 TRAIN_STEPS = 10000
TRAIN_BATCHSZ = 16384 TRAIN_BATCHSZ = 8192
TRAIN_PROGRESS = 100 TRAIN_PROGRESS = 100
BATCH_LOWER = -512.0 BATCH_LOWER = -512.0
BATCH_UPPER = 512.0 BATCH_UPPER = 512.0
@@ -95,7 +95,7 @@ def training_entry():
# embed model dimensions into the model serialization # embed model dimensions into the model serialization
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, MODEL_PATH) torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, MODEL_PATH)
log.info("Saved PyTorch Model State to model.pth") log.info(f"Saved PyTorch Model State to {MODEL_PATH}")
def infer_entry(): def infer_entry():
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE) model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)