diff --git a/model/soft_dtw_cuda.py b/model/soft_dtw_cuda.py index 906a877..739ed34 100644 --- a/model/soft_dtw_cuda.py +++ b/model/soft_dtw_cuda.py @@ -60,7 +60,7 @@ def compute_softdtw_cuda(D, gamma, warp, bandwidth, max_i, max_j, n_passes, R): cuda.syncthreads() @cuda.jit -def compute_softdtw_backward_cuda(D, R, inv_gamma, warp, bandwidth, max_i, max_j, n_passes, E, G): +def compute_softdtw_backward_cuda(D, R, inv_gamma, warp, bandwidth, max_i, max_j, n_passes, E): k = cuda.blockIdx.x tid = cuda.threadIdx.x @@ -84,7 +84,6 @@ def compute_softdtw_backward_cuda(D, R, inv_gamma, warp, bandwidth, max_i, max_j b = math.exp((R[k, i, j + 1] - R[k, i, j] - D[k, i, j] - warp) * inv_gamma) c = math.exp((R[k, i + 1, j + 1] - R[k, i, j] - D[k, i, j]) * inv_gamma) E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c - G[k, i, j] = E[k, i + 1, j]+E[k, i, j+1]+E[k, i+1, j+1] cuda.syncthreads() @@ -142,18 +141,16 @@ def backward(ctx, grad_output): E = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev) E[:, -1, -1] = 1 - G = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev) - G[:, -1, -1] = 1 compute_softdtw_backward_cuda[B, threads_per_block](cuda.as_cuda_array(D_), cuda.as_cuda_array(R), 1.0 / gamma.item(), warp.item(), bandwidth.item(), N, M, n_passes, - cuda.as_cuda_array(E), cuda.as_cuda_array(G)) - G = G[:, 1:N + 1, 1:M + 1] # dR_D + cuda.as_cuda_array(E)) + E = E[:, 1:N + 1, 1:M + 1] # dR_D - tmp_G = G.unsqueeze(-1).expand(-1, -1, -1, H) - tmp_G = tmp_G * torch.sign(raw_D) - dR_X = tmp_G.sum(dim=2) + tmp_E = E.unsqueeze(-1).expand(-1, -1, -1, H) + tmp_E = tmp_E * torch.sign(raw_D) + dR_X = tmp_E.sum(dim=2) return grad_output.view(-1, 1, 1).expand_as(dR_X) * dR_X, None, None, None, None, None @@ -194,10 +191,8 @@ def cpu_compute_softdtw_backward(D_, R, gamma, warp, bandwidth): M = D_.shape[2] D = np.zeros((B, N + 2, M + 2)) E = np.zeros((B, N + 2, M + 2)) - G = np.zeros((B, N + 2, M + 2)) D[:, 1:N + 1, 1:M + 1] = D_ E[:, -1, -1] = 1 - G[:, -1, -1] = 1 for k in range(B): for j in range(M, 0, -1): for i in range(N, 0, -1): @@ -216,9 +211,8 @@ def cpu_compute_softdtw_backward(D_, R, gamma, warp, bandwidth): b = np.exp(b0) c = np.exp(c0) E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c - G[k, i, j] = E[k, i + 1, j]+E[k, i, j+1]+E[k, i+1, j+1] - return G[:, 1:N + 1, 1:M + 1] + return E[:, 1:N + 1, 1:M + 1] class CPUSoftDTW(Function): @@ -249,10 +243,10 @@ def backward(ctx, grad_output): g_ = gamma.item() w_ = warp.item() b_ = bandwidth.item() - G = torch.Tensor(cpu_compute_softdtw_backward(D_, R_, g_, w_, b_)).to(dev).type(dtype) - tmp_G = G.unsqueeze(-1).expand(-1, -1, -1, H) - tmp_G = tmp_G * torch.sign(raw_D) - dR_X = tmp_G.sum(dim=2) + E = torch.Tensor(cpu_compute_softdtw_backward(D_, R_, g_, w_, b_)).to(dev).type(dtype) + tmp_E = E.unsqueeze(-1).expand(-1, -1, -1, H) + tmp_E = tmp_E * torch.sign(raw_D) + dR_X = tmp_E.sum(dim=2) return grad_output.view(-1, 1, 1).expand_as(dR_X) * dR_X, None, None, None, None, None @@ -298,9 +292,10 @@ def _manhattan_dist_func(x, y): return torch.abs(x - y).sum(3), (x - y) def forward(self, X, Y): - + n_hidden = X.size(-1) func_dtw = self._get_func_dtw(X, Y) D_xy, raw_D_xy = self.dist_func(X, Y) + D_xy = D_xy / n_hidden return func_dtw(X, raw_D_xy, D_xy, self.gamma, self.warp, self.bandwidth)