fix log entry for model state saving
adjust default batch size and passes to be a bit more CPU pytorch friendly
This commit is contained in:
@@ -13,8 +13,8 @@ DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_availab
|
||||
|
||||
# Valves
|
||||
DIMENSIONS = 1
|
||||
TRAIN_STEPS = 25000
|
||||
TRAIN_BATCHSZ = 16384
|
||||
TRAIN_STEPS = 10000
|
||||
TRAIN_BATCHSZ = 8192
|
||||
TRAIN_PROGRESS = 100
|
||||
BATCH_LOWER = -512.0
|
||||
BATCH_UPPER = 512.0
|
||||
@@ -95,7 +95,7 @@ def training_entry():
|
||||
|
||||
# embed model dimensions into the model serialization
|
||||
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, MODEL_PATH)
|
||||
log.info("Saved PyTorch Model State to model.pth")
|
||||
log.info(f"Saved PyTorch Model State to {MODEL_PATH}")
|
||||
|
||||
def infer_entry():
|
||||
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
|
||||
|
||||
Reference in New Issue
Block a user