import torch from torch import nn # 2) Number "embedding" network: R -> R^d class NumberEmbedder(nn.Module): def __init__(self, d=2, hidden=4): super().__init__() self.net = nn.Sequential( nn.Linear(1, hidden), nn.ReLU(), nn.Linear(hidden, d), ) def forward(self, x): return self.net(x) # 3) Comparator head: takes (ea, eb, e) -> logit for "a > b" class PairwiseComparator(nn.Module): def __init__(self, d=2, hidden=4, k=0.5): super().__init__() self.log_k = nn.Parameter(torch.tensor([k])) self.embed = NumberEmbedder(d, hidden) self.head = nn.Sequential( nn.Linear(d, hidden), nn.ReLU(), nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1), ) def forward(self, a, b): # trying to force antisym here: h(a,b)=-h(b,a) phi = self.head(self.embed(a-b)) phi_neg = self.head(self.embed(b-a)) logit = phi - phi_neg return (self.log_k ** 2) * logit