Skip to content

Commit 7fbb144

Browse files
authored
Merge pull request #17 from OpenTabular/feature/nodegam
Add NodeGAM model implementation and related utilities
2 parents 80c20aa + 42b4a7c commit 7fbb144

8 files changed

Lines changed: 2216 additions & 9 deletions

File tree

nampy/arch_utils/nn_utils.py

Lines changed: 372 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
import torch
2+
import numpy as np
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from torch.autograd import Function
7+
8+
from torch.jit import script
9+
10+
"""Neural Network related utils like Entmax and Modules."""
11+
12+
13+
def check_numpy(x):
14+
"""Makes sure x is a numpy array. If not, make it as one."""
15+
if isinstance(x, torch.Tensor):
16+
x = x.detach().cpu().numpy()
17+
x = np.asarray(x)
18+
assert isinstance(x, np.ndarray)
19+
return x
20+
21+
def process_in_chunks(function, *args, batch_size, out=None, **kwargs):
22+
"""Computes output by applying batch-parallel function to large data tensor in chunks.
23+
24+
Args:
25+
function: a function(*[x[indices, ...] for x in args]) -> out[indices, ...].
26+
args: one or many tensors, each [num_instances, ...].
27+
batch_size: maximum chunk size processed in one go.
28+
out: memory buffer for out, defaults to torch.zeros of appropriate size and type.
29+
30+
Returns:
31+
out: the outputs of function(data), computed in a memory-efficient (mini-batch) way.
32+
"""
33+
total_size = args[0].shape[0]
34+
first_output = function(*[x[0: batch_size] for x in args])
35+
output_shape = (total_size,) + tuple(first_output.shape[1:])
36+
if out is None:
37+
out = torch.zeros(*output_shape, dtype=first_output.dtype, device=first_output.device,
38+
layout=first_output.layout, **kwargs)
39+
40+
out[0: batch_size] = first_output
41+
for i in range(batch_size, total_size, batch_size):
42+
batch_ix = slice(i, min(i + batch_size, total_size))
43+
out[batch_ix] = function(*[x[batch_ix] for x in args])
44+
return out
45+
46+
47+
48+
49+
50+
def to_one_hot(y, depth=None):
51+
"""Make the target become one-hot encoding.
52+
53+
Takes integer with n dims and converts it to 1-hot representation with n + 1 dims.
54+
The n+1'st dimension will have zeros everywhere but at y'th index, where it will be equal to 1.
55+
56+
Args:
57+
y: Input integer (IntTensor, LongTensor or Variable) of any shape.
58+
depth (int): The size of the one hot dimension.
59+
60+
Returns:
61+
y_onehot: The onehot encoding of y.
62+
"""
63+
y_flat = y.to(torch.int64).view(-1, 1)
64+
depth = depth if depth is not None else int(torch.max(y_flat)) + 1
65+
y_one_hot = torch.zeros(y_flat.size()[0], depth, device=y.device).scatter_(1, y_flat, 1)
66+
y_one_hot = y_one_hot.view(*(tuple(y.shape) + (-1,)))
67+
return y_one_hot
68+
69+
70+
def _make_ix_like(input, dim=0):
71+
d = input.size(dim)
72+
rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype)
73+
view = [1] * input.dim()
74+
view[0] = -1
75+
return rho.view(view).transpose(0, dim)
76+
77+
78+
class SparsemaxFunction(Function):
79+
"""Sparsemax function.
80+
81+
An implementation of sparsemax (Martins & Astudillo, 2016). See
82+
:cite:`DBLP:journals/corr/MartinsA16` for detailed description.
83+
84+
By Ben Peters and Vlad Niculae.
85+
"""
86+
87+
@staticmethod
88+
def forward(ctx, input, dim=-1):
89+
"""sparsemax: normalizing sparse transform (a la softmax)
90+
91+
Args:
92+
input: Any dimension.
93+
dim: Dimension along which to apply.
94+
95+
Returns:
96+
output (Tensor): Same shape as input.
97+
"""
98+
ctx.dim = dim
99+
max_val, _ = input.max(dim=dim, keepdim=True)
100+
input -= max_val # same numerical stability trick as for softmax
101+
tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim)
102+
output = torch.clamp(input - tau, min=0)
103+
ctx.save_for_backward(supp_size, output)
104+
return output
105+
106+
@staticmethod
107+
def backward(ctx, grad_output):
108+
supp_size, output = ctx.saved_tensors
109+
dim = ctx.dim
110+
grad_input = grad_output.clone()
111+
grad_input[output == 0] = 0
112+
113+
v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze()
114+
v_hat = v_hat.unsqueeze(dim)
115+
grad_input = torch.where(output != 0, grad_input - v_hat, grad_input)
116+
return grad_input, None
117+
118+
119+
@staticmethod
120+
def _threshold_and_support(input, dim=-1):
121+
"""Sparsemax building block: compute the threshold.
122+
123+
Args:
124+
input: Any dimension.
125+
dim: Dimension along which to apply the sparsemax.
126+
127+
Returns:
128+
The threshold value.
129+
"""
130+
131+
input_srt, _ = torch.sort(input, descending=True, dim=dim)
132+
input_cumsum = input_srt.cumsum(dim) - 1
133+
rhos = _make_ix_like(input, dim)
134+
support = rhos * input_srt > input_cumsum
135+
136+
support_size = support.sum(dim=dim).unsqueeze(dim)
137+
tau = input_cumsum.gather(dim, support_size - 1)
138+
tau /= support_size.to(input.dtype)
139+
return tau, support_size
140+
141+
142+
sparsemax = lambda input, dim=-1: SparsemaxFunction.apply(input, dim)
143+
sparsemoid = lambda input: (0.5 * input + 0.5).clamp_(0, 1)
144+
145+
146+
class Entmax15Function(Function):
147+
"""Entropy Max (EntMax).
148+
149+
An implementation of exact Entmax with alpha=1.5 (B. Peters, V. Niculae, A. Martins). See
150+
:cite:`https://arxiv.org/abs/1905.05702 for detailed description.
151+
Source: https://github.com/deep-spin/entmax
152+
"""
153+
154+
@staticmethod
155+
def forward(ctx, input, dim=-1):
156+
ctx.dim = dim
157+
158+
max_val, _ = input.max(dim=dim, keepdim=True)
159+
input = input - max_val # same numerical stability trick as for softmax
160+
input = input / 2 # divide by 2 to solve actual Entmax
161+
162+
tau_star, _ = Entmax15Function._threshold_and_support(input, dim)
163+
output = torch.clamp(input - tau_star, min=0) ** 2
164+
ctx.save_for_backward(output)
165+
return output
166+
167+
@staticmethod
168+
def backward(ctx, grad_output):
169+
Y, = ctx.saved_tensors
170+
gppr = Y.sqrt() # = 1 / g'' (Y)
171+
dX = grad_output * gppr
172+
q = dX.sum(ctx.dim) / gppr.sum(ctx.dim)
173+
q = q.unsqueeze(ctx.dim)
174+
dX -= q * gppr
175+
return dX, None
176+
177+
@staticmethod
178+
def _threshold_and_support(input, dim=-1):
179+
Xsrt, _ = torch.sort(input, descending=True, dim=dim)
180+
181+
rho = _make_ix_like(input, dim)
182+
mean = Xsrt.cumsum(dim) / rho
183+
mean_sq = (Xsrt ** 2).cumsum(dim) / rho
184+
ss = rho * (mean_sq - mean ** 2)
185+
delta = (1 - ss) / rho
186+
187+
# NOTE this is not exactly the same as in reference algo
188+
# Fortunately it seems the clamped values never wrongly
189+
# get selected by tau <= sorted_z. Prove this!
190+
delta_nz = torch.clamp(delta, 0)
191+
tau = mean - torch.sqrt(delta_nz)
192+
193+
support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim)
194+
tau_star = tau.gather(dim, support_size - 1)
195+
return tau_star, support_size
196+
197+
198+
class Entmoid15(Function):
199+
"""A highly optimized equivalent of lambda x: Entmax15([x, 0])."""
200+
201+
@staticmethod
202+
def forward(ctx, input):
203+
output = Entmoid15._forward(input)
204+
ctx.save_for_backward(output)
205+
return output
206+
207+
@staticmethod
208+
@script
209+
def _forward(input):
210+
input, is_pos = abs(input), input >= 0
211+
tau = (input + torch.sqrt(F.relu(8 - input ** 2))) / 2
212+
tau.masked_fill_(tau <= input, 2.0)
213+
y_neg = 0.25 * F.relu(tau - input, inplace=True) ** 2
214+
return torch.where(is_pos, 1 - y_neg, y_neg)
215+
216+
@staticmethod
217+
def backward(ctx, grad_output):
218+
return Entmoid15._backward(ctx.saved_tensors[0], grad_output)
219+
220+
@staticmethod
221+
@script
222+
def _backward(output, grad_output):
223+
gppr0, gppr1 = output.sqrt(), (1 - output).sqrt()
224+
grad_input = grad_output * gppr0
225+
q = grad_input / (gppr0 + gppr1)
226+
grad_input -= q * gppr0
227+
return grad_input
228+
229+
230+
entmax15 = lambda input, dim=-1: Entmax15Function.apply(input, dim)
231+
entmoid15 = Entmoid15.apply
232+
233+
234+
def my_one_hot(val, dim=-1):
235+
"""Make one hot encoding along certain dimension and not just the last dimension.
236+
237+
Args:
238+
val: A pytorch tensor.
239+
dim: The dimension.
240+
"""
241+
max_cls = torch.argmax(val, dim=dim)
242+
onehot = F.one_hot(max_cls, num_classes=val.shape[dim])
243+
244+
# swap back the dimension
245+
if dim != -1 and dim != val.ndim - 1:
246+
the_dim = list(range(onehot.ndim))
247+
the_dim.insert(dim, the_dim.pop(-1))
248+
onehot = onehot.permute(the_dim)
249+
250+
return onehot
251+
252+
253+
class _Temp(nn.Module):
254+
"""Shared base class to do temperature annealing for EntMax/SoftMax/GumbleMax functions."""
255+
256+
def __init__(self, steps, max_temp=1., min_temp=0.01, sample_soft=False):
257+
"""Annealing temperature from max to min in log10 space.
258+
259+
Args:
260+
steps: The number of steps to change from max_temp to the min_temp in log10 space.
261+
max_temp: The max (initial) temperature.
262+
min_temp: The min (final) temperature.
263+
sample_soft: If False, the model does a hard operation after the specified steps.
264+
"""
265+
super().__init__()
266+
self.steps = steps
267+
self.min_temp = min_temp
268+
self.max_temp = max_temp
269+
self.sample_soft = sample_soft
270+
271+
# Initialize to nn Parameter to store it in the model state_dict
272+
self.tau = nn.Parameter(torch.tensor(max_temp, dtype=torch.float32), requires_grad=False)
273+
274+
def forward(self, logits, dim=-1):
275+
# During training and under annealing, run a soft max operation
276+
if self.sample_soft or (self.training and self.tau.item() > self.min_temp):
277+
return self.forward_with_tau(logits, dim=dim)
278+
279+
# In test time, sample a hard max
280+
with torch.no_grad():
281+
return self.discrete_op(logits, dim=dim)
282+
283+
def discrete_op(self, logits, dim=-1):
284+
return my_one_hot(logits, dim=dim).float()
285+
286+
@property
287+
def is_deterministic(self):
288+
return (not self.sample_soft) and (not self.training or self.tau.item() <= self.min_temp)
289+
290+
def temp_step_callback(self, step):
291+
# Calculate the temp; allow fractional step!
292+
if step >= self.steps:
293+
self.tau.data = torch.tensor(self.min_temp, dtype=torch.float32)
294+
else:
295+
logmin = np.log10(self.min_temp)
296+
logmax = np.log10(self.max_temp)
297+
# Linearly interpolate it;
298+
logtemp = logmax + step / self.steps * (logmin - logmax)
299+
temp = (10 ** logtemp)
300+
self.tau.data = torch.tensor(temp, dtype=torch.float32)
301+
302+
def forward_with_tau(self, logits, dim):
303+
raise NotImplementedError()
304+
305+
306+
class SMTemp(_Temp):
307+
"""Softmax with temperature annealing."""
308+
def forward_with_tau(self, logits, dim):
309+
return F.softmax(logits / self.tau.item(), dim=dim)
310+
311+
312+
class GSMTemp(_Temp):
313+
"""Gumbel Softmax with temperature annealing."""
314+
def forward_with_tau(self, logits, dim):
315+
return F.gumbel_softmax(logits, tau=self.tau.item(), dim=dim)
316+
317+
318+
class EM15Temp(_Temp):
319+
"""EntMax15 with temperature annealing."""
320+
def forward_with_tau(self, logits, dim):
321+
return entmax15(logits / self.tau.item(), dim=dim)
322+
323+
324+
class EMoid15Temp(_Temp):
325+
"""Entmoid with temperature annealing."""
326+
def __init__(self, **kwargs):
327+
# It always does soft operation.
328+
kwargs['sample_soft'] = True
329+
super().__init__(**kwargs)
330+
331+
def forward_with_tau(self, logits, dim=-1):
332+
return entmoid15(logits / self.tau.item())
333+
334+
def discrete_op(self, logits, dim=-1):
335+
# Do not handle the logits=0 since it's quite rare in opt
336+
# And I think it's fine to output 0.5
337+
return torch.sign(logits) / 2 + 0.5
338+
339+
340+
class Lambda(nn.Module):
341+
def __init__(self, func):
342+
super().__init__()
343+
self.func = func
344+
345+
def forward(self, *args, **kwargs):
346+
return self.func(*args, **kwargs)
347+
348+
349+
class ModuleWithInit(nn.Module):
350+
"""Base class for pytorch module with data-aware initializer on first batch."""
351+
def __init__(self):
352+
super().__init__()
353+
self._is_initialized_tensor = nn.Parameter(torch.tensor(0, dtype=torch.float32), requires_grad=False)
354+
self._is_initialized_bool = None
355+
# Note: this module uses a separate flag self._is_initialized so as to achieve both
356+
# * persistence: is_initialized is saved alongside model in state_dict
357+
# * speed: model doesn't need to cache
358+
# please DO NOT use these flags in child modules
359+
# I change the type to torch.float32 to use apex 16 precision training
360+
361+
def initialize(self, *args, **kwargs):
362+
"""initialize module tensors using first batch of data."""
363+
raise NotImplementedError("Please implement ")
364+
365+
def __call__(self, *args, **kwargs):
366+
if self._is_initialized_bool is None:
367+
self._is_initialized_bool = bool(self._is_initialized_tensor.item())
368+
if not self._is_initialized_bool:
369+
self.initialize(*args, **kwargs)
370+
self._is_initialized_tensor.data[...] = 1
371+
self._is_initialized_bool = True
372+
return super().__call__(*args, **kwargs)

0 commit comments

Comments
 (0)