Skip to content

Commit 2397be7

Browse files
committed
asn1: Add support for SET
Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com>
1 parent b48368d commit 2397be7

7 files changed

Lines changed: 156 additions & 0 deletions

File tree

src/cryptography/hazmat/asn1/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
decode_der,
2020
encode_der,
2121
sequence,
22+
set,
2223
)
2324

2425
__all__ = [
@@ -38,4 +39,5 @@
3839
"decode_der",
3940
"encode_der",
4041
"sequence",
42+
"set",
4143
]

src/cryptography/hazmat/asn1/asn1.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,16 @@ def _register_asn1_sequence(cls: type[U]) -> None:
326326
setattr(cls, "__asn1_root__", root)
327327

328328

329+
def _register_asn1_set(cls: type[U]) -> None:
330+
raw_fields = get_type_hints(cls, include_extras=True)
331+
root = declarative_asn1.AnnotatedType(
332+
declarative_asn1.Type.Set(cls, _annotate_fields(raw_fields)),
333+
declarative_asn1.Annotation(),
334+
)
335+
336+
setattr(cls, "__asn1_root__", root)
337+
338+
329339
# Due to https://github.com/python/mypy/issues/19731, we can't define an alias
330340
# for `dataclass_transform` that conditionally points to `typing` or
331341
# `typing_extensions` depending on the Python version (like we do for
@@ -357,6 +367,29 @@ def sequence(cls: type[U]) -> type[U]:
357367
_register_asn1_sequence(dataclass_cls)
358368
return dataclass_cls
359369

370+
@typing_extensions.dataclass_transform(kw_only_default=True)
371+
def set(cls: type[U]) -> type[U]:
372+
# We use `dataclasses.dataclass` to add an __init__ method
373+
# to the class with keyword-only parameters.
374+
if sys.version_info >= (3, 10):
375+
dataclass_cls = dataclasses.dataclass(
376+
repr=False,
377+
eq=False,
378+
# `match_args` was added in Python 3.10 and defaults
379+
# to True
380+
match_args=False,
381+
# `kw_only` was added in Python 3.10 and defaults to
382+
# False
383+
kw_only=True,
384+
)(cls)
385+
else:
386+
dataclass_cls = dataclasses.dataclass(
387+
repr=False,
388+
eq=False,
389+
)(cls)
390+
_register_asn1_set(dataclass_cls)
391+
return dataclass_cls
392+
360393
else:
361394

362395
@typing.dataclass_transform(kw_only_default=True)
@@ -372,6 +405,19 @@ def sequence(cls: type[U]) -> type[U]:
372405
_register_asn1_sequence(dataclass_cls)
373406
return dataclass_cls
374407

408+
@typing.dataclass_transform(kw_only_default=True)
409+
def set(cls: type[U]) -> type[U]:
410+
# Only add an __init__ method, with keyword-only
411+
# parameters.
412+
dataclass_cls = dataclasses.dataclass(
413+
repr=False,
414+
eq=False,
415+
match_args=False,
416+
kw_only=True,
417+
)(cls)
418+
_register_asn1_set(dataclass_cls)
419+
return dataclass_cls
420+
375421

376422
# TODO: replace with `Default[U]` once the min Python version is >= 3.12
377423
@dataclasses.dataclass(frozen=True)

src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def non_root_python_to_rust(cls: type) -> Type: ...
1414
class Type:
1515
Sequence: typing.ClassVar[type]
1616
SequenceOf: typing.ClassVar[type]
17+
Set: typing.ClassVar[type]
1718
SetOf: typing.ClassVar[type]
1819
Option: typing.ClassVar[type]
1920
Choice: typing.ClassVar[type]

src/rust/src/declarative_asn1/decode.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,21 @@ pub(crate) fn decode_annotated_type<'a>(
286286
Ok(list.into_any())
287287
})?
288288
}
289+
Type::Set(cls, fields) => {
290+
let set_parse_result = read_value::<asn1::Set<'_>>(parser, encoding)?;
291+
292+
set_parse_result.parse(|d| -> ParseResult<pyo3::Bound<'a, pyo3::PyAny>> {
293+
let kwargs = pyo3::types::PyDict::new(py);
294+
let fields = fields.bind(py);
295+
for (name, ann_type) in fields.into_iter() {
296+
let ann_type = ann_type.cast::<AnnotatedType>()?;
297+
let value = decode_annotated_type(py, d, ann_type.get())?;
298+
kwargs.set_item(name, value)?;
299+
}
300+
let val = cls.call(py, (), Some(&kwargs))?.into_bound(py);
301+
Ok(val)
302+
})?
303+
}
289304
Type::SetOf(cls) => {
290305
let setof_parse_result = read_value::<asn1::Set<'_>>(parser, encoding)?;
291306

src/rust/src/declarative_asn1/encode.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,22 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> {
7878

7979
write_value(writer, &asn1::SequenceOfWriter::new(values), encoding)
8080
}
81+
Type::Set(_cls, fields) => write_value(
82+
writer,
83+
&asn1::SetWriter::new(&|w| {
84+
for (name, ann_type) in fields.bind(py).into_iter() {
85+
let name = name.cast::<pyo3::types::PyString>()?;
86+
let ann_type = ann_type.cast::<AnnotatedType>()?;
87+
let object = AnnotatedTypeObject {
88+
annotated_type: ann_type.get(),
89+
value: self.value.getattr(name)?,
90+
};
91+
w.write_element(&object)?;
92+
}
93+
Ok(())
94+
}),
95+
encoding,
96+
),
8197
Type::SetOf(cls) => {
8298
let setof = value.cast::<super::types::SetOf>()?;
8399
let values: Vec<AnnotatedTypeObject<'_>> = setof

src/rust/src/declarative_asn1/types.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ pub enum Type {
2323
Sequence(pyo3::Py<pyo3::types::PyType>, pyo3::Py<pyo3::types::PyDict>),
2424
/// SEQUENCE OF (`list[`T`]`)
2525
SequenceOf(pyo3::Py<AnnotatedType>),
26+
/// SET(`class`, `dict`)
27+
/// The first element is the Python class that represents the set,
28+
/// the second element is a dict of the (already converted) fields of the class.
29+
Set(pyo3::Py<pyo3::types::PyType>, pyo3::Py<pyo3::types::PyDict>),
2630
/// SET OF (`list[`T`]`)
2731
SetOf(pyo3::Py<AnnotatedType>),
2832
/// OPTIONAL (`T | None`)
@@ -640,6 +644,7 @@ pub(crate) fn is_tag_valid_for_type(
640644
) -> bool {
641645
match type_ {
642646
Type::Sequence(_, _) => check_tag_with_encoding(asn1::Sequence::TAG, encoding, tag),
647+
Type::Set(_, _) => check_tag_with_encoding(asn1::Set::TAG, encoding, tag),
643648
Type::SequenceOf(_) => check_tag_with_encoding(asn1::Sequence::TAG, encoding, tag),
644649
Type::SetOf(_) => check_tag_with_encoding(asn1::SetOf::<()>::TAG, encoding, tag),
645650
Type::Option(t) => is_tag_valid_for_type(py, tag, t.get().inner.get(), encoding),

tests/hazmat/asn1/test_serialization.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,77 @@ class Example:
10371037
)
10381038

10391039

1040+
class TestSet:
1041+
def test_ok_set_single_field(self) -> None:
1042+
@asn1.set
1043+
@_comparable_dataclass
1044+
class Example:
1045+
foo: int
1046+
1047+
assert_roundtrips([(Example(foo=9), b"\x31\x03\x02\x01\x09")])
1048+
1049+
def test_ok_set_multiple_fields(self) -> None:
1050+
@asn1.set
1051+
@_comparable_dataclass
1052+
class Example:
1053+
foo: int
1054+
bar: int
1055+
1056+
assert_roundtrips(
1057+
[(Example(foo=6, bar=9), b"\x31\x06\x02\x01\x06\x02\x01\x09")]
1058+
)
1059+
1060+
def test_fail_set_multiple_fields_wrong_order(self) -> None:
1061+
@asn1.set
1062+
@_comparable_dataclass
1063+
class Example:
1064+
foo: int
1065+
bar: int
1066+
1067+
with pytest.raises(
1068+
ValueError,
1069+
match=re.escape(
1070+
"invalid SET ordering while performing ASN.1 serialization"
1071+
),
1072+
):
1073+
assert_roundtrips(
1074+
[(Example(foo=9, bar=6), b"\x31\x06\x02\x01\x06\x02\x01\x09")]
1075+
)
1076+
1077+
def test_ok_nested_set(self) -> None:
1078+
@asn1.set
1079+
@_comparable_dataclass
1080+
class Child:
1081+
foo: int
1082+
1083+
@asn1.set
1084+
@_comparable_dataclass
1085+
class Parent:
1086+
foo: Child
1087+
1088+
assert_roundtrips(
1089+
[(Parent(foo=Child(foo=9)), b"\x31\x05\x31\x03\x02\x01\x09")]
1090+
)
1091+
1092+
def test_ok_set_multiple_types(self) -> None:
1093+
@asn1.set
1094+
@_comparable_dataclass
1095+
class Example:
1096+
a: bool
1097+
b: int
1098+
c: bytes
1099+
d: str
1100+
1101+
assert_roundtrips(
1102+
[
1103+
(
1104+
Example(a=True, b=9, c=b"c", d="d"),
1105+
b"\x31\x0c\x01\x01\xff\x02\x01\x09\x04\x01c\x0c\x01d",
1106+
)
1107+
]
1108+
)
1109+
1110+
10401111
class TestSize:
10411112
def test_ok_sequenceof_size_restriction(self) -> None:
10421113
@asn1.sequence

0 commit comments

Comments
 (0)