Files
mltoys/pairwise_comp_nn.py
Elaina Claus 6e31865a84 added episilon for equality check
major layout changes in the network
2025-12-22 21:47:16 -05:00

37 lines
1.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
from torch import nn
# 2) Number "embedding" network: R -> R^d
class NumberEmbedder(nn.Module):
def __init__(self, d=4, hidden=16):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1, hidden),
nn.ReLU(),
nn.Linear(hidden, d),
)
def forward(self, x):
return self.net(x)
# 3) Comparator head: takes (ea, eb, e) -> logit for "a > b"
class PairwiseComparator(nn.Module):
def __init__(self, d=4, hidden=16, k=1.0):
super().__init__()
self.log_k = nn.Parameter(torch.tensor([k]))
self.embed = NumberEmbedder(d, hidden)
self.head = nn.Sequential(
nn.Linear(d, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, 1),
)
def forward(self, a, b):
# trying to force antisym here: h(a,b)=h(b,a)
phi = self.head(self.embed(a-b))
phi_neg = self.head(self.embed(b-a))
logit = phi - phi_neg
return (self.log_k ** 2) * logit