Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,7 @@ def matmul_4bit(
bias: Optional[torch.Tensor] = None,
):
assert quant_state is not None
# Change dtype to input dtype on CPU
if A.device.type == "cpu":
quant_state.dtype = A.dtype

if getattr(quant_state, "packing_format_for_cpu", False):
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
Expand Down
115 changes: 79 additions & 36 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,42 +258,85 @@ def __setstate__(self, state):
self.bnb_quantized = state["bnb_quantized"]
self.module = state["module"]

# Map from state_dict key names (as produced by QuantState.as_dict) to
# the actual QuantState attribute/access path. FSDP's _get_fqns() resolves
# dotted FQN keys via getattr, so "weight.quant_map" becomes
# getattr(weight, "quant_map") — we must map that to quant_state.code.
_QUANT_STATE_ATTR_MAP = {
# Direct QuantState attributes
"absmax": lambda qs: qs.absmax,
"code": lambda qs: qs.code,
"blocksize": lambda qs: qs.blocksize,
"dtype": lambda qs: qs.dtype,
"shape": lambda qs: qs.shape,
"offset": lambda qs: qs.offset,
"state2": lambda qs: qs.state2,
# as_dict serializes code → "quant_map"
"quant_map": lambda qs: qs.code,
"quant_type": lambda qs: qs.quant_type,
# as_dict serializes nested state2 attributes under "nested_*" keys
"nested_absmax": lambda qs: qs.state2.absmax,
"nested_blocksize": lambda qs: qs.state2.blocksize,
"nested_quant_map": lambda qs: qs.state2.code,
"nested_dtype": lambda qs: qs.state2.dtype,
"nested_offset": lambda qs: qs.offset,
}

def __getattr__(self, name):
# Proxy known QuantState attributes so that PyTorch's FSDP state_dict
# machinery (which traverses FQN paths via getattr) can find them.
accessor = self._QUANT_STATE_ATTR_MAP.get(name)
if accessor is not None:
quant_state = self.__dict__.get("quant_state")
if quant_state is not None:
try:
return accessor(quant_state)
except AttributeError:
pass
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
# Properties that proxy QuantState attributes for FSDP state_dict traversal.
# FSDP's _get_fqns() resolves dotted FQN keys via getattr, e.g. "weight.absmax"
# becomes getattr(weight, "absmax"). Using @property instead of __getattr__
# avoids torch.compile graph breaks (see #1904), since Dynamo can trace
# descriptor protocol access but not __getattr__ on Tensor subclasses.
#
# Note: attributes that collide with Params4bit instance attrs (blocksize,
# quant_type) or Tensor attrs (dtype, shape) are intentionally omitted —
# they are packed into the bitsandbytes__* blob and not traversed by FSDP.

@property
def absmax(self):
qs = self.__dict__.get("quant_state")
if qs is not None:
return qs.absmax
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'absmax'")

@property
def code(self):
qs = self.__dict__.get("quant_state")
if qs is not None:
return qs.code
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'code'")

@property
def quant_map(self):
qs = self.__dict__.get("quant_state")
if qs is not None:
return qs.code
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'quant_map'")

@property
def offset(self):
qs = self.__dict__.get("quant_state")
if qs is not None:
return qs.offset
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'offset'")

@property
def state2(self):
qs = self.__dict__.get("quant_state")
if qs is not None:
return qs.state2
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'state2'")

@property
def nested_absmax(self):
qs = self.__dict__.get("quant_state")
if qs is not None and qs.state2 is not None:
return qs.state2.absmax
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_absmax'")

@property
def nested_blocksize(self):
qs = self.__dict__.get("quant_state")
if qs is not None and qs.state2 is not None:
return qs.state2.blocksize
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_blocksize'")

@property
def nested_quant_map(self):
qs = self.__dict__.get("quant_state")
if qs is not None and qs.state2 is not None:
return qs.state2.code
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_quant_map'")

@property
def nested_dtype(self):
qs = self.__dict__.get("quant_state")
if qs is not None and qs.state2 is not None:
return qs.state2.dtype
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_dtype'")

@property
def nested_offset(self):
qs = self.__dict__.get("quant_state")
if qs is not None:
return qs.offset
raise AttributeError(f"'{type(self).__name__}' object has no attribute 'nested_offset'")

def __deepcopy__(self, memo):
new_instance = type(self).__new__(type(self))
Expand Down
71 changes: 70 additions & 1 deletion tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,75 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st
torch.testing.assert_close(grad_compiled, grad_ref)


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.skipif(torch.__version__ < (2, 8, 0, "dev"), reason="fullgraph requires torch 2.8+")
@pytest.mark.skipif(
torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10"
)
def test_linear4bit_torch_compile_activation_checkpointing(device, quant_type, compress_statistics):
"""Regression test for #1904: __getattr__ on Params4bit causes graph breaks under torch.compile.

Activation checkpointing replays the forward pass during backward, which multiplies
attribute accesses on Params4bit. If __getattr__ is defined (instead of @property),
Dynamo cannot trace through it and creates graph breaks. With fullgraph=True, this
causes torch.compile to raise an error rather than silently degrading performance.
"""
if device == "hpu" and not is_supported_on_hpu(quant_type):
pytest.skip("This configuration is not supported on HPU.")
if device == "cuda" and platform.system() == "Windows":
pytest.skip("Triton is not officially supported on Windows")
dim = 256
batch_size = 16
compute_dtype = torch.bfloat16

torch.compiler.reset()

class CheckpointedNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ModuleList(
[
bnb.nn.Linear4bit(
dim,
dim,
bias=False,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
)
for _ in range(4)
]
)

def forward(self, x):
for layer in self.layers:
x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
return x

net = CheckpointedNet().to(device)

x = torch.randn(batch_size, dim, dtype=compute_dtype, device=device, requires_grad=True)

# Reference output (eager)
ref_output = net(x)
ref_output.sum().backward()
grad_ref = x.grad.clone()
x.grad = None

# Compiled with fullgraph=True — will raise if there are graph breaks
compile_backend = "hpu_backend" if device == "hpu" else "inductor"
compiled_net = torch.compile(net, fullgraph=True, backend=compile_backend)

compiled_output = compiled_net(x)
compiled_output.sum().backward()
grad_compiled = x.grad.clone()

torch.testing.assert_close(compiled_output, ref_output)
torch.testing.assert_close(grad_compiled, grad_ref)


@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
Expand Down Expand Up @@ -494,7 +563,7 @@ def test_params4bit_quant_state_attr_access(device, quant_type, compress_statist
with pytest.raises(AttributeError, match="nonexistent_attribute"):
_ = w.nonexistent_attribute

# Verify that normal Params4bit attributes are unaffected by __getattr__
# Verify that normal Params4bit instance attributes are unaffected
assert isinstance(w.quant_state, bnb.functional.QuantState)
assert isinstance(w.bnb_quantized, bool)
assert w.bnb_quantized is True
Expand Down
Loading