Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 24 additions & 60 deletions include/infinicore/nn/rope.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
#include "../context/context.hpp"
#include "../tensor.hpp"
#include "module.hpp"
#include <memory>
#include "rope_scaling_configs.hpp"
#include <cmath>
#include <memory>

namespace infinicore::nn {

Expand All @@ -18,70 +19,33 @@ class RoPE : public Module {
GPT_NEOX = 1, // GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos)
};

enum class ScalingType {
DEFAULT = 0, // Default RoPE
LONGROPE = 1 // Long-RoPE
};

class ScalingConfig {
public:
virtual ~ScalingConfig() = default;
ScalingType type() const { return type_; }

protected:
ScalingType type_ = ScalingType::DEFAULT;
ScalingConfig(ScalingType type) : type_(type) {}
};

// longrope scaling
class LongRopeConfig : public ScalingConfig {
protected:
std::vector<float> short_factor_;
std::vector<float> long_factor_;
size_t original_max_position_embeddings_;
float factor_;

public:
LongRopeConfig(
std::vector<float> short_factor,
std::vector<float> long_factor,
size_t original_max_position_embeddings,
float factor = 1.0f)
: ScalingConfig(ScalingType::LONGROPE),
short_factor_(short_factor),
long_factor_(long_factor),
original_max_position_embeddings_(original_max_position_embeddings),
factor_(factor == 1.0f ? 1.0f : std::sqrt(1 + std::log(factor) / std::log(original_max_position_embeddings))) {}
~LongRopeConfig() override = default;
size_t original_max_position_embeddings() const { return original_max_position_embeddings_; }
const std::vector<float> &short_factor() const { return short_factor_; }
const std::vector<float> &long_factor() const { return long_factor_; }
float factor() const { return factor_; }
};

/**
* @brief Construct a RoPE layer
*
* @param head_dim Dimension of each attention head (must be even)
* @param rotary_dim Number of dimensions to apply rotation to (must be even).
* For full rotation models, this equals head_dim;
* for partial rotation models, this equals head_dim * partial_rotary_factor.
* @param max_seq_len Maximum sequence length for pre-computed cache
* @param theta Base frequency for rotary embeddings (default: 10000.0)
* @param algo RoPE algorithm type (default: Algo::GPT_J)
* @param dtype Data type for sin/cos cache (default: DataType::F32)
* @param device Device to create the cache on
* @param scaling RoPE scaling type (default: nullptr)
* @param scaling RoPE scaling configuration (default: nullptr)
*/
RoPE(size_t head_dim,
size_t rotary_dim,
Comment thread
rubik-hua marked this conversation as resolved.
size_t max_seq_len,
double theta = 10000.0,
Algo algo = Algo::GPT_J,
const DataType &dtype = DataType::F32,
const Device &device = Device(),
std::shared_ptr<ScalingConfig> scaling = nullptr);
std::shared_ptr<RopeScalingConfig> scaling = nullptr);

/**
* @brief Forward pass: apply RoPE to a tensor
*
* @param x Input tensor of shape (..., head_dim) where ... is any number of dimensions
* @param x Input tensor of shape (..., rotary_dim) where ... is any number of dimensions
* @param pos Position IDs tensor of shape (*,) typically [seq_len] or [batch, seq_len]
* @param in_place If true, modify input tensor in place (default: false)
* @return Rotated tensor with same shape as input
Expand All @@ -90,32 +54,32 @@ class RoPE : public Module {
* For attention mechanisms, call this method separately for query and key tensors.
*
* Common input shapes:
* - [batch, num_heads, seq_len, head_dim]
* - [batch, seq_len, num_heads, head_dim]
* - [seq_len, head_dim]
* - [batch, num_heads, seq_len, rotary_dim]
* - [batch, seq_len, num_heads, rotary_dim]
* - [seq_len, rotary_dim]
*/
Tensor forward(const Tensor &x, const Tensor &pos, bool in_place = false) const;

/**
* @brief Forward pass: apply RoPE to a tensor in place
*
* @param y Output tensor of shape (..., head_dim) where ... is any number of dimensions
* @param x Input tensor of shape (..., head_dim) where ... is any number of dimensions
* @param y Output tensor of shape (..., rotary_dim) where ... is any number of dimensions
* @param x Input tensor of shape (..., rotary_dim) where ... is any number of dimensions
* @param pos Position IDs tensor of shape (*,) typically [seq_len] or [batch, seq_len]
* @return Rotated tensor with same shape as input
*
* Applies rotary position embeddings to the input tensor.
* For attention mechanisms, call this method separately for query and key tensors.
*
* Common input shapes:
* - [batch, num_heads, seq_len, head_dim]
* - [batch, seq_len, num_heads, head_dim]
* - [seq_len, head_dim]
* - [batch, num_heads, seq_len, rotary_dim]
* - [batch, seq_len, num_heads, rotary_dim]
* - [seq_len, rotary_dim]
*/
Tensor forward(const Tensor &y, const Tensor &x, const Tensor &pos) const;

// Module information
size_t head_dim() const { return head_dim_; }
size_t rotary_dim() const { return rotary_dim_; }
size_t max_seq_len() const { return max_seq_len_; }
double theta() const { return theta_; }
Algo algo() const { return algo_; }
Expand All @@ -132,12 +96,12 @@ class RoPE : public Module {
private:
void initialize_cache();

size_t head_dim_; // Dimension of each attention head
size_t max_seq_len_; // Maximum sequence length
double theta_; // Base frequency for rotary embeddings
Algo algo_; // RoPE algorithm type
DataType dtype_; // Data type for cache tables
std::shared_ptr<ScalingConfig> scaling_; // RoPE scaling type
size_t rotary_dim_; // Dimension of each attention head
size_t max_seq_len_; // Maximum sequence length
double theta_; // Base frequency for rotary embeddings
Algo algo_; // RoPE algorithm type
DataType dtype_; // Data type for cache tables
std::shared_ptr<RopeScalingConfig> scaling_; // RoPE scaling configuration
};

} // namespace infinicore::nn
90 changes: 90 additions & 0 deletions include/infinicore/nn/rope_scaling_configs.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#pragma once
#include <memory>
#include <vector>

namespace infinicore::nn {

/**
* @brief Abstract base class for RoPE scaling strategies.
* Uses polymorphism to eliminate type checking (if-else) in the core RoPE loop.
*/
class RopeScalingConfig {
public:
virtual ~RopeScalingConfig() = default;

/**
* @brief Calculate the frequency scaling factor for a specific position and dimension.
*
* @param pos Current sequence position
* @param dim_idx Current dimension index (0 to head_dim/2 - 1)
* @param base_inv_freq Pre-computed base inverse frequency for this dimension (1.0 / theta^(2j/head_dim))
* @return Frequency scaling factor (default 1.0)
*/
virtual float get_freq_scale(size_t pos, size_t dim_idx, float base_inv_freq) const {
return 1.0f;
}

/**
* @brief Calculate the magnitude scaling factor for a specific position and dimension.
*
* @param pos Current sequence position
* @param dim_idx Current dimension index (0 to head_dim/2 - 1)
* @param base_inv_freq Pre-computed base inverse frequency for this dimension
* @return Magnitude scaling factor (default 1.0)
*/
virtual float get_magnitude_scale(size_t pos, size_t dim_idx, float base_inv_freq) const {
return 1.0f;
}
};

/**
* @brief LongRoPE scaling configuration.
*/
class LongRopeScalingConfig : public RopeScalingConfig {
public:
LongRopeScalingConfig(
std::vector<float> short_factor,
std::vector<float> long_factor,
size_t original_max_position_embeddings,
float factor = 1.0f);

float get_freq_scale(size_t pos, size_t dim_idx, float base_inv_freq) const override;
float get_magnitude_scale(size_t pos, size_t dim_idx, float base_inv_freq) const override;

size_t original_max_position_embeddings() const { return original_max_position_embeddings_; }
const std::vector<float> &short_factor() const { return short_factor_; }
const std::vector<float> &long_factor() const { return long_factor_; }
float factor() const { return factor_; }

private:
std::vector<float> short_factor_;
std::vector<float> long_factor_;
size_t original_max_position_embeddings_;
float factor_;
};

// TODO(rubik) implement in cpp
/**
* @brief Llama3 frequency-aware RoPE scaling configuration.
* Native support for Llama 3.1 RoPE scaling (smooth interpolation based on wavelength).
*/
class Llama3RopeScalingConfig : public RopeScalingConfig {
public:
Llama3RopeScalingConfig(
float factor,
float low_freq_factor,
float high_freq_factor,
size_t original_max_position_embeddings);

float get_freq_scale(size_t pos, size_t dim_idx, float base_inv_freq) const override;

// Llama3 does not use magnitude scaling, so it inherits the default get_magnitude_scale() returning 1.0f

private:
float factor_;
float low_freq_factor_;
float high_freq_factor_;
size_t original_max_position_embeddings_;
};

} // namespace infinicore::nn
57 changes: 23 additions & 34 deletions src/infinicore/nn/rope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@
namespace infinicore::nn {

RoPE::RoPE(size_t head_dim,
size_t rotary_dim,
size_t max_seq_len,
double theta,
Algo algo,
const DataType &dtype,
const Device &device,
std::shared_ptr<ScalingConfig> scaling)
: head_dim_(head_dim),
std::shared_ptr<RopeScalingConfig> scaling)
: rotary_dim_(rotary_dim),
max_seq_len_(max_seq_len),
theta_(theta),
algo_(algo),
dtype_(dtype),
scaling_(scaling) {
if (head_dim % 2 != 0) {
throw std::invalid_argument("head_dim must be even for RoPE, got " + std::to_string(head_dim));
// TODO use head_dim
if (rotary_dim % 2 != 0) {
throw std::invalid_argument("rotary_dim must be even for RoPE, got " + std::to_string(rotary_dim));
}

device_ = device;
Expand All @@ -35,14 +37,14 @@ RoPE::RoPE(size_t head_dim,
}

void RoPE::initialize_cache() {
size_t cache_dim = head_dim_ / 2;
size_t cache_dim = rotary_dim_ / 2;

// Create sin and cos cache tables: [max_seq_len, cache_dim]
INFINICORE_NN_BUFFER_INIT(sin_cache, ({max_seq_len_, cache_dim}, dtype_, device_));
INFINICORE_NN_BUFFER_INIT(cos_cache, ({max_seq_len_, cache_dim}, dtype_, device_));

// Pre-compute sin and cos values
// Frequency generation always uses GPT-J style (theta^(-2j/head_dim)).
// Frequency generation always uses GPT-J style (theta^(-2j/rotary_dim)).
// The rotation algorithm (algo_) controls how dimensions are paired in the kernel.

// Compute on CPU first, then copy to device
Expand All @@ -53,33 +55,20 @@ void RoPE::initialize_cache() {
std::vector<float> cos_data(max_seq_len_ * cache_dim);

for (size_t pos = 0; pos < max_seq_len_; pos++) {
for (size_t j = 0; j < cache_dim; j++) {
// GPT-J style inverse frequency: theta^(-2j/head_dim)
// Compute directly in float to avoid double->float casting
float inv_freq;
float table_factor = 1.0f;
if (scaling_ == nullptr) {
inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_));
} else if (scaling_->type() == ScalingType::LONGROPE) {
std::shared_ptr<LongRopeConfig> lr = std::dynamic_pointer_cast<LongRopeConfig>(scaling_);
table_factor = lr->factor();
float _ext;
if (pos < lr->original_max_position_embeddings()) {
_ext = lr->short_factor()[j];
} else {
_ext = lr->long_factor()[j];
}
inv_freq = 1.0f / (_ext * std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_)));
} else {
inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_));
}

// Compute angle: position * inverse_frequency
float angle = static_cast<float>(pos) * inv_freq;

// Compute sin and cos directly on float
sin_data[pos * cache_dim + j] = std::sin(angle) * table_factor;
cos_data[pos * cache_dim + j] = std::cos(angle) * table_factor;
for (size_t dim_idx = 0; dim_idx < cache_dim; dim_idx++) {
// 1. Base inverse frequency (shared across all RoPE types)
Comment thread
rubik-hua marked this conversation as resolved.
float base_inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(dim_idx) / static_cast<float>(rotary_dim_));

// 2. Polymorphic scaling resolution
// Passing pre-computed base_inv_freq avoids redundant pow() calculations in subclasses
float freq_scale = scaling_ ? scaling_->get_freq_scale(pos, dim_idx, base_inv_freq) : 1.0f;
float mag_scale = scaling_ ? scaling_->get_magnitude_scale(pos, dim_idx, base_inv_freq) : 1.0f;

// 3. Compute final angle and sin/cos values
float angle = static_cast<float>(pos) * base_inv_freq * freq_scale;

sin_data[pos * cache_dim + dim_idx] = std::sin(angle) * mag_scale;
cos_data[pos * cache_dim + dim_idx] = std::cos(angle) * mag_scale;
}
}

Expand Down Expand Up @@ -147,7 +136,7 @@ Tensor RoPE::forward(const Tensor &y, const Tensor &x, const Tensor &pos) const

std::string RoPE::extra_repr() const {
std::string algo_str = (algo_ == Algo::GPT_J) ? "GPT_J" : "GPT_NEOX";
return "RoPE(head_dim=" + std::to_string(head_dim_) + ", max_seq_len=" + std::to_string(max_seq_len_) + ", theta=" + std::to_string(theta_) + ", algo=" + algo_str + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
return "RoPE(rotary_dim=" + std::to_string(rotary_dim_) + ", max_seq_len=" + std::to_string(max_seq_len_) + ", theta=" + std::to_string(theta_) + ", algo=" + algo_str + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
}

} // namespace infinicore::nn
46 changes: 46 additions & 0 deletions src/infinicore/nn/rope_scaling_configs.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "infinicore/nn/rope_scaling_configs.hpp"
#include <cmath>
#include <stdexcept>

namespace infinicore::nn {

// LongRopeScalingConfig Implementation
LongRopeScalingConfig::LongRopeScalingConfig(
std::vector<float> short_factor,
std::vector<float> long_factor,
size_t original_max_position_embeddings,
float factor)
: short_factor_(std::move(short_factor)),
long_factor_(std::move(long_factor)),
original_max_position_embeddings_(original_max_position_embeddings),
factor_(factor == 1.0f ? 1.0f : std::sqrt(1 + std::log(factor) / std::log(original_max_position_embeddings))) {}

float LongRopeScalingConfig::get_freq_scale(size_t pos, size_t dim_idx, float base_inv_freq) const {
float _ext = (pos < original_max_position_embeddings_) ? short_factor_[dim_idx] : long_factor_[dim_idx];
// The base inv_freq is multiplied by this scale.
// Original: inv_freq = 1.0f / (_ext * pow(theta, 2j/head_dim))
// New: inv_freq = base_inv_freq * (1.0f / _ext)
return 1.0f / _ext;
}

float LongRopeScalingConfig::get_magnitude_scale(size_t pos, size_t dim_idx, float base_inv_freq) const {
return factor_;
}

// TODO(rubik) llama3 implement here
// Llama3RopeScalingConfig Implementation
Llama3RopeScalingConfig::Llama3RopeScalingConfig(
float factor,
float low_freq_factor,
float high_freq_factor,
size_t original_max_position_embeddings)
: factor_(factor),
low_freq_factor_(low_freq_factor),
high_freq_factor_(high_freq_factor),
original_max_position_embeddings_(original_max_position_embeddings) {}

float Llama3RopeScalingConfig::get_freq_scale(size_t pos, size_t dim_idx, float base_inv_freq) const {
return 1.0f;
}

} // namespace infinicore::nn
Loading