Skip to content

Commit 1b33b98

Browse files
tenpercentclaude
andcommitted
Rewrite StaticallyIndexedArray to use C-array instead of Tuple
Replace the recursive template metaprogramming implementation of StaticallyIndexedArray with a simple C-array based struct. This avoids deep template instantiation while maintaining the same interface. Key changes: - StaticallyIndexedArray now stores `T data_[N]` instead of inheriting from Tuple - Added constexpr conversion constructor to convert from any indexed container (Tuple, etc.) - Added arithmetic operators (+, -, *, +=, -=) using C++20 concepts - Added overloads for container_reorder_given_new2old/old2new - Added overloads for get_container_subset and set_container_subset - Specialization for empty array (N=0) Co-Authored-By: Claude <[email protected]>
1 parent 57c8cb1 commit 1b33b98

File tree

2 files changed

+224
-28
lines changed

2 files changed

+224
-28
lines changed

include/ck/utility/container_helper.hpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,25 @@ __host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple<T
7676
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
7777
}
7878

79+
template <typename T, index_t N, index_t... IRs>
80+
__host__ __device__ constexpr auto
81+
container_reorder_given_new2old(const StaticallyIndexedArray<T, N>& old_arr,
82+
Sequence<IRs...> /*new2old*/)
83+
{
84+
static_assert(N == sizeof...(IRs), "wrong! size not consistent");
85+
static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
86+
return make_statically_indexed_array<T>(old_arr[Number<IRs>{}]...);
87+
}
88+
89+
template <typename T, index_t N, index_t... IRs>
90+
__host__ __device__ constexpr auto
91+
container_reorder_given_old2new(const StaticallyIndexedArray<T, N>& old_arr,
92+
Sequence<IRs...> old2new)
93+
{
94+
return container_reorder_given_new2old(
95+
old_arr, typename sequence_map_inverse<decltype(old2new)>::type{});
96+
}
97+
7998
template <index_t... Is, index_t... IRs>
8099
__host__ __device__ constexpr auto container_reorder_given_new2old(Sequence<Is...> /* old_seq */,
81100
Sequence<IRs...> /*new2old*/)
@@ -358,6 +377,15 @@ __host__ __device__ constexpr auto get_container_subset(const Tuple<Ts...>& tup,
358377
return make_tuple(tup[Number<Is>{}]...);
359378
}
360379

380+
template <typename T, index_t N, index_t... Is>
381+
__host__ __device__ constexpr auto get_container_subset(const StaticallyIndexedArray<T, N>& arr,
382+
Sequence<Is...>)
383+
{
384+
static_assert(N >= sizeof...(Is), "wrong! size");
385+
386+
return StaticallyIndexedArray<T, sizeof...(Is)>{arr[Number<Is>{}]...};
387+
}
388+
361389
template <typename T, index_t N, index_t... Is>
362390
__host__ __device__ constexpr void
363391
set_container_subset(Array<T, N>& y, Sequence<Is...> picks, const Array<T, sizeof...(Is)>& x)
@@ -376,6 +404,29 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>&
376404
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
377405
}
378406

407+
template <typename T, index_t N, index_t... Is>
408+
__host__ __device__ constexpr void
409+
set_container_subset(StaticallyIndexedArray<T, N>& y,
410+
Sequence<Is...> picks,
411+
const StaticallyIndexedArray<T, sizeof...(Is)>& x)
412+
{
413+
static_assert(N >= sizeof...(Is), "wrong! size");
414+
415+
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
416+
}
417+
418+
// Generic set_container_subset for StaticallyIndexedArray destination with any indexed source
419+
template <typename T, index_t N, index_t... Is, typename Src>
420+
requires requires { Src::Size(); }
421+
__host__ __device__ constexpr void
422+
set_container_subset(StaticallyIndexedArray<T, N>& y, Sequence<Is...> picks, const Src& x)
423+
{
424+
static_assert(N >= sizeof...(Is), "wrong! size");
425+
static_assert(Src::Size() == sizeof...(Is), "wrong! size mismatch");
426+
427+
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
428+
}
429+
379430
template <index_t... Is>
380431
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
381432
{

include/ck/utility/statically_indexed_array.hpp

Lines changed: 173 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,52 +10,124 @@
1010

1111
namespace ck {
1212

13-
namespace detail {
14-
template <typename X, typename Y>
15-
struct tuple_concat;
16-
17-
template <typename... Xs, typename... Ys>
18-
struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>>
19-
{
20-
using type = Tuple<Xs..., Ys...>;
21-
};
22-
23-
// StaticallyIndexedArrayImpl uses binary split for O(log N) depth
13+
// StaticallyIndexedArray using simple C-array instead of template metaprogramming
14+
// This avoids deep template instantiation while maintaining the same interface
2415
template <typename T, index_t N>
25-
struct StaticallyIndexedArrayImpl
16+
struct StaticallyIndexedArray
2617
{
27-
using type =
28-
typename tuple_concat<typename StaticallyIndexedArrayImpl<T, N / 2>::type,
29-
typename StaticallyIndexedArrayImpl<T, N - N / 2>::type>::type;
30-
};
18+
__host__ __device__ constexpr StaticallyIndexedArray() : data_{} {}
19+
20+
// Single-element constructor - exclude containers with matching size (to prefer conversion
21+
// constructor)
22+
template <typename X>
23+
requires(N == 1 &&
24+
// Allow if X is same type as T or doesn't have Size() method
25+
(is_same<remove_cvref_t<X>, T>::value || !requires { remove_cvref_t<X>::Size(); }))
26+
__host__ __device__ constexpr StaticallyIndexedArray(X&& x)
27+
: data_{static_cast<T>(ck::forward<X>(x))}
28+
{
29+
}
3130

32-
template <typename T>
33-
struct StaticallyIndexedArrayImpl<T, 0>
34-
{
35-
using type = Tuple<>;
31+
// Multi-element constructor
32+
template <typename... Xs>
33+
requires(sizeof...(Xs) == N && N > 1)
34+
__host__ __device__ constexpr StaticallyIndexedArray(Xs&&... xs)
35+
: data_{static_cast<T>(ck::forward<Xs>(xs))...}
36+
{
37+
}
38+
39+
// Conversion constructor from any indexed container (Tuple, etc.)
40+
template <typename Container>
41+
requires(!is_same<remove_cvref_t<Container>, StaticallyIndexedArray>::value &&
42+
requires { Container::Size(); } && Container::Size() == N)
43+
__host__ __device__ constexpr StaticallyIndexedArray(const Container& src)
44+
: StaticallyIndexedArray(
45+
make_from_container(src, typename arithmetic_sequence_gen<0, N, 1>::type{}))
46+
{
47+
}
48+
49+
private:
50+
template <typename Container, index_t... Is>
51+
__host__ __device__ static constexpr StaticallyIndexedArray
52+
make_from_container(const Container& src, Sequence<Is...>)
53+
{
54+
return StaticallyIndexedArray{static_cast<T>(src[Number<Is>{}])...};
55+
}
56+
57+
public:
58+
__host__ __device__ static constexpr index_t Size() { return N; }
59+
60+
// read access
61+
template <index_t I>
62+
__host__ __device__ constexpr const T& At(Number<I>) const
63+
{
64+
static_assert(I < N, "wrong! out of range");
65+
return data_[I];
66+
}
67+
68+
// write access
69+
template <index_t I>
70+
__host__ __device__ constexpr T& At(Number<I>)
71+
{
72+
static_assert(I < N, "wrong! out of range");
73+
return data_[I];
74+
}
75+
76+
// read access
77+
template <index_t I>
78+
__host__ __device__ constexpr const T& operator[](Number<I> i) const
79+
{
80+
return At(i);
81+
}
82+
83+
// write access
84+
template <index_t I>
85+
__host__ __device__ constexpr T& operator()(Number<I> i)
86+
{
87+
return At(i);
88+
}
89+
90+
template <typename U>
91+
__host__ __device__ constexpr auto operator=(const U& a)
92+
{
93+
static_assert(U::Size() == Size(), "wrong! size not the same");
94+
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
95+
return *this;
96+
}
97+
98+
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
99+
100+
T data_[N];
36101
};
37102

103+
// Specialization for empty array
38104
template <typename T>
39-
struct StaticallyIndexedArrayImpl<T, 1>
105+
struct StaticallyIndexedArray<T, 0>
40106
{
41-
using type = Tuple<T>;
42-
};
43-
} // namespace detail
107+
__host__ __device__ constexpr StaticallyIndexedArray() = default;
44108

45-
template <typename T, index_t N>
46-
using StaticallyIndexedArray = typename detail::StaticallyIndexedArrayImpl<T, N>::type;
109+
__host__ __device__ static constexpr index_t Size() { return 0; }
110+
111+
template <typename U>
112+
__host__ __device__ constexpr auto operator=(const U&)
113+
{
114+
return *this;
115+
}
116+
117+
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
118+
};
47119

48120
template <typename X, typename... Xs>
49121
__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
50122
{
51-
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
123+
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>{x, static_cast<X>(xs)...};
52124
}
53125

54126
// make empty StaticallyIndexedArray
55127
template <typename X>
56128
__host__ __device__ constexpr auto make_statically_indexed_array()
57129
{
58-
return StaticallyIndexedArray<X, 0>();
130+
return StaticallyIndexedArray<X, 0>{};
59131
}
60132

61133
template <typename T, index_t N>
@@ -102,5 +174,78 @@ struct StaticallyIndexedArray_v2
102174
T data_[N];
103175
};
104176

177+
// Concepts for StaticallyIndexedArray arithmetic operators
178+
template <typename T>
179+
concept Scalar = ck::is_integral<T>::value || ck::is_floating_point<T>::value;
180+
181+
template <typename T>
182+
concept IndexedContainer = !Scalar<T> && requires { T::Size(); };
183+
184+
// Arithmetic operators for StaticallyIndexedArray (to match Tuple operators)
185+
186+
// StaticallyIndexedArray += X
187+
template <typename T, index_t N, IndexedContainer X>
188+
__host__ __device__ constexpr auto operator+=(StaticallyIndexedArray<T, N>& y, const X& x)
189+
{
190+
static_assert(X::Size() == N, "wrong! size not the same");
191+
static_for<0, N, 1>{}([&](auto i) { y(i) += x[i]; });
192+
return y;
193+
}
194+
195+
// StaticallyIndexedArray -= X
196+
template <typename T, index_t N, IndexedContainer X>
197+
__host__ __device__ constexpr auto operator-=(StaticallyIndexedArray<T, N>& y, const X& x)
198+
{
199+
static_assert(X::Size() == N, "wrong! size not the same");
200+
static_for<0, N, 1>{}([&](auto i) { y(i) -= x[i]; });
201+
return y;
202+
}
203+
204+
// StaticallyIndexedArray + Y
205+
template <typename T, index_t N, IndexedContainer Y>
206+
__host__ __device__ constexpr auto operator+(const StaticallyIndexedArray<T, N>& x, const Y& y)
207+
{
208+
static_assert(Y::Size() == N, "wrong! size not the same");
209+
StaticallyIndexedArray<T, N> r;
210+
static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] + y[i]; });
211+
return r;
212+
}
213+
214+
// StaticallyIndexedArray - Y
215+
template <typename T, index_t N, IndexedContainer Y>
216+
__host__ __device__ constexpr auto operator-(const StaticallyIndexedArray<T, N>& x, const Y& y)
217+
{
218+
static_assert(Y::Size() == N, "wrong! size not the same");
219+
StaticallyIndexedArray<T, N> r;
220+
static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] - y[i]; });
221+
return r;
222+
}
223+
224+
// StaticallyIndexedArray * Y (element-wise)
225+
template <typename T, index_t N, IndexedContainer Y>
226+
__host__ __device__ constexpr auto operator*(const StaticallyIndexedArray<T, N>& x, const Y& y)
227+
{
228+
static_assert(Y::Size() == N, "wrong! size not the same");
229+
StaticallyIndexedArray<T, N> r;
230+
static_for<0, N, 1>{}([&](auto i) { r(i) = x[i] * y[i]; });
231+
return r;
232+
}
233+
234+
// scalar * StaticallyIndexedArray
235+
template <typename T, index_t N, Scalar S>
236+
__host__ __device__ constexpr auto operator*(S a, const StaticallyIndexedArray<T, N>& x)
237+
{
238+
StaticallyIndexedArray<T, N> r;
239+
static_for<0, N, 1>{}([&](auto i) { r(i) = a * x[i]; });
240+
return r;
241+
}
242+
243+
// StaticallyIndexedArray * scalar
244+
template <typename T, index_t N, Scalar S>
245+
__host__ __device__ constexpr auto operator*(const StaticallyIndexedArray<T, N>& x, S a)
246+
{
247+
return a * x;
248+
}
249+
105250
} // namespace ck
106251
#endif

0 commit comments

Comments
 (0)