diff --git a/pairwise_comp_nn.py b/pairwise_comp_nn.py index e03184d..6bfff5f 100644 --- a/pairwise_comp_nn.py +++ b/pairwise_comp_nn.py @@ -3,7 +3,7 @@ from torch import nn # 2) Number "embedding" network: R -> R^d class NumberEmbedder(nn.Module): - def __init__(self, d=4, hidden=16): + def __init__(self, d=2, hidden=4): super().__init__() self.net = nn.Sequential( nn.Linear(1, hidden), @@ -16,7 +16,7 @@ class NumberEmbedder(nn.Module): # 3) Comparator head: takes (ea, eb, e) -> logit for "a > b" class PairwiseComparator(nn.Module): - def __init__(self, d=4, hidden=16, k=0.5): + def __init__(self, d=2, hidden=4, k=0.5): super().__init__() self.log_k = nn.Parameter(torch.tensor([k])) self.embed = NumberEmbedder(d, hidden)