diff --git a/pairwise_compare.py b/pairwise_compare.py index 4c80f19..2101a22 100755 --- a/pairwise_compare.py +++ b/pairwise_compare.py @@ -53,7 +53,7 @@ def parse_training_log(file_path: str) -> pd.DataFrame: # TODO: Move plotting into its own file def plt_loss_tstep(df: pd.DataFrame) -> None: # Plot 1: Loss - plt.figure(figsize=(8, 4)) + plt.figure(figsize=(10, 6)) plt.plot(df["step"], df["loss_clamped"]) plt.yscale("log") plt.xlabel("Step") @@ -69,7 +69,7 @@ def plt_loss_tstep(df: pd.DataFrame) -> None: def plt_acc_tstep(df: pd.DataFrame, eps=1e-10) -> None: # Plot 2: Accuracy df["err"] = (1.0 - df["acc"]).clip(lower=eps) - plt.figure(figsize=(8, 4)) + plt.figure(figsize=(10, 6)) plt.plot(df["step"], df["err"]) plt.yscale("log") plt.xlabel("Step") @@ -99,11 +99,14 @@ def plt_embeddings(model: comp_nn.PairwiseComparator) -> None: # move data back to CPU for plotting embeddings = embeddings.cpu() xs = xs.cpu() - + + # Plot 3: x vs h(x) + plt.figure(figsize=(10, 6)) for i in range(embeddings.shape[1]): plt.plot(xs.squeeze(), embeddings[:, i], label=f"dim {i}") - - plt.legend() + plt.title("x vs h(x)") + plt.xlabel("x [input]") + plt.ylabel("h(x) [embedding]") plt.savefig(EMBED_CHART_PATH) plt.close()