diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py index 0ae7c5c527..c97e3b08fc 100644 --- a/thunder/benchmarks/benchmark_inference.py +++ b/thunder/benchmarks/benchmark_inference.py @@ -272,33 +272,33 @@ def __init__(self, config: InferenceBenchmarkConfig): assert all(p.device == torch.device("meta") for p in model.parameters()) tp_plan = { - "*.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=True), - "*.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=True), - "*.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=True), - "*.layers.*.self_attn.o_proj": RowwiseParallel(use_local_output=True), - "*.layers.*.feed_forward.gate_proj": ColwiseParallel(use_local_output=False), - "*.layers.*.feed_forward.up_proj": ColwiseParallel(use_local_output=False), - "*.layers.*.feed_forward.down_proj": RowwiseParallel(use_local_output=True), + "model.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=True), + "model.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=True), + "model.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=True), + "model.layers.*.self_attn.o_proj": RowwiseParallel(use_local_output=True), + "model.layers.*.feed_forward.gate_proj": ColwiseParallel(use_local_output=False), + "model.layers.*.feed_forward.up_proj": ColwiseParallel(use_local_output=False), + "model.layers.*.feed_forward.down_proj": RowwiseParallel(use_local_output=True), } if not self.config.disable_moe_replacement: tp_plan.update( { # Custom MoE - "*.layers.*.feed_forward.shared_experts.gate_proj": ColwiseParallel( + "model.layers.*.feed_forward.shared_experts.gate_proj": ColwiseParallel( use_local_output=False, output_layouts=Shard(2) ), - "*.layers.*.feed_forward.shared_experts.up_proj": ColwiseParallel( + "model.layers.*.feed_forward.shared_experts.up_proj": ColwiseParallel( use_local_output=False, output_layouts=Shard(2) ), - "*.layers.*.feed_forward.shared_experts.down_proj": RowwiseParallel(), - "*.layers.*.feed_forward.routed_experts.gate_proj": GroupedLinearColwiseParallel( + "model.layers.*.feed_forward.shared_experts.down_proj": RowwiseParallel(), + "model.layers.*.feed_forward.routed_experts.gate_proj": GroupedLinearColwiseParallel( use_local_output=False ), - "*.layers.*.feed_forward.routed_experts.up_proj": GroupedLinearColwiseParallel( + "model.layers.*.feed_forward.routed_experts.up_proj": GroupedLinearColwiseParallel( use_local_output=False ), - "*.layers.*.feed_forward.routed_experts.down_proj": GroupedLinearRowwiseParallel(), + "model.layers.*.feed_forward.routed_experts.down_proj": GroupedLinearRowwiseParallel(), } ) @@ -306,9 +306,9 @@ def __init__(self, config: InferenceBenchmarkConfig): tp_plan.update( { # HF MoE - "*.layers.*.feed_forward.shared_expert.gate_proj": ColwiseParallel(use_local_output=False), - "*.layers.*.feed_forward.shared_expert.up_proj": ColwiseParallel(use_local_output=False), - "*.layers.*.feed_forward.shared_expert.down_proj": RowwiseParallel(use_local_output=True), + "model.layers.*.feed_forward.shared_expert.gate_proj": ColwiseParallel(use_local_output=False), + "model.layers.*.feed_forward.shared_expert.up_proj": ColwiseParallel(use_local_output=False), + "model.layers.*.feed_forward.shared_expert.down_proj": RowwiseParallel(use_local_output=True), # TODO:Need to write ParallelStyle for HF's grouped_mm implementation. } ) @@ -316,6 +316,12 @@ def __init__(self, config: InferenceBenchmarkConfig): if mesh: model = parallelize_module(model, mesh, tp_plan) + # Sanity check - verify attention projections are sharded + assert type(model.model.layers[0].self_attn.q_proj.weight) == DTensor, "q_proj should be DTensor" + assert type(model.model.layers[0].self_attn.k_proj.weight) == DTensor, "k_proj should be DTensor" + assert type(model.model.layers[0].self_attn.v_proj.weight) == DTensor, "v_proj should be DTensor" + assert type(model.model.layers[0].self_attn.o_proj.weight) == DTensor, "o_proj should be DTensor" + # Sanity check if not self.config.disable_moe_replacement: assert type(model.model.layers[1].feed_forward.shared_experts.gate_proj.weight) == DTensor @@ -410,12 +416,17 @@ def generate_batch(self) -> tuple[torch.Tensor, HybridChunkedCache]: input_ids = torch.randint(0, self.vocab_size, (batch_size, input_length), device=DEVICE) if LooseVersion(transformers.__version__) >= LooseVersion("4.55"): # Transformers deprecated HybridChunkedCache in favour of static in 4.55.x + # NOTE: tp_size should only reflect tensor parallelism size, not total world size. + # If 2D parallelism (data parallel + tensor parallel) is added in the future, + # tp_size should be set to the tensor parallelism size only (e.g., WORLD_SIZE // DP_SIZE), + # not WORLD_SIZE, so that StaticCache correctly handles sharded KV heads. past_key_values = StaticCache( config=self.hf_config, max_batch_size=input_ids.shape[0], max_cache_len=input_ids.shape[1] + self.config.output_length, device=DEVICE, dtype=torch.bfloat16, + tp_size=WORLD_SIZE if mesh else 1, ) else: past_key_values = HybridChunkedCache(