diff --git a/CMakeLists.txt b/CMakeLists.txt index cf80e37b5..31c95b2c9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,6 +24,7 @@ option(BUILD_SHARED_LIBS "Build shared libraries" ON) option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF) option(WITH_FLASH_ATTN "Compile with Flash Attention 2" OFF) option(ENABLE_ADDRESS_SANITIZER "ASAN" OFF) +option(ENABLE_LORA_RUNTIME "Enable experimental runtime LoRA delta injection (W += scale*B@A in-place)" OFF) MESSAGE(STATUS "Compiler Id: ${CMAKE_CXX_COMPILER_ID}") MESSAGE(STATUS "Compiler Version: ${CMAKE_CXX_COMPILER_VERSION}") @@ -33,6 +34,11 @@ if(ENABLE_PROFILING) add_definitions(-DCT2_ENABLE_PROFILING) endif() +if(ENABLE_LORA_RUNTIME) + message(STATUS "Enable experimental LoRA runtime delta injection") + add_definitions(-DCT2_ENABLE_LORA_RUNTIME) +endif() + if(DEFINED ENV{INTELROOT}) set(INTEL_ROOT_DEFAULT $ENV{INTELROOT}) elseif(DEFINED ENV{ONEAPI_ROOT}) diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h index 6f1d5c760..9a6c33220 100644 --- a/include/ctranslate2/models/model.h +++ b/include/ctranslate2/models/model.h @@ -123,6 +123,23 @@ namespace ctranslate2 { std::unordered_map get_variables() const; bool layer_exists(std::string prefix) const; +#ifdef CT2_ENABLE_LORA_RUNTIME + // Applies a LoRA adapter delta in-place: W' = W + scale * (lora_B @ lora_A). + // + // This is a low-level primitive — it modifies a single named weight buffer directly. + // Adapter pool management, language switching, and revert logic are entirely the + // caller's responsibility (see tools/lora_swap/ for a reference implementation). + // + // Supported dtypes: float32, float16, bfloat16. + // int8 and packed weights (CT2_USE_EXPERIMENTAL_PACKED_GEMM) raise std::runtime_error. + // NOT thread-safe — caller must hold the inference lock before calling. + // Returns false if the variable name is not found (caller may skip silently). + bool apply_lora_delta(const std::string& name, + const StorageView& lora_A, // shape [r, d_in] + const StorageView& lora_B, // shape [d_out, r] + float scale = 1.0f) const; +#endif // CT2_ENABLE_LORA_RUNTIME + // Attributes are saved as scalar variables. template T get_attribute(const std::string& name) const { diff --git a/python/cpp/whisper.cc b/python/cpp/whisper.cc index 34500745c..fb3d3e1c9 100644 --- a/python/cpp/whisper.cc +++ b/python/cpp/whisper.cc @@ -15,6 +15,22 @@ namespace ctranslate2 { return _pool->is_multilingual(); } +#ifdef CT2_ENABLE_LORA_RUNTIME + // Low-level primitive: writes W' = W + scale * (lora_B @ lora_A) into a single + // named weight buffer. Adapter pool management and language switching are the + // caller's responsibility. Acquires the inference lock for the duration of the write. + // + // Note: only patches replica 0. For inter_threads > 1, iterate all replicas. + bool apply_lora_delta(const std::string& name, + const StorageView& lora_A, + const StorageView& lora_B, + float scale = 1.0f) { + std::unique_lock lock(_mutex); + assert_model_is_ready(); + return const_cast(*model()).apply_lora_delta(name, lora_A, lora_B, scale); + } +#endif // CT2_ENABLE_LORA_RUNTIME + size_t n_mels() const { return _pool->n_mels(); } @@ -341,6 +357,39 @@ namespace ctranslate2 { A list of alignment results. )pbdoc") +#ifdef CT2_ENABLE_LORA_RUNTIME + .def("apply_lora_delta", + &WhisperWrapper::apply_lora_delta, + py::arg("name"), + py::arg("lora_A"), + py::arg("lora_B"), + py::arg("scale")=1.0f, + R"pbdoc( + [Experimental] Applies a LoRA adapter delta to a single named weight in-place. + + This is a low-level primitive: it writes W' = W + scale * (lora_B @ lora_A) + directly into the weight buffer and returns immediately. Adapter pool + management, language switching, and revert logic are entirely the caller's + responsibility. + + Supported dtypes: float32, float16, bfloat16. + int8 and packed-GEMM weights raise RuntimeError. + + Note: only the first replica is patched (inter_threads=1 assumed). + For multi-replica deployments, iterate replicas manually. + + Arguments: + name: CT2 variable name, e.g. + ``"decoder/layer_0/self_attention/linear_0/weight"`` + lora_A: StorageView of shape [r, d_in] (the A matrix). + lora_B: StorageView of shape [d_out, r] (the B matrix). + scale: Scalar multiplier (alpha/r in standard LoRA notation). + + Returns: + True if the variable was found and updated; False otherwise. + )pbdoc") +#endif // CT2_ENABLE_LORA_RUNTIME + .def("unload_model", &WhisperWrapper::unload_model, py::arg("to_cpu")=false, py::call_guard(), diff --git a/src/models/model.cc b/src/models/model.cc index 1d8295b0a..209b8fb24 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -264,6 +264,48 @@ namespace ctranslate2 { return false; } +#ifdef CT2_ENABLE_LORA_RUNTIME + bool Model::apply_lora_delta(const std::string& name, + const StorageView& lora_A, + const StorageView& lora_B, + float scale) const { + // Look up variable. Packed weights require unpack+repack, not yet supported. + auto it = _variable_index.find(name); + if (it == _variable_index.end()) { + const std::string packed_name = name + "_packed"; + it = _variable_index.find(packed_name); + if (it == _variable_index.end()) + return false; + throw std::runtime_error( + "apply_lora_delta: packed weights are not supported (name=" + name + + "). Rebuild without CT2_USE_EXPERIMENTAL_PACKED_GEMM."); + } + + StorageView& weight = *it->second; + const DataType orig_dtype = weight.dtype(); + const Device compute_dev = weight.device(); + + if (orig_dtype == DataType::INT8) + throw std::runtime_error( + "apply_lora_delta: int8 quantized weights are not supported; " + "use compute_type=float32 or float16."); + + // Cast LoRA matrices to float32 on weight's device. + StorageView A_f32 = lora_A.to(DataType::FLOAT32).to(compute_dev); + StorageView B_f32 = lora_B.to(DataType::FLOAT32).to(compute_dev); + + // Cast weight to float32, run fused GEMM, write back. + // W_f32 = scale * B @ A + 1.0 * W_f32 (single cuBLAS SGEMM on GPU). + StorageView W_f32 = weight.to(DataType::FLOAT32); + const ops::Gemm gemm_op(/*alpha=*/scale, /*beta=*/1.0f, + /*trans_a=*/false, /*trans_b=*/false); + gemm_op(B_f32, A_f32, W_f32); + + weight = W_f32.to(orig_dtype); + return true; + } +#endif // CT2_ENABLE_LORA_RUNTIME + bool Model::get_flag_with_default(const std::string& name, bool default_value) const { return get_attribute_with_default(name, static_cast(default_value)); }