From d46712ff53a84a1cf5466f502a9a65dd78baf8c5 Mon Sep 17 00:00:00 2001 From: Elaina Claus Date: Thu, 25 Dec 2025 01:23:54 +0000 Subject: [PATCH] make the defaults more sane for the task its still overkill --- pairwise_comp_nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pairwise_comp_nn.py b/pairwise_comp_nn.py index e03184d..6bfff5f 100644 --- a/pairwise_comp_nn.py +++ b/pairwise_comp_nn.py @@ -3,7 +3,7 @@ from torch import nn # 2) Number "embedding" network: R -> R^d class NumberEmbedder(nn.Module): - def __init__(self, d=4, hidden=16): + def __init__(self, d=2, hidden=4): super().__init__() self.net = nn.Sequential( nn.Linear(1, hidden), @@ -16,7 +16,7 @@ class NumberEmbedder(nn.Module): # 3) Comparator head: takes (ea, eb, e) -> logit for "a > b" class PairwiseComparator(nn.Module): - def __init__(self, d=4, hidden=16, k=0.5): + def __init__(self, d=2, hidden=4, k=0.5): super().__init__() self.log_k = nn.Parameter(torch.tensor([k])) self.embed = NumberEmbedder(d, hidden)