Skip to content

Commit 1c27bdd

Browse files
committed
asn1: Add support for SET
Signed-off-by: Facundo Tuesca <[email protected]>
1 parent 7d42cdd commit 1c27bdd

8 files changed

Lines changed: 244 additions & 1 deletion

File tree

src/cryptography/hazmat/asn1/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
decode_der,
2020
encode_der,
2121
sequence,
22+
set,
2223
)
2324

2425
__all__ = [
@@ -38,4 +39,5 @@
3839
"decode_der",
3940
"encode_der",
4041
"sequence",
42+
"set",
4143
]

src/cryptography/hazmat/asn1/asn1.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,10 @@ def _normalize_field_type(
164164

165165
if hasattr(field_type, "__asn1_root__"):
166166
root_type = field_type.__asn1_root__
167-
if not isinstance(root_type, declarative_asn1.Type.Sequence):
167+
if not isinstance(
168+
root_type,
169+
(declarative_asn1.Type.Sequence, declarative_asn1.Type.Set),
170+
):
168171
raise TypeError(f"unsupported root type: {root_type}")
169172
return declarative_asn1.AnnotatedType(
170173
typing.cast(declarative_asn1.Type, root_type), annotation
@@ -325,6 +328,13 @@ def _register_asn1_sequence(cls: type[U]) -> None:
325328
setattr(cls, "__asn1_root__", root)
326329

327330

331+
def _register_asn1_set(cls: type[U]) -> None:
332+
raw_fields = get_type_hints(cls, include_extras=True)
333+
root = declarative_asn1.Type.Set(cls, _annotate_fields(raw_fields))
334+
335+
setattr(cls, "__asn1_root__", root)
336+
337+
328338
# Due to https://github.com/python/mypy/issues/19731, we can't define an alias
329339
# for `dataclass_transform` that conditionally points to `typing` or
330340
# `typing_extensions` depending on the Python version (like we do for
@@ -356,6 +366,29 @@ def sequence(cls: type[U]) -> type[U]:
356366
_register_asn1_sequence(dataclass_cls)
357367
return dataclass_cls
358368

369+
@typing_extensions.dataclass_transform(kw_only_default=True)
370+
def set(cls: type[U]) -> type[U]:
371+
# We use `dataclasses.dataclass` to add an __init__ method
372+
# to the class with keyword-only parameters.
373+
if sys.version_info >= (3, 10):
374+
dataclass_cls = dataclasses.dataclass(
375+
repr=False,
376+
eq=False,
377+
# `match_args` was added in Python 3.10 and defaults
378+
# to True
379+
match_args=False,
380+
# `kw_only` was added in Python 3.10 and defaults to
381+
# False
382+
kw_only=True,
383+
)(cls)
384+
else:
385+
dataclass_cls = dataclasses.dataclass(
386+
repr=False,
387+
eq=False,
388+
)(cls)
389+
_register_asn1_set(dataclass_cls)
390+
return dataclass_cls
391+
359392
else:
360393

361394
@typing.dataclass_transform(kw_only_default=True)
@@ -371,6 +404,19 @@ def sequence(cls: type[U]) -> type[U]:
371404
_register_asn1_sequence(dataclass_cls)
372405
return dataclass_cls
373406

407+
@typing.dataclass_transform(kw_only_default=True)
408+
def set(cls: type[U]) -> type[U]:
409+
# Only add an __init__ method, with keyword-only
410+
# parameters.
411+
dataclass_cls = dataclasses.dataclass(
412+
repr=False,
413+
eq=False,
414+
match_args=False,
415+
kw_only=True,
416+
)(cls)
417+
_register_asn1_set(dataclass_cls)
418+
return dataclass_cls
419+
374420

375421
# TODO: replace with `Default[U]` once the min Python version is >= 3.12
376422
@dataclasses.dataclass(frozen=True)

src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def non_root_python_to_rust(cls: type) -> Type: ...
1414
class Type:
1515
Sequence: typing.ClassVar[type]
1616
SequenceOf: typing.ClassVar[type]
17+
Set: typing.ClassVar[type]
1718
SetOf: typing.ClassVar[type]
1819
Option: typing.ClassVar[type]
1920
Choice: typing.ClassVar[type]

src/rust/src/declarative_asn1/decode.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,21 @@ pub(crate) fn decode_annotated_type<'a>(
286286
Ok(list.into_any())
287287
})?
288288
}
289+
Type::Set(cls, fields) => {
290+
let set_parse_result = read_value::<asn1::Set<'_>>(parser, encoding)?;
291+
292+
set_parse_result.parse(|d| -> ParseResult<pyo3::Bound<'a, pyo3::PyAny>> {
293+
let kwargs = pyo3::types::PyDict::new(py);
294+
let fields = fields.bind(py);
295+
for (name, ann_type) in fields.into_iter() {
296+
let ann_type = ann_type.cast::<AnnotatedType>()?;
297+
let value = decode_annotated_type(py, d, ann_type.get())?;
298+
kwargs.set_item(name, value)?;
299+
}
300+
let val = cls.call(py, (), Some(&kwargs))?.into_bound(py);
301+
Ok(val)
302+
})?
303+
}
289304
Type::SetOf(cls) => {
290305
let setof_parse_result = read_value::<asn1::Set<'_>>(parser, encoding)?;
291306

src/rust/src/declarative_asn1/encode.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,22 @@ impl asn1::Asn1Writable for AnnotatedTypeObject<'_> {
7878

7979
write_value(writer, &asn1::SequenceOfWriter::new(values), encoding)
8080
}
81+
Type::Set(_cls, fields) => write_value(
82+
writer,
83+
&asn1::SetWriter::new(&|w| {
84+
for (name, ann_type) in fields.bind(py).into_iter() {
85+
let name = name.cast::<pyo3::types::PyString>()?;
86+
let ann_type = ann_type.cast::<AnnotatedType>()?;
87+
let object = AnnotatedTypeObject {
88+
annotated_type: ann_type.get(),
89+
value: self.value.getattr(name)?,
90+
};
91+
w.write_element(&object)?;
92+
}
93+
Ok(())
94+
}),
95+
encoding,
96+
),
8197
Type::SetOf(cls) => {
8298
let setof = value.cast::<super::types::SetOf>()?;
8399
let values: Vec<AnnotatedTypeObject<'_>> = setof

src/rust/src/declarative_asn1/types.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ pub enum Type {
2323
Sequence(pyo3::Py<pyo3::types::PyType>, pyo3::Py<pyo3::types::PyDict>),
2424
/// SEQUENCE OF (`list[`T`]`)
2525
SequenceOf(pyo3::Py<AnnotatedType>),
26+
/// SET(`class`, `dict`)
27+
/// The first element is the Python class that represents the set,
28+
/// the second element is a dict of the (already converted) fields of the class.
29+
Set(pyo3::Py<pyo3::types::PyType>, pyo3::Py<pyo3::types::PyDict>),
2630
/// SET OF (`list[`T`]`)
2731
SetOf(pyo3::Py<AnnotatedType>),
2832
/// OPTIONAL (`T | None`)
@@ -650,6 +654,7 @@ pub(crate) fn is_tag_valid_for_type(
650654
) -> bool {
651655
match type_ {
652656
Type::Sequence(_, _) => check_tag_with_encoding(asn1::Sequence::TAG, encoding, tag),
657+
Type::Set(_, _) => check_tag_with_encoding(asn1::Set::TAG, encoding, tag),
653658
Type::SequenceOf(_) => check_tag_with_encoding(asn1::Sequence::TAG, encoding, tag),
654659
Type::SetOf(_) => check_tag_with_encoding(asn1::SetOf::<()>::TAG, encoding, tag),
655660
Type::Option(t) => is_tag_valid_for_type(py, tag, t.get().inner.get(), encoding),

tests/hazmat/asn1/test_api.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,10 @@ def test_fields_of_variant_type(self) -> None:
362362
assert seq._0 is type(None)
363363
assert seq._1 == {}
364364

365+
set = declarative_asn1.Type.Set(type(None), {})
366+
assert set._0 is type(None)
367+
assert set._1 == {}
368+
365369
ann_type = declarative_asn1.AnnotatedType(
366370
seq, declarative_asn1.Annotation()
367371
)
@@ -461,3 +465,69 @@ def test_fail_optional_tlv(self) -> None:
461465
@asn1.sequence
462466
class Example:
463467
invalid: typing.Union[asn1.TLV, None]
468+
469+
470+
class TestSetAPI:
471+
def test_fail_unsupported_field(self) -> None:
472+
class Unsupported:
473+
foo: int
474+
475+
with pytest.raises(TypeError, match="cannot handle type"):
476+
477+
@asn1.set
478+
class Example:
479+
foo: Unsupported
480+
481+
def test_fail_init_incorrect_field_name(self) -> None:
482+
@asn1.set
483+
class Example:
484+
foo: int
485+
486+
with pytest.raises(
487+
TypeError, match="got an unexpected keyword argument 'bar'"
488+
):
489+
Example(bar=3) # type: ignore[call-arg]
490+
491+
def test_fail_init_missing_field_name(self) -> None:
492+
@asn1.set
493+
class Example:
494+
foo: int
495+
496+
expected_err = (
497+
"missing 1 required keyword-only argument: 'foo'"
498+
if sys.version_info >= (3, 10)
499+
else "missing 1 required positional argument: 'foo'"
500+
)
501+
502+
with pytest.raises(TypeError, match=expected_err):
503+
Example() # type: ignore[call-arg]
504+
505+
def test_fail_positional_field_initialization(self) -> None:
506+
@asn1.set
507+
class Example:
508+
foo: int
509+
510+
# The kw-only init is only enforced in Python >= 3.10, which is
511+
# when the parameter `kw_only` for `dataclasses.datalass` was
512+
# added.
513+
if sys.version_info < (3, 10):
514+
assert Example(5).foo == 5 # type: ignore[misc]
515+
else:
516+
with pytest.raises(
517+
TypeError,
518+
match="takes 1 positional argument but 2 were given",
519+
):
520+
Example(5) # type: ignore[misc]
521+
522+
def test_fail_malformed_root_type(self) -> None:
523+
@asn1.set
524+
class Invalid:
525+
foo: int
526+
527+
setattr(Invalid, "__asn1_root__", int)
528+
529+
with pytest.raises(TypeError, match="unsupported root type"):
530+
531+
@asn1.set
532+
class Example:
533+
foo: Invalid

tests/hazmat/asn1/test_serialization.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,10 +528,16 @@ def test_ok_sequence_all_types_optional(self) -> None:
528528
class MyField:
529529
a: int
530530

531+
@asn1.set
532+
@_comparable_dataclass
533+
class MySetField:
534+
a: int
535+
531536
@asn1.sequence
532537
@_comparable_dataclass
533538
class Example:
534539
a: typing.Union[MyField, None]
540+
a2: typing.Union[MySetField, None]
535541
b: typing.Union[int, None]
536542
c: typing.Union[bytes, None]
537543
d: typing.Union[asn1.PrintableString, None]
@@ -553,6 +559,7 @@ class Example:
553559
(
554560
Example(
555561
a=None,
562+
a2=None,
556563
b=None,
557564
c=None,
558565
d=None,
@@ -589,6 +596,11 @@ class MyField:
589596
)
590597
default_oid = x509.ObjectIdentifier("1.3.6.1.4.1.343")
591598

599+
@asn1.set
600+
@_comparable_dataclass
601+
class MySetField:
602+
a: int
603+
592604
@asn1.sequence
593605
@_comparable_dataclass
594606
class Example:
@@ -628,6 +640,10 @@ class Example:
628640
MyField,
629641
asn1.Default(MyField(a=9)),
630642
]
643+
k3: Annotated[
644+
MySetField,
645+
asn1.Default(MySetField(a=9)),
646+
]
631647
z: Annotated[str, asn1.Default("a"), asn1.Implicit(0)]
632648
only_field_present: Annotated[
633649
str, asn1.Default("a"), asn1.Implicit(1)
@@ -649,6 +665,7 @@ class Example:
649665
j=3,
650666
k=asn1.Null(),
651667
k2=MyField(a=9),
668+
k3=MySetField(a=9),
652669
z="a",
653670
only_field_present="b",
654671
),
@@ -1047,6 +1064,77 @@ class Example:
10471064
)
10481065

10491066

1067+
class TestSet:
1068+
def test_ok_set_single_field(self) -> None:
1069+
@asn1.set
1070+
@_comparable_dataclass
1071+
class Example:
1072+
foo: int
1073+
1074+
assert_roundtrips([(Example(foo=9), b"\x31\x03\x02\x01\x09")])
1075+
1076+
def test_ok_set_multiple_fields(self) -> None:
1077+
@asn1.set
1078+
@_comparable_dataclass
1079+
class Example:
1080+
foo: int
1081+
bar: int
1082+
1083+
assert_roundtrips(
1084+
[(Example(foo=6, bar=9), b"\x31\x06\x02\x01\x06\x02\x01\x09")]
1085+
)
1086+
1087+
def test_fail_set_multiple_fields_wrong_order(self) -> None:
1088+
@asn1.set
1089+
@_comparable_dataclass
1090+
class Example:
1091+
foo: int
1092+
bar: int
1093+
1094+
with pytest.raises(
1095+
ValueError,
1096+
match=re.escape(
1097+
"invalid SET ordering while performing ASN.1 serialization"
1098+
),
1099+
):
1100+
assert_roundtrips(
1101+
[(Example(foo=9, bar=6), b"\x31\x06\x02\x01\x06\x02\x01\x09")]
1102+
)
1103+
1104+
def test_ok_nested_set(self) -> None:
1105+
@asn1.set
1106+
@_comparable_dataclass
1107+
class Child:
1108+
foo: int
1109+
1110+
@asn1.set
1111+
@_comparable_dataclass
1112+
class Parent:
1113+
foo: Child
1114+
1115+
assert_roundtrips(
1116+
[(Parent(foo=Child(foo=9)), b"\x31\x05\x31\x03\x02\x01\x09")]
1117+
)
1118+
1119+
def test_ok_set_multiple_types(self) -> None:
1120+
@asn1.set
1121+
@_comparable_dataclass
1122+
class Example:
1123+
a: bool
1124+
b: int
1125+
c: bytes
1126+
d: str
1127+
1128+
assert_roundtrips(
1129+
[
1130+
(
1131+
Example(a=True, b=9, c=b"c", d="d"),
1132+
b"\x31\x0c\x01\x01\xff\x02\x01\x09\x04\x01c\x0c\x01d",
1133+
)
1134+
]
1135+
)
1136+
1137+
10501138
class TestSize:
10511139
def test_ok_sequenceof_size_restriction(self) -> None:
10521140
@asn1.sequence

0 commit comments

Comments
 (0)