Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions thunder/benchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,50 +272,56 @@ def __init__(self, config: InferenceBenchmarkConfig):
assert all(p.device == torch.device("meta") for p in model.parameters())

tp_plan = {
"*.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=True),
"*.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=True),
"*.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=True),
"*.layers.*.self_attn.o_proj": RowwiseParallel(use_local_output=True),
"*.layers.*.feed_forward.gate_proj": ColwiseParallel(use_local_output=False),
"*.layers.*.feed_forward.up_proj": ColwiseParallel(use_local_output=False),
"*.layers.*.feed_forward.down_proj": RowwiseParallel(use_local_output=True),
"model.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=True),
"model.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=True),
"model.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=True),
"model.layers.*.self_attn.o_proj": RowwiseParallel(use_local_output=True),
"model.layers.*.feed_forward.gate_proj": ColwiseParallel(use_local_output=False),
"model.layers.*.feed_forward.up_proj": ColwiseParallel(use_local_output=False),
"model.layers.*.feed_forward.down_proj": RowwiseParallel(use_local_output=True),
}

if not self.config.disable_moe_replacement:
tp_plan.update(
{
# Custom MoE
"*.layers.*.feed_forward.shared_experts.gate_proj": ColwiseParallel(
"model.layers.*.feed_forward.shared_experts.gate_proj": ColwiseParallel(
use_local_output=False, output_layouts=Shard(2)
),
"*.layers.*.feed_forward.shared_experts.up_proj": ColwiseParallel(
"model.layers.*.feed_forward.shared_experts.up_proj": ColwiseParallel(
use_local_output=False, output_layouts=Shard(2)
),
"*.layers.*.feed_forward.shared_experts.down_proj": RowwiseParallel(),
"*.layers.*.feed_forward.routed_experts.gate_proj": GroupedLinearColwiseParallel(
"model.layers.*.feed_forward.shared_experts.down_proj": RowwiseParallel(),
"model.layers.*.feed_forward.routed_experts.gate_proj": GroupedLinearColwiseParallel(
use_local_output=False
),
"*.layers.*.feed_forward.routed_experts.up_proj": GroupedLinearColwiseParallel(
"model.layers.*.feed_forward.routed_experts.up_proj": GroupedLinearColwiseParallel(
use_local_output=False
),
"*.layers.*.feed_forward.routed_experts.down_proj": GroupedLinearRowwiseParallel(),
"model.layers.*.feed_forward.routed_experts.down_proj": GroupedLinearRowwiseParallel(),
}
)

else:
tp_plan.update(
{
# HF MoE
"*.layers.*.feed_forward.shared_expert.gate_proj": ColwiseParallel(use_local_output=False),
"*.layers.*.feed_forward.shared_expert.up_proj": ColwiseParallel(use_local_output=False),
"*.layers.*.feed_forward.shared_expert.down_proj": RowwiseParallel(use_local_output=True),
"model.layers.*.feed_forward.shared_expert.gate_proj": ColwiseParallel(use_local_output=False),
"model.layers.*.feed_forward.shared_expert.up_proj": ColwiseParallel(use_local_output=False),
"model.layers.*.feed_forward.shared_expert.down_proj": RowwiseParallel(use_local_output=True),
# TODO:Need to write ParallelStyle for HF's grouped_mm implementation.
}
)

if mesh:
model = parallelize_module(model, mesh, tp_plan)

# Sanity check - verify attention projections are sharded
assert type(model.model.layers[0].self_attn.q_proj.weight) == DTensor, "q_proj should be DTensor"
assert type(model.model.layers[0].self_attn.k_proj.weight) == DTensor, "k_proj should be DTensor"
assert type(model.model.layers[0].self_attn.v_proj.weight) == DTensor, "v_proj should be DTensor"
assert type(model.model.layers[0].self_attn.o_proj.weight) == DTensor, "o_proj should be DTensor"

# Sanity check
if not self.config.disable_moe_replacement:
assert type(model.model.layers[1].feed_forward.shared_experts.gate_proj.weight) == DTensor
Expand Down Expand Up @@ -410,12 +416,17 @@ def generate_batch(self) -> tuple[torch.Tensor, HybridChunkedCache]:
input_ids = torch.randint(0, self.vocab_size, (batch_size, input_length), device=DEVICE)
if LooseVersion(transformers.__version__) >= LooseVersion("4.55"):
# Transformers deprecated HybridChunkedCache in favour of static in 4.55.x
# NOTE: tp_size should only reflect tensor parallelism size, not total world size.
# If 2D parallelism (data parallel + tensor parallel) is added in the future,
# tp_size should be set to the tensor parallelism size only (e.g., WORLD_SIZE // DP_SIZE),
# not WORLD_SIZE, so that StaticCache correctly handles sharded KV heads.
past_key_values = StaticCache(
config=self.hf_config,
max_batch_size=input_ids.shape[0],
max_cache_len=input_ids.shape[1] + self.config.output_length,
device=DEVICE,
dtype=torch.bfloat16,
tp_size=WORLD_SIZE if mesh else 1,
)
else:
past_key_values = HybridChunkedCache(
Expand Down
Loading