using AdamW optim again, AdamW is the go to for small toys and transformers. refactored NN classes to thier own module under pairwise_comp_nn.py
34 lines
871 B
Python
34 lines
871 B
Python
import torch
|
|
from torch import nn
|
|
|
|
# 2) Number "embedding" network: R -> R^d
|
|
class NumberEmbedder(nn.Module):
|
|
def __init__(self, d=8):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(1, 16),
|
|
nn.ReLU(),
|
|
nn.Linear(16, d),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
# 3) Comparator head: takes (ea, eb) -> logit for "a > b"
|
|
class PairwiseComparator(nn.Module):
|
|
def __init__(self, d=8):
|
|
super().__init__()
|
|
self.embed = NumberEmbedder(d)
|
|
self.head = nn.Sequential(
|
|
nn.Linear(2 * d + 1, 16),
|
|
nn.ReLU(),
|
|
nn.Linear(16, 1),
|
|
)
|
|
|
|
def forward(self, a, b):
|
|
ea = self.embed(a)
|
|
eb = self.embed(b)
|
|
delta_ab = a - b
|
|
x = torch.cat([ea, eb, delta_ab], dim=-1)
|
|
|
|
return self.head(x) # logits |