Skip to content

Commit 60fd6c7

Browse files
authored
Fix var to be array api compatible by using float64 for __truediv__ (#6)
1 parent d53b143 commit 60fd6c7

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

transformcl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ def var(cl):
138138
139139
"""
140140
xp = array_namespace(cl)
141-
ell = xp.arange(cl.shape[-1])
141+
# ell cannot be an integer here as, within the array api
142+
# only floating-point dtypes are allowed in __truediv__
143+
ell = xp.arange(cl.shape[-1], dtype=xp.float64)
142144
return xp.sum((2 * ell + 1) / (4 * xp.pi) * cl, axis=-1)
143145

144146

0 commit comments

Comments
 (0)