fix log entry for model state saving
adjust default batch size and passes to be a bit more CPU pytorch friendly
This commit is contained in:
@@ -13,8 +13,8 @@ DEVICE = torch.accelerator.current_accelerator() if torch.accelerator.is_availab
|
|||||||
|
|
||||||
# Valves
|
# Valves
|
||||||
DIMENSIONS = 1
|
DIMENSIONS = 1
|
||||||
TRAIN_STEPS = 25000
|
TRAIN_STEPS = 10000
|
||||||
TRAIN_BATCHSZ = 16384
|
TRAIN_BATCHSZ = 8192
|
||||||
TRAIN_PROGRESS = 100
|
TRAIN_PROGRESS = 100
|
||||||
BATCH_LOWER = -512.0
|
BATCH_LOWER = -512.0
|
||||||
BATCH_UPPER = 512.0
|
BATCH_UPPER = 512.0
|
||||||
@@ -95,7 +95,7 @@ def training_entry():
|
|||||||
|
|
||||||
# embed model dimensions into the model serialization
|
# embed model dimensions into the model serialization
|
||||||
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, MODEL_PATH)
|
torch.save({"state_dict": model.state_dict(), "d": DIMENSIONS}, MODEL_PATH)
|
||||||
log.info("Saved PyTorch Model State to model.pth")
|
log.info(f"Saved PyTorch Model State to {MODEL_PATH}")
|
||||||
|
|
||||||
def infer_entry():
|
def infer_entry():
|
||||||
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
|
model_ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
|
||||||
|
|||||||
Reference in New Issue
Block a user