Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 280 additions & 5 deletions c10/cuda/CUDAStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,43 @@
#include <atomic>
#include <cstdint>

#include <execinfo.h>
#include <cxxabi.h>
#include <iostream>
#include <cstdlib>
#include <memory>


#include "execinfo.h"
#include <iostream>
#include <sstream>
#include <fstream>
#include <regex>
#include <unistd.h>
#include <set>
#include <thread>
#include <mutex>

#include <dlfcn.h>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <unordered_map>
#include <string>

#include <torch/csrc/autograd/function.h>


static std::string CcaGetEnv(const char* name, const char* default_value) {
auto rtn = std::getenv(name);
if (rtn) {
return rtn;
}
return default_value;
}

static int dev_idx_to_print = -1;

namespace c10::cuda {

namespace {
Expand Down Expand Up @@ -58,6 +95,9 @@ static c10::once_flag
[C10_COMPILE_TIME_MAX_GPUS][kStreamsPerPool];
#endif

static std::array<cudaStream_t, C10_COMPILE_TIME_MAX_GPUS> default_streams;
static c10::once_flag default_stream_flags[C10_COMPILE_TIME_MAX_GPUS];

// Note [HIP Lazy Streams]
// ~~~~~~~~~~~~~~~~~~~~~~~
// For ROCm/HIP, each stream is lazily initialized rather than creating all
Expand Down Expand Up @@ -197,14 +237,201 @@ static void initGlobalStreamState() {
: range;
}

// Init a single CUDA or HIP stream


static std::vector<uint32_t> createCustomMask(const std::vector<int>& enabledCUs, int totalCUs) {
std::vector<uint32_t> cuMask((totalCUs + 31) / 32, 0);
for (int cu : enabledCUs) {
if (cu >= 0 && cu < totalCUs) {
int wordIndex = cu / 32;
int bitIndex = cu % 32;
cuMask[wordIndex] |= (1UL << bitIndex);
}
}
return cuMask;
}

// Helper function to parse a single item (either number or start:count like "0:7")
static std::vector<int> parseCustomItem(const std::string& item) {
std::vector<int> cus;
// Check if it's a start:count format (contains ':')
size_t colonPos = item.find(':');
if (colonPos != std::string::npos && colonPos > 0 && colonPos < item.length() - 1) {
// Parse start:count format: "start:count"
int start = std::stoi(item.substr(0, colonPos));
int count = std::stoi(item.substr(colonPos + 1));

// Ensure count > 0
if (count > 0) {
for (int i = 0; i < count; i++) {
cus.push_back(start + i);
}
} else {
std::cerr << "Invalid count in custom CU mask: " << item
<< " (count must be > 0)" << std::endl;
}
} else {
// Single number
cus.push_back(std::stoi(item));
}
return cus;
}

int32_t getTotalCUs() {
hipDeviceProp_t prop;
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, c10::cuda::current_device()));
std::string arch_name = prop.gcnArchName;
return (arch_name.find("gfx942") != std::string::npos) ? 304 : 256;
}

std::vector<uint32_t> GetCuMask(int32_t &enable_cu_num, std::string mask_str, bool lower_bits_zero) {
const int32_t totalCUs = getTotalCUs();

if (mask_str.substr(0, 7) == "custom=") {
// Parse custom CU list: custom=0:7,32:8 or custom:1,2,5
std::string cuList = mask_str.substr(7);
std::vector<int> enabledCUs;
std::stringstream ss(cuList);
std::string item;

enable_cu_num = 0;

while (std::getline(ss, item, ',')) {
// Trim whitespace
item.erase(0, item.find_first_not_of(" \t"));
item.erase(item.find_last_not_of(" \t") + 1);

// Parse this item (could be single number or start:count)
std::vector<int> itemCUs = parseCustomItem(item);
enabledCUs.insert(enabledCUs.end(), itemCUs.begin(), itemCUs.end());
enable_cu_num += itemCUs.size();
}
return createCustomMask(enabledCUs, totalCUs);
}


if (CcaGetEnv("DBGENV_REVERSE_MASK", "0") == "1") {
lower_bits_zero = !lower_bits_zero;
}

enable_cu_num = std::stoi(mask_str);
constexpr int32_t single_mask_bits = 32;

constexpr int32_t max_cu_num = 512;
assert(totalCUs <= max_cu_num);

std::vector<uint32_t> mask;
std::bitset<max_cu_num> bits_mask(0);
int32_t start_idx = lower_bits_zero ? (totalCUs - enable_cu_num) : 0;
int32_t end_idx = lower_bits_zero ? totalCUs : enable_cu_num;

auto get_mask_bit_index = [=](int index) -> int {
constexpr int se_num = 32;
constexpr int mapping_to_se[se_num] =
{1, 5, 9, 13, 17, 21, 25, 29, 2, 6, 10, 14, 18, 22, 26, 30, 0, 4, 8, 12, 16, 20, 24, 28, 3, 7, 11, 15, 19, 23, 27, 31};
// {1, 2, 5, 6, 9, 10, 13, 14, 17, 18, 21, 22, 25, 26, 29, 30, 0, 3, 4, 7, 8, 11, 12, 15, 16, 19, 20, 23, 24, 27, 28, 31};
int se = mapping_to_se[index % se_num];
int rtn = (se % 4) * 8 + (se / 4) + (index / se_num) * se_num;
if (index >= 288) {
rtn = index;
}
return rtn;
};

for (int32_t i = start_idx; i < end_idx; i++) {
bits_mask.set(get_mask_bit_index(i));
}

for (int i = 0; i < (totalCUs + single_mask_bits - 1) / single_mask_bits; i++) {
auto tmp_mask = bits_mask;
for (int b = single_mask_bits; b < max_cu_num; ++b) {
tmp_mask.reset(b);
}
mask.push_back(static_cast<uint32_t>(tmp_mask.to_ulong()));
bits_mask = bits_mask >> single_mask_bits;
}
return mask;
}

/*


*/

static void initSingleDefaultStream(DeviceIndex device_index);

static void create_masked_stream(cudaStream_t *stream, const char*cu_num_env, const char*env_default,int device_index, bool lower_bits_zero, int pri, int i = 0) {
const char* create_msg = "";
std::string env_str = CcaGetEnv(cu_num_env, env_default);
const int32_t totalCUs = getTotalCUs();
if (env_str == "-1") {
*stream = nullptr;
create_msg = "use_nullptr";
} else if (env_str == "0") {
C10_CUDA_CHECK(cudaStreamCreateWithPriority(stream, kDefaultFlags, pri));
create_msg = "priority_stream";
} else if (env_str == "999") {
TORCH_CHECK(std::string("DBGENV_DEFAULT_COMP_STREAM_CU") != cu_num_env)
c10::call_once(default_stream_flags[device_index], initSingleDefaultStream, device_index);
*stream = default_streams[device_index];
create_msg = "same_as_default_stream";
} else {
create_msg = "masked_stream";
int32_t enable_cu_num = 0;
std::vector<uint32_t> mask = GetCuMask(enable_cu_num, env_str, lower_bits_zero);
TORCH_CHECK(enable_cu_num <= totalCUs);
TORCH_CHECK(enable_cu_num > 0);

if (device_index == 0) {
{
std::ostringstream oss;
oss << "cca_log create_masked_stream";
for (int m = mask.size() - 1; m >= 0; --m) {
oss << std::hex << " [" << m << "]=" << mask[m];
}
oss << std::dec << " dev " << device_index
<< " env " << cu_num_env << " cu_num " << enable_cu_num << " i " << i;
std::fprintf(stderr, "%s\n", oss.str().c_str());
}

{
std::ostringstream oss;
oss << "0x";
for (int m = mask.size() - 1; m >= 0; --m) {
oss << std::hex << std::setfill('0') << std::setw(8) << mask[m];
}
std::fprintf(stderr, "%s\n", oss.str().c_str());
}
}

C10_CUDA_CHECK(hipExtStreamCreateWithCUMask(stream, mask.size(), &mask[0]));
}
if (device_index == 0) {
std::fprintf(stderr, "cca_log create_stream %s %p %s=%s i %d totalCUs %d GetTraceID %d\n", create_msg, (void*)*stream, cu_num_env, std::getenv(cu_num_env), i, totalCUs, GetTraceID(true));
}
}
// Init a single HIP or HIP stream
// See Note [HIP Lazy Streams]
static void initSingleStream(int p, DeviceIndex device_index, int i) {
CUDAGuard device_guard(device_index);
auto& stream = streams[p][device_index][i];
auto pri = -p; // lower number is higher priority

C10_CUDA_CHECK(cudaStreamCreateWithPriority(&stream, kDefaultFlags, pri));
dev_idx_to_print = std::stoi(CcaGetEnv("DBGENV_DEVIDX_PRINT", "-1"));

const char *env_name = "DBGENV_DEFAULT_RCCL_STREAM_CU";
bool lower_bits_zero = true;

if (i == 1) {
env_name = "DBGENV_2ND_COMP_STREAM_CU";
lower_bits_zero = false;
} else if (i == 2) {
env_name = "DBGENV_2ND_RCCL_STREAM_CU";
lower_bits_zero = true;
}

create_masked_stream(&stream, env_name, "0", device_index, lower_bits_zero, pri, i);

const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_stream_creation(
Expand All @@ -213,6 +440,16 @@ static void initSingleStream(int p, DeviceIndex device_index, int i) {
}
}

static void initSingleDefaultStream(DeviceIndex device_index) {
CUDAGuard device_guard(device_index);
auto& stream = default_streams[device_index];
auto pri = 0; // lower number is higher priority

dev_idx_to_print = std::stoi(CcaGetEnv("DBGENV_DEVIDX_PRINT", "-1"));

create_masked_stream(&stream, "DBGENV_DEFAULT_COMP_STREAM_CU", "-1", device_index, false, -1);
}

// Creates the low and high priority stream pools for the specified device
// Warning: only call once per device!
static void initDeviceStreamState(DeviceIndex device_index) {
Expand All @@ -238,6 +475,9 @@ static void initCUDAStreamsOnce() {
for (const auto i : c10::irange(num_gpus)) {
current_streams[i] = makeStreamId(StreamIdType::DEFAULT, 0);
}

// TORCH_WARN("cca_log initCUDAStreamsOnce num_gpus ", (int)num_gpus, " GetTraceID ", GetTraceID())

}

// Helper to verify the GPU index is valid
Expand Down Expand Up @@ -285,7 +525,9 @@ cudaStream_t CUDAStream::stream() const {
").",
" Did you manufacture the StreamId yourself? Don't do that; use the",
" official API like c10::cuda::getStreamFromPool() to get a new stream.");
return nullptr;
// See Note [HIP Lazy Streams]
c10::call_once(default_stream_flags[device_index], initSingleDefaultStream, device_index);
return default_streams[device_index];
} else if (st.isExt()) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<cudaStream_t>(stream_id);
Expand Down Expand Up @@ -355,7 +597,17 @@ CUDAStream getDefaultCUDAStream(DeviceIndex device_index) {
c10::cuda::SetTargetDevice();
}
check_gpu(device_index);
return CUDAStreamForId(device_index, makeStreamId(StreamIdType::DEFAULT, 0));
auto rtn = CUDAStreamForId(device_index, makeStreamId(StreamIdType::DEFAULT, 0));
// TORCH_WARN("cca_log getDefaultCUDAStream device_index ", (int)device_index, " stream ", rtn.stream())

if (device_index == dev_idx_to_print) {
std::fprintf(stderr, "cca_log getDefaultCUDAStream device_index %d stream %p GetTraceID %d\n",
(int)device_index,
(void*)rtn.stream(),
GetTraceID());
}

return rtn;
}

CUDAStream getCurrentCUDAStream(DeviceIndex device_index) {
Expand All @@ -365,11 +617,34 @@ CUDAStream getCurrentCUDAStream(DeviceIndex device_index) {
c10::cuda::SetTargetDevice();
}
check_gpu(device_index);
return CUDAStreamForId(device_index, current_streams[device_index]);
auto rtn = CUDAStreamForId(device_index, current_streams[device_index]);
if (device_index == dev_idx_to_print) {
// TORCH_WARN("cca_log getCurrentCUDAStream device_index ", (int)device_index, " stream ", rtn.stream(), " GetTraceID ", GetTraceID())
std::fprintf(stderr, "cca_log getCurrentCUDAStream device_index %d stream %p tid %zu GetTraceID %d\n",
(int)device_index,
(void*)rtn.stream(),
std::hash<std::thread::id>{}(std::this_thread::get_id()),
GetTraceID());
}
return rtn;
}

void setCurrentCUDAStream(CUDAStream stream) {
initCUDAStreamsOnce();
if (stream.device_index() == dev_idx_to_print) {
// TORCH_WARN("cca_log setCurrentCUDAStream device_index ", (int)stream.device_index(),
// " from ", CUDAStreamForId(stream.device_index(), current_streams[stream.device_index()]).stream(),
// " to ", stream.stream(), " GetTraceID ", GetTraceID());

std::fprintf(stderr, "cca_log setCurrentCUDAStream device_index %d from %p to %p tid %zu GetTraceID %d\n",
(int)stream.device_index(),
(void*)CUDAStreamForId(stream.device_index(), current_streams[stream.device_index()]).stream(),
(void*)stream.stream(),
std::hash<std::thread::id>{}(std::this_thread::get_id()),
GetTraceID()
);
}

current_streams[stream.device_index()] = stream.id();
}

Expand Down
5 changes: 5 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1548,6 +1548,11 @@ target_link_libraries(torch_cpu PUBLIC c10)
target_link_libraries(torch_cpu PUBLIC ${Caffe2_PUBLIC_DEPENDENCY_LIBS})
target_link_libraries(torch_cpu PRIVATE ${Caffe2_DEPENDENCY_LIBS})
target_link_libraries(torch_cpu PRIVATE ${Caffe2_DEPENDENCY_WHOLE_LINK_LIBS})

if(USE_ROCM)
target_link_libraries(torch_cpu PRIVATE c10_hip)
endif()

if(USE_MPI)
target_link_libraries(torch_cpu PRIVATE MPI::MPI_CXX)
endif()
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/api/include/torch/nn/parallel/data_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ void replicate_grad_edges(
auto grad_fn = std::make_shared<ReduceAdd>((*parameter).device());
grad_fn->set_next_edges(autograd::collect_next_edges(*parameter));

// std::fprintf(stderr, "cca_log 113 grad_fn->name %s\n", grad_fn->name().c_str());
for (const auto i : c10::irange(devices.size())) {
autograd::set_history(replicas[i]->parameters_[parameter.key()], grad_fn);
}
Expand All @@ -120,6 +121,7 @@ void replicate_grad_edges(
auto grad_fn = std::make_shared<ReduceAdd>((*buffer).device());
grad_fn->set_next_edges(autograd::collect_next_edges(*buffer));

// std::fprintf(stderr, "cca_log 123 grad_fn->name %s\n", grad_fn->name().c_str());
for (const auto i : c10::irange(devices.size())) {
autograd::set_history(replicas[i]->buffers_[buffer.key()], grad_fn);
}
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/autograd/VariableTypeManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor& self, int64_t level) {
})();

if (grad_fn) {
CCADEBUG(std::fprintf(stderr, "cca_log 137 grad_fn->name %s\n", grad_fn->name().c_str()));
set_history(flatten_tensor_args(result), grad_fn);
}
if (isFwGradDefined(self)) {
Expand Down Expand Up @@ -180,6 +181,7 @@ Tensor _make_dual(
})();

if (grad_fn) {
CCADEBUG(std::fprintf(stderr, "cca_log 183 grad_fn->name %s\n", grad_fn->name().c_str()));
set_history(flatten_tensor_args(result), grad_fn);
}

Expand Down
Loading