diff --git a/include/infinicore/nn/rope.hpp b/include/infinicore/nn/rope.hpp index 4385d3a64..020257f0d 100644 --- a/include/infinicore/nn/rope.hpp +++ b/include/infinicore/nn/rope.hpp @@ -96,7 +96,8 @@ class RoPE : public Module { private: void initialize_cache(); - size_t rotary_dim_; // Dimension of each attention head + size_t rotary_dim_; // Number of dimensions to apply rotation to (must be even). + size_t head_dim_; // Dimension of each attention head size_t max_seq_len_; // Maximum sequence length double theta_; // Base frequency for rotary embeddings Algo algo_; // RoPE algorithm type diff --git a/src/infinicore/nn/rope.cc b/src/infinicore/nn/rope.cc index e60091f12..831347674 100644 --- a/src/infinicore/nn/rope.cc +++ b/src/infinicore/nn/rope.cc @@ -3,6 +3,7 @@ #include "../utils.hpp" #include "infinicore/ops.hpp" #include +#include #include #include #include @@ -20,6 +21,7 @@ RoPE::RoPE(size_t head_dim, const Device &device, std::shared_ptr scaling) : rotary_dim_(rotary_dim), + head_dim_(head_dim), max_seq_len_(max_seq_len), theta_(theta), algo_(algo), @@ -29,7 +31,7 @@ RoPE::RoPE(size_t head_dim, if (rotary_dim % 2 != 0) { throw std::invalid_argument("rotary_dim must be even for RoPE, got " + std::to_string(rotary_dim)); } - + assert((rotary_dim > 0) && (rotary_dim <= head_dim_)); device_ = device; // Initialize cache tables @@ -120,23 +122,34 @@ void RoPE::initialize_cache() { } Tensor RoPE::forward(const Tensor &x, const Tensor &pos, bool in_place) const { + Tensor y; if (in_place) { - Tensor y = Tensor(x); - op::rope_(y, x, pos, sin_cache_, cos_cache_, algo_); - return y; + y = Tensor(x); + } else { + y = Tensor::empty(x->shape(), x->dtype(), x->device()); + if (rotary_dim_ < head_dim_) { + y->copy_from(x); + } } - return op::rope(x, pos, sin_cache_, cos_cache_, algo_); + size_t ndim = x->ndim(); + op::rope_(y->narrow({{ndim - 1, 0, rotary_dim_}}), + x->narrow({{ndim - 1, 0, rotary_dim_}}), + pos, sin_cache_, cos_cache_, algo_); + return y; } Tensor RoPE::forward(const Tensor &y, const Tensor &x, const Tensor &pos) const { - op::rope_(y, x, pos, sin_cache_, cos_cache_, algo_); + size_t ndim = x->ndim(); + op::rope_(y->narrow({{ndim - 1, 0, rotary_dim_}}), + x->narrow({{ndim - 1, 0, rotary_dim_}}), + pos, sin_cache_, cos_cache_, algo_); return y; } std::string RoPE::extra_repr() const { std::string algo_str = (algo_ == Algo::GPT_J) ? "GPT_J" : "GPT_NEOX"; - return "RoPE(rotary_dim=" + std::to_string(rotary_dim_) + ", max_seq_len=" + std::to_string(max_seq_len_) + ", theta=" + std::to_string(theta_) + ", algo=" + algo_str + ", dtype=" + std::to_string(static_cast(dtype_)) + ")"; + return "RoPE(head_dim=" + std::to_string(head_dim_) + ", rotary_dim=" + std::to_string(rotary_dim_) + ", max_seq_len=" + std::to_string(max_seq_len_) + ", theta=" + std::to_string(theta_) + ", algo=" + algo_str + ", dtype=" + std::to_string(static_cast(dtype_)) + ")"; } } // namespace infinicore::nn