Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF)
option(WITH_FLASH_ATTN "Compile with Flash Attention 2" OFF)
option(ENABLE_ADDRESS_SANITIZER "ASAN" OFF)
option(ENABLE_LORA_RUNTIME "Enable experimental runtime LoRA delta injection (W += scale*B@A in-place)" OFF)

MESSAGE(STATUS "Compiler Id: ${CMAKE_CXX_COMPILER_ID}")
MESSAGE(STATUS "Compiler Version: ${CMAKE_CXX_COMPILER_VERSION}")
Expand All @@ -33,6 +34,11 @@ if(ENABLE_PROFILING)
add_definitions(-DCT2_ENABLE_PROFILING)
endif()

if(ENABLE_LORA_RUNTIME)
message(STATUS "Enable experimental LoRA runtime delta injection")
add_definitions(-DCT2_ENABLE_LORA_RUNTIME)
endif()

if(DEFINED ENV{INTELROOT})
set(INTEL_ROOT_DEFAULT $ENV{INTELROOT})
elseif(DEFINED ENV{ONEAPI_ROOT})
Expand Down
17 changes: 17 additions & 0 deletions include/ctranslate2/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,23 @@ namespace ctranslate2 {
std::unordered_map<std::string, StorageView> get_variables() const;
bool layer_exists(std::string prefix) const;

#ifdef CT2_ENABLE_LORA_RUNTIME
// Applies a LoRA adapter delta in-place: W' = W + scale * (lora_B @ lora_A).
//
// This is a low-level primitive — it modifies a single named weight buffer directly.
// Adapter pool management, language switching, and revert logic are entirely the
// caller's responsibility (see tools/lora_swap/ for a reference implementation).
//
// Supported dtypes: float32, float16, bfloat16.
// int8 and packed weights (CT2_USE_EXPERIMENTAL_PACKED_GEMM) raise std::runtime_error.
// NOT thread-safe — caller must hold the inference lock before calling.
// Returns false if the variable name is not found (caller may skip silently).
bool apply_lora_delta(const std::string& name,
const StorageView& lora_A, // shape [r, d_in]
const StorageView& lora_B, // shape [d_out, r]
float scale = 1.0f) const;
#endif // CT2_ENABLE_LORA_RUNTIME

// Attributes are saved as scalar variables.
template <typename T>
T get_attribute(const std::string& name) const {
Expand Down
49 changes: 49 additions & 0 deletions python/cpp/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@ namespace ctranslate2 {
return _pool->is_multilingual();
}

#ifdef CT2_ENABLE_LORA_RUNTIME
// Low-level primitive: writes W' = W + scale * (lora_B @ lora_A) into a single
// named weight buffer. Adapter pool management and language switching are the
// caller's responsibility. Acquires the inference lock for the duration of the write.
//
// Note: only patches replica 0. For inter_threads > 1, iterate all replicas.
bool apply_lora_delta(const std::string& name,
const StorageView& lora_A,
const StorageView& lora_B,
float scale = 1.0f) {
std::unique_lock<std::shared_mutex> lock(_mutex);
assert_model_is_ready();
return const_cast<models::Model&>(*model()).apply_lora_delta(name, lora_A, lora_B, scale);
}
#endif // CT2_ENABLE_LORA_RUNTIME

size_t n_mels() const {
return _pool->n_mels();
}
Expand Down Expand Up @@ -341,6 +357,39 @@ namespace ctranslate2 {
A list of alignment results.
)pbdoc")

#ifdef CT2_ENABLE_LORA_RUNTIME
.def("apply_lora_delta",
&WhisperWrapper::apply_lora_delta,
py::arg("name"),
py::arg("lora_A"),
py::arg("lora_B"),
py::arg("scale")=1.0f,
R"pbdoc(
[Experimental] Applies a LoRA adapter delta to a single named weight in-place.

This is a low-level primitive: it writes W' = W + scale * (lora_B @ lora_A)
directly into the weight buffer and returns immediately. Adapter pool
management, language switching, and revert logic are entirely the caller's
responsibility.

Supported dtypes: float32, float16, bfloat16.
int8 and packed-GEMM weights raise RuntimeError.

Note: only the first replica is patched (inter_threads=1 assumed).
For multi-replica deployments, iterate replicas manually.

Arguments:
name: CT2 variable name, e.g.
``"decoder/layer_0/self_attention/linear_0/weight"``
lora_A: StorageView of shape [r, d_in] (the A matrix).
lora_B: StorageView of shape [d_out, r] (the B matrix).
scale: Scalar multiplier (alpha/r in standard LoRA notation).

Returns:
True if the variable was found and updated; False otherwise.
)pbdoc")
#endif // CT2_ENABLE_LORA_RUNTIME

.def("unload_model", &WhisperWrapper::unload_model,
py::arg("to_cpu")=false,
py::call_guard<py::gil_scoped_release>(),
Expand Down
42 changes: 42 additions & 0 deletions src/models/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,48 @@ namespace ctranslate2 {
return false;
}

#ifdef CT2_ENABLE_LORA_RUNTIME
bool Model::apply_lora_delta(const std::string& name,
const StorageView& lora_A,
const StorageView& lora_B,
float scale) const {
// Look up variable. Packed weights require unpack+repack, not yet supported.
auto it = _variable_index.find(name);
if (it == _variable_index.end()) {
const std::string packed_name = name + "_packed";
it = _variable_index.find(packed_name);
if (it == _variable_index.end())
return false;
throw std::runtime_error(
"apply_lora_delta: packed weights are not supported (name=" + name
+ "). Rebuild without CT2_USE_EXPERIMENTAL_PACKED_GEMM.");
}

StorageView& weight = *it->second;
const DataType orig_dtype = weight.dtype();
const Device compute_dev = weight.device();

if (orig_dtype == DataType::INT8)
throw std::runtime_error(
"apply_lora_delta: int8 quantized weights are not supported; "
"use compute_type=float32 or float16.");

// Cast LoRA matrices to float32 on weight's device.
StorageView A_f32 = lora_A.to(DataType::FLOAT32).to(compute_dev);
StorageView B_f32 = lora_B.to(DataType::FLOAT32).to(compute_dev);

// Cast weight to float32, run fused GEMM, write back.
// W_f32 = scale * B @ A + 1.0 * W_f32 (single cuBLAS SGEMM on GPU).
StorageView W_f32 = weight.to(DataType::FLOAT32);
const ops::Gemm gemm_op(/*alpha=*/scale, /*beta=*/1.0f,
/*trans_a=*/false, /*trans_b=*/false);
gemm_op(B_f32, A_f32, W_f32);

weight = W_f32.to(orig_dtype);
return true;
}
#endif // CT2_ENABLE_LORA_RUNTIME

bool Model::get_flag_with_default(const std::string& name, bool default_value) const {
return get_attribute_with_default(name, static_cast<int8_t>(default_value));
}
Expand Down
Loading