diff --git a/pairwise_comp_nn.py b/pairwise_comp_nn.py index 84e5299..2a14466 100644 --- a/pairwise_comp_nn.py +++ b/pairwise_comp_nn.py @@ -14,7 +14,7 @@ class NumberEmbedder(nn.Module): def forward(self, x): return self.net(x) -# 3) Comparator head: takes (ea, eb, e) -> logit for "a > b" +# MLP Comparator head: takes (ea, eb, e) -> logit for "a > b" class PairwiseComparator(nn.Module): def __init__(self, d=2, hidden=4, k=0.5): super().__init__()