|
10 | 10 |
|
11 | 11 | namespace ck { |
12 | 12 |
|
13 | | -namespace detail { |
14 | | -template <typename X, typename Y> |
15 | | -struct tuple_concat; |
16 | | - |
17 | | -template <typename... Xs, typename... Ys> |
18 | | -struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>> |
19 | | -{ |
20 | | - using type = Tuple<Xs..., Ys...>; |
21 | | -}; |
22 | | - |
23 | | -// StaticallyIndexedArrayImpl uses binary split for O(log N) depth |
| 13 | +// StaticallyIndexedArray using simple C-array instead of template metaprogramming |
| 14 | +// This avoids deep template instantiation while maintaining the same interface |
24 | 15 | template <typename T, index_t N> |
25 | | -struct StaticallyIndexedArrayImpl |
| 16 | +struct StaticallyIndexedArray |
26 | 17 | { |
27 | | - using type = |
28 | | - typename tuple_concat<typename StaticallyIndexedArrayImpl<T, N / 2>::type, |
29 | | - typename StaticallyIndexedArrayImpl<T, N - N / 2>::type>::type; |
30 | | -}; |
| 18 | + __host__ __device__ constexpr StaticallyIndexedArray() : data_{} {} |
| 19 | + |
| 20 | + // Single-element constructor - exclude containers with matching size (to prefer conversion |
| 21 | + // constructor) |
| 22 | + template <typename X> |
| 23 | + requires(N == 1 && |
| 24 | + // Allow if X is same type as T or doesn't have Size() method |
| 25 | + (is_same<remove_cvref_t<X>, T>::value || !requires { remove_cvref_t<X>::Size(); })) |
| 26 | + __host__ __device__ constexpr StaticallyIndexedArray(X&& x) |
| 27 | + : data_{static_cast<T>(ck::forward<X>(x))} |
| 28 | + { |
| 29 | + } |
31 | 30 |
|
32 | | -template <typename T> |
33 | | -struct StaticallyIndexedArrayImpl<T, 0> |
34 | | -{ |
35 | | - using type = Tuple<>; |
| 31 | + // Multi-element constructor |
| 32 | + template <typename... Xs> |
| 33 | + requires(sizeof...(Xs) == N && N > 1) |
| 34 | + __host__ __device__ constexpr StaticallyIndexedArray(Xs&&... xs) |
| 35 | + : data_{static_cast<T>(ck::forward<Xs>(xs))...} |
| 36 | + { |
| 37 | + } |
| 38 | + |
| 39 | + // Conversion constructor from any indexed container (Tuple, etc.) |
| 40 | + template <typename Container> |
| 41 | + requires(!is_same<remove_cvref_t<Container>, StaticallyIndexedArray>::value && |
| 42 | + requires { Container::Size(); } && Container::Size() == N) |
| 43 | + __host__ __device__ constexpr StaticallyIndexedArray(const Container& src) |
| 44 | + : StaticallyIndexedArray( |
| 45 | + make_from_container(src, typename arithmetic_sequence_gen<0, N, 1>::type{})) |
| 46 | + { |
| 47 | + } |
| 48 | + |
| 49 | + private: |
| 50 | + template <typename Container, index_t... Is> |
| 51 | + __host__ __device__ static constexpr StaticallyIndexedArray |
| 52 | + make_from_container(const Container& src, Sequence<Is...>) |
| 53 | + { |
| 54 | + return StaticallyIndexedArray{static_cast<T>(src[Number<Is>{}])...}; |
| 55 | + } |
| 56 | + |
| 57 | + public: |
| 58 | + __host__ __device__ static constexpr index_t Size() { return N; } |
| 59 | + |
| 60 | + // read access |
| 61 | + template <index_t I> |
| 62 | + __host__ __device__ constexpr const T& At(Number<I>) const |
| 63 | + { |
| 64 | + static_assert(I < N, "wrong! out of range"); |
| 65 | + return data_[I]; |
| 66 | + } |
| 67 | + |
| 68 | + // write access |
| 69 | + template <index_t I> |
| 70 | + __host__ __device__ constexpr T& At(Number<I>) |
| 71 | + { |
| 72 | + static_assert(I < N, "wrong! out of range"); |
| 73 | + return data_[I]; |
| 74 | + } |
| 75 | + |
| 76 | + // read access |
| 77 | + template <index_t I> |
| 78 | + __host__ __device__ constexpr const T& operator[](Number<I> i) const |
| 79 | + { |
| 80 | + return At(i); |
| 81 | + } |
| 82 | + |
| 83 | + // write access |
| 84 | + template <index_t I> |
| 85 | + __host__ __device__ constexpr T& operator()(Number<I> i) |
| 86 | + { |
| 87 | + return At(i); |
| 88 | + } |
| 89 | + |
| 90 | + template <typename U> |
| 91 | + __host__ __device__ constexpr auto operator=(const U& a) |
| 92 | + { |
| 93 | + static_assert(U::Size() == Size(), "wrong! size not the same"); |
| 94 | + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); |
| 95 | + return *this; |
| 96 | + } |
| 97 | + |
| 98 | + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } |
| 99 | + |
| 100 | + T data_[N]; |
36 | 101 | }; |
37 | 102 |
|
| 103 | +// Specialization for empty array |
38 | 104 | template <typename T> |
39 | | -struct StaticallyIndexedArrayImpl<T, 1> |
| 105 | +struct StaticallyIndexedArray<T, 0> |
40 | 106 | { |
41 | | - using type = Tuple<T>; |
42 | | -}; |
43 | | -} // namespace detail |
| 107 | + __host__ __device__ constexpr StaticallyIndexedArray() = default; |
44 | 108 |
|
45 | | -template <typename T, index_t N> |
46 | | -using StaticallyIndexedArray = typename detail::StaticallyIndexedArrayImpl<T, N>::type; |
| 109 | + __host__ __device__ static constexpr index_t Size() { return 0; } |
| 110 | + |
| 111 | + template <typename U> |
| 112 | + __host__ __device__ constexpr auto operator=(const U&) |
| 113 | + { |
| 114 | + return *this; |
| 115 | + } |
| 116 | + |
| 117 | + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } |
| 118 | +}; |
47 | 119 |
|
48 | 120 | template <typename X, typename... Xs> |
49 | 121 | __host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs) |
50 | 122 | { |
51 | | - return StaticallyIndexedArray<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...); |
| 123 | + return StaticallyIndexedArray<X, sizeof...(Xs) + 1>{x, static_cast<X>(xs)...}; |
52 | 124 | } |
53 | 125 |
|
54 | 126 | // make empty StaticallyIndexedArray |
55 | 127 | template <typename X> |
56 | 128 | __host__ __device__ constexpr auto make_statically_indexed_array() |
57 | 129 | { |
58 | | - return StaticallyIndexedArray<X, 0>(); |
| 130 | + return StaticallyIndexedArray<X, 0>{}; |
59 | 131 | } |
60 | 132 |
|
61 | 133 | template <typename T, index_t N> |
@@ -102,5 +174,78 @@ struct StaticallyIndexedArray_v2 |
102 | 174 | T data_[N]; |
103 | 175 | }; |
104 | 176 |
|
| 177 | +// Concepts for StaticallyIndexedArray arithmetic operators |
| 178 | +template <typename T> |
| 179 | +concept Scalar = ck::is_integral<T>::value || ck::is_floating_point<T>::value; |
| 180 | + |
| 181 | +template <typename T> |
| 182 | +concept IndexedContainer = !Scalar<T> && requires { T::Size(); }; |
| 183 | + |
| 184 | +// Arithmetic operators for StaticallyIndexedArray (to match Tuple operators) |
| 185 | + |
| 186 | +// StaticallyIndexedArray += X |
| 187 | +template <typename T, index_t N, IndexedContainer X> |
| 188 | +__host__ __device__ constexpr auto operator+=(StaticallyIndexedArray<T, N>& y, const X& x) |
| 189 | +{ |
| 190 | + static_assert(X::Size() == N, "wrong! size not the same"); |
| 191 | + static_for<0, N, 1>{}([&](auto i) { y(i) += x[i]; }); |
| 192 | + return y; |
| 193 | +} |
| 194 | + |
| 195 | +// StaticallyIndexedArray -= X |
| 196 | +template <typename T, index_t N, IndexedContainer X> |
| 197 | +__host__ __device__ constexpr auto operator-=(StaticallyIndexedArray<T, N>& y, const X& x) |
| 198 | +{ |
| 199 | + static_assert(X::Size() == N, "wrong! size not the same"); |
| 200 | + static_for<0, N, 1>{}([&](auto i) { y(i) -= x[i]; }); |
| 201 | + return y; |
| 202 | +} |
| 203 | + |
| 204 | +// StaticallyIndexedArray + Y |
| 205 | +template <typename T, index_t N, IndexedContainer Y> |
| 206 | +__host__ __device__ constexpr auto operator+(const StaticallyIndexedArray<T, N>& x, const Y& y) |
| 207 | +{ |
| 208 | + static_assert(Y::Size() == N, "wrong! size not the same"); |
| 209 | + StaticallyIndexedArray<T, N> r; |
| 210 | + static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] + y[i]; }); |
| 211 | + return r; |
| 212 | +} |
| 213 | + |
| 214 | +// StaticallyIndexedArray - Y |
| 215 | +template <typename T, index_t N, IndexedContainer Y> |
| 216 | +__host__ __device__ constexpr auto operator-(const StaticallyIndexedArray<T, N>& x, const Y& y) |
| 217 | +{ |
| 218 | + static_assert(Y::Size() == N, "wrong! size not the same"); |
| 219 | + StaticallyIndexedArray<T, N> r; |
| 220 | + static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] - y[i]; }); |
| 221 | + return r; |
| 222 | +} |
| 223 | + |
| 224 | +// StaticallyIndexedArray * Y (element-wise) |
| 225 | +template <typename T, index_t N, IndexedContainer Y> |
| 226 | +__host__ __device__ constexpr auto operator*(const StaticallyIndexedArray<T, N>& x, const Y& y) |
| 227 | +{ |
| 228 | + static_assert(Y::Size() == N, "wrong! size not the same"); |
| 229 | + StaticallyIndexedArray<T, N> r; |
| 230 | + static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] * y[i]; }); |
| 231 | + return r; |
| 232 | +} |
| 233 | + |
| 234 | +// scalar * StaticallyIndexedArray |
| 235 | +template <typename T, index_t N, Scalar S> |
| 236 | +__host__ __device__ constexpr auto operator*(S a, const StaticallyIndexedArray<T, N>& x) |
| 237 | +{ |
| 238 | + StaticallyIndexedArray<T, N> r; |
| 239 | + static_for<0, N, 1>{}([&](auto i) { r(i) = a * x[i]; }); |
| 240 | + return r; |
| 241 | +} |
| 242 | + |
| 243 | +// StaticallyIndexedArray * scalar |
| 244 | +template <typename T, index_t N, Scalar S> |
| 245 | +__host__ __device__ constexpr auto operator*(const StaticallyIndexedArray<T, N>& x, S a) |
| 246 | +{ |
| 247 | + return a * x; |
| 248 | +} |
| 249 | + |
105 | 250 | } // namespace ck |
106 | 251 | #endif |
0 commit comments