import torch from torch import nn # 2) Number "embedding" network: R -> R^d class NumberEmbedder(nn.Module): def __init__(self, d=8): super().__init__() self.net = nn.Sequential( nn.Linear(1, 16), nn.ReLU(), nn.Linear(16, d), ) def forward(self, x): return self.net(x) # 3) Comparator head: takes (ea, eb) -> logit for "a > b" class PairwiseComparator(nn.Module): def __init__(self, d=8): super().__init__() self.embed = NumberEmbedder(d) self.head = nn.Sequential( nn.Linear(2 * d + 1, 16), nn.ReLU(), nn.Linear(16, 1), ) def forward(self, a, b): ea = self.embed(a) eb = self.embed(b) delta_ab = a - b x = torch.cat([ea, eb, delta_ab], dim=-1) return self.head(x) # logits