diff --git a/go.mod b/go.mod index 3bccf8ee50df6..232a975a09c6c 100644 --- a/go.mod +++ b/go.mod @@ -57,6 +57,7 @@ require ( github.com/prometheus/client_golang v1.23.2 github.com/scaleway/scaleway-sdk-go v1.0.0-beta.35 github.com/sergi/go-diff v1.4.0 + github.com/smallstep/pkcs7 v0.2.1 github.com/spf13/cobra v1.10.2 github.com/spf13/pflag v1.0.10 github.com/spf13/viper v1.21.0 diff --git a/go.sum b/go.sum index 2611f05ff1681..9745084160a74 100644 --- a/go.sum +++ b/go.sum @@ -294,6 +294,7 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-configfs-tsm v0.3.3-0.20240919001351-b4b5b84fdcbc h1:SG12DWUUM5igxm+//YX5Yq4vhdoRnOG9HkCodkOn+YU= @@ -567,6 +568,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/smallstep/pkcs7 v0.2.1 h1:6Kfzr/QizdIuB6LSv8y1LJdZ3aPSfTNhTLqAx9CTLfA= +github.com/smallstep/pkcs7 v0.2.1/go.mod h1:RcXHsMfL+BzH8tRhmrF1NkkpebKpq3JEM66cOFxanf0= github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= @@ -609,6 +612,7 @@ github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavM github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/bridges/prometheus v0.57.0 h1:UW0+QyeyBVhn+COBec3nGhfnFe5lwB0ic1JBVjzhk0w= @@ -672,12 +676,22 @@ golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3 golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 h1:jiDhWWeC7jfWqR9c/uplMOqJ0sbNlNWv0UkzE0vX1MA= golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90/go.mod h1:xE1HEv6b+1SCZ5/uscMRjUBKtIxworgEcEi+/n9NQDQ= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -689,6 +703,13 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= @@ -698,6 +719,12 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -711,20 +738,46 @@ golang.org/x/sys v0.0.0-20190922100055-0a153f010e69/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210426230700-d19ff857e887/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s= golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= @@ -736,6 +789,10 @@ golang.org/x/tools v0.0.0-20200505023115-26f46d2f7ef8/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/upup/pkg/fi/cloudup/azure/attest.go b/upup/pkg/fi/cloudup/azure/attest.go new file mode 100644 index 0000000000000..a0bcd6430a28d --- /dev/null +++ b/upup/pkg/fi/cloudup/azure/attest.go @@ -0,0 +1,578 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "bytes" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "path" + "strings" + "sync" + "time" + + "github.com/smallstep/pkcs7" + expirationcache "k8s.io/client-go/tools/cache" + "k8s.io/klog/v2" +) + +const ( + // attestedDocumentTimeFormat is the timestamp format used by the Azure IMDS attested document. + attestedDocumentTimeFormat = "01/02/06 15:04:05 -0700" + + // attestedDocumentNonceLength is the length of the hex-encoded SHA256 prefix + // used as the Azure IMDS attested document nonce. IMDS enforces a 32-character + // maximum for the nonce parameter; 32 hex chars is 128 bits of entropy, well + // above the cryptographic nonce floor. + attestedDocumentNonceLength = 32 + + azureMetadataDNSName = "metadata.azure.com" + azureMetadataSubdomainSuffix = ".metadata.azure.com" + + // microsoftIntermediateCertBaseURL is the only location we should consult + // when constructing Azure metadata intermediate certificate URLs. + microsoftIntermediateCertBaseURL = "https://www.microsoft.com/pkiops/certs" + + // intermediateCertRefreshInterval is how long intermediate CA certificates + // are cached before re-fetching from the Microsoft PKI repository. Microsoft + // rotates intermediate CAs infrequently (typically yearly), so 24 hours + // keeps the cache fresh without unnecessary network requests. + intermediateCertRefreshInterval = 24 * time.Hour + + // intermediateCertNegativeCacheInterval is how long a failed intermediate + // fetch result is cached. Short enough to recover from a transient Microsoft + // PKI blip, long enough to stop an attacker from using bogus AIA URLs to + // amplify network fetches via repeated verification attempts. + intermediateCertNegativeCacheInterval = 5 * time.Minute + + // intermediateCertMaxResponseBytes caps the body size accepted from an + // intermediate certificate fetch. Real Microsoft PKI intermediates are + // ~1.5 KB in DER; 16 KiB leaves ample headroom while preventing a + // pathological response from consuming memory. + intermediateCertMaxResponseBytes = 16 * 1024 +) + +var ( + // intermediateCertPositiveCache caches successful intermediate fetches. + // intermediateCertNegativeCache caches recent fetch failures for a shorter + // window, keyed the same way, so attackers cannot amplify fetches via bogus + // AIA URLs. Stores are read positive-first; transient overlap is harmless. + intermediateCertPositiveCache = expirationcache.NewTTLStore( + intermediateCertCacheEntryKeyFunc, intermediateCertRefreshInterval) + intermediateCertNegativeCache = expirationcache.NewTTLStore( + intermediateCertCacheEntryKeyFunc, intermediateCertNegativeCacheInterval) + + cachedSystemCertPool *x509.CertPool + cachedSystemMu sync.Mutex + + // Reuse one client for intermediate fetches so each lookup does not build a + // new transport stack. The allowlist lives in URL validation, not in the client. + intermediateCertHTTPClient = &http.Client{Timeout: 10 * time.Second} +) + +// intermediateCacheKey scopes the cache to the issuing CA identity rather than +// the leaf signer. Caching by issuer avoids refetching for every leaf, while +// RawIssuer and AuthorityKeyId together distinguish issuers that may share +// names across renewals or cross-signs. +type intermediateCacheKey struct { + rawIssuer string + authorityKeyID string +} + +// intermediateCertCacheEntry is the object stored in the TTLStore caches. A +// nil pool marks a negative entry (a cached fetch failure). +type intermediateCertCacheEntry struct { + key intermediateCacheKey + pool *x509.CertPool +} + +// intermediateCertCacheEntryKeyFunc is the TTLStore key function for cache +// entries. \x00 is used as a separator since neither field can contain it in +// valid x509 content. +func intermediateCertCacheEntryKeyFunc(obj any) (string, error) { + e, ok := obj.(*intermediateCertCacheEntry) + if !ok { + return "", fmt.Errorf("unexpected cache entry type %T", obj) + } + return e.key.rawIssuer + "\x00" + e.key.authorityKeyID, nil +} + +// attestedData is the JSON content inside the PKCS7 signed data from the IMDS attested document. +type attestedData struct { + VMId string `json:"vmId"` + SubscriptionId string `json:"subscriptionId"` + Nonce string `json:"nonce"` + TimeStamp attestedTimeStamp `json:"timeStamp"` +} + +// attestedTimeStamp represents the creation and expiration time of the attested document. +type attestedTimeStamp struct { + CreatedOn string `json:"createdOn"` + ExpiresOn string `json:"expiresOn"` +} + +// verifyAttestedDocument verifies a PKCS7 attested document using the system +// root certificate pool, first trying the embedded PKCS7 chain and only +// falling back to Microsoft PKI when intermediates are missing. +func verifyAttestedDocument(signature string, body []byte) (*attestedData, error) { + rootCertPool, err := systemCertPool() + if err != nil { + return nil, err + } + + return verifyAttestedDocumentWithRootAndFetcher(signature, body, rootCertPool, intermediateCertPoolForSigner) +} + +// verifyAttestedDocumentWithRootAndFetcher verifies a PKCS7 attested document +// using the supplied root pool and intermediate fetcher. +func verifyAttestedDocumentWithRootAndFetcher(signature string, body []byte, rootCertPool *x509.CertPool, fetchIntermediates func(*x509.Certificate) (*x509.CertPool, error)) (*attestedData, error) { + if rootCertPool == nil { + return nil, fmt.Errorf("root certificate pool is required") + } + if fetchIntermediates == nil { + return nil, fmt.Errorf("intermediate certificate fetch function is required") + } + + p7, signer, err := parseAndValidatePKCS7Signer(signature) + if err != nil { + return nil, err + } + + // Signer SAN was validated above; now it is safe to hit the network. + klog.V(2).Infof("Fetching intermediate certificates for signer issuer %q", signer.Issuer) + intermediateCerts, err := fetchIntermediates(signer) + if err != nil { + return nil, fmt.Errorf("fetching intermediate certificates: %w", err) + } + if err := verifySignerCertChain(signer, p7.Certificates, rootCertPool, intermediateCerts); err != nil { + return nil, fmt.Errorf("verifying PKCS7 certificate chain: %w", err) + } + klog.V(2).Infof("PKCS7 certificate chain verified for signer issuer %q", signer.Issuer) + + return parseAndValidateAttestedDocumentContent(p7.Content, body) +} + +// verifyAttestedDocumentWithCertPools verifies a base64-encoded PKCS7 attested +// document using caller-supplied root and intermediate certificate pools. +func verifyAttestedDocumentWithCertPools(signature string, body []byte, rootCertPool *x509.CertPool, intermediateCerts *x509.CertPool) (*attestedData, error) { + if rootCertPool == nil { + return nil, fmt.Errorf("root certificate pool is required") + } + if intermediateCerts == nil { + return nil, fmt.Errorf("intermediate certificate pool is required") + } + + p7, signer, err := parseAndValidatePKCS7Signer(signature) + if err != nil { + return nil, err + } + + if err := verifySignerCertChain(signer, p7.Certificates, rootCertPool, intermediateCerts); err != nil { + return nil, fmt.Errorf("verifying PKCS7 certificate chain: %w", err) + } + + return parseAndValidateAttestedDocumentContent(p7.Content, body) +} + +// parseAndValidatePKCS7Signer decodes and parses a base64-encoded PKCS7 +// signature, verifies its self-signature, and validates that the signer +// certificate's SAN identifies an Azure metadata endpoint. All checks here +// are CPU-only; no network I/O is performed, so this is safe to call before +// triggering intermediate certificate fetches. +func parseAndValidatePKCS7Signer(signature string) (*pkcs7.PKCS7, *x509.Certificate, error) { + if signature == "" { + return nil, nil, fmt.Errorf("empty PKCS7 signature") + } + + sigBytes, err := base64.StdEncoding.DecodeString(signature) + if err != nil { + return nil, nil, fmt.Errorf("decoding PKCS7 signature: %w", err) + } + klog.V(4).Infof("Decoded PKCS7 signature (%d bytes)", len(sigBytes)) + + p7, err := pkcs7.Parse(sigBytes) + if err != nil { + return nil, nil, fmt.Errorf("parsing PKCS7 signature: %w", err) + } + klog.V(8).Infof("Parsed PKCS7 structure with %d embedded certificate(s)", len(p7.Certificates)) + + // Verify the PKCS7 signature against the embedded leaf certificate. + if err := p7.Verify(); err != nil { + return nil, nil, fmt.Errorf("verifying PKCS7 signature: %w", err) + } + klog.V(4).Infof("PKCS7 self-signature verified") + + signer := p7.GetOnlySigner() + if signer == nil { + return nil, nil, fmt.Errorf("getting PKCS7 signer certificate") + } + klog.V(8).Infof("PKCS7 signer certificate: subject=%q issuer=%q SANs=%v", signer.Subject, signer.Issuer, signer.DNSNames) + if err := validateAzureMetadataSignerSAN(signer); err != nil { + return nil, nil, fmt.Errorf("validating PKCS7 signer SAN: %w", err) + } + klog.V(4).Infof("PKCS7 signer SAN validated as Azure metadata endpoint") + + return p7, signer, nil +} + +// nonceForBody derives the IMDS attestation nonce from the request body. +// Must be identical on the authenticator and verifier sides. +func nonceForBody(body []byte) string { + hash := sha256.Sum256(body) + return hex.EncodeToString(hash[:])[:attestedDocumentNonceLength] +} + +// parseAndValidateAttestedDocumentContent unmarshals the signed attestation +// payload and validates its nonce and expiration. +func parseAndValidateAttestedDocumentContent(content []byte, body []byte) (*attestedData, error) { + // Extract and validate the signed attested data. + var data attestedData + if err := json.Unmarshal(content, &data); err != nil { + return nil, fmt.Errorf("unmarshalling attested data: %w", err) + } + klog.V(4).Infof("Attested document content: vmId=%q subscriptionId=%q nonce=%q createdOn=%q expiresOn=%q", data.VMId, data.SubscriptionId, data.Nonce, data.TimeStamp.CreatedOn, data.TimeStamp.ExpiresOn) + + // Verify the nonce matches the request body hash (replay protection). + expectedNonce := nonceForBody(body) + if data.Nonce != expectedNonce { + return nil, fmt.Errorf("attested document nonce mismatch: got=%q expected=%q", data.Nonce, expectedNonce) + } + klog.V(4).Infof("Attested document nonce verified") + + // Verify the attested document has not expired. + if data.TimeStamp.ExpiresOn != "" { + expiresOn, err := time.Parse(attestedDocumentTimeFormat, data.TimeStamp.ExpiresOn) + if err != nil { + return nil, fmt.Errorf("parsing attested document expiration: %w", err) + } + if time.Now().After(expiresOn) { + return nil, fmt.Errorf("attested document expired at %s", data.TimeStamp.ExpiresOn) + } + klog.V(4).Infof("Attested document not expired (expiresOn=%s)", expiresOn.Format(time.RFC3339)) + } + + return &data, nil +} + +// systemCertPool returns a cached system root certificate pool, +// loading it on first call. +func systemCertPool() (*x509.CertPool, error) { + cachedSystemMu.Lock() + defer cachedSystemMu.Unlock() + + if cachedSystemCertPool != nil { + return cachedSystemCertPool, nil + } + + pool, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("loading system certificate pool: %w", err) + } + + cachedSystemCertPool = pool + return pool, nil +} + +// intermediateCertPoolForSigner returns intermediates for the signer's issuer +// using the package-level TTL caches, fetching from Microsoft PKI on miss. +func intermediateCertPoolForSigner(signer *x509.Certificate) (*x509.CertPool, error) { + return intermediateCertPoolWithCaches(signer, fetchIntermediateCerts, intermediateCertPositiveCache, intermediateCertNegativeCache) +} + +// intermediateCertPoolWithCaches performs a cached lookup against the supplied +// positive and negative TTL caches, invoking fetch on a miss. Tests inject +// their own stores and fetchers. +func intermediateCertPoolWithCaches(signer *x509.Certificate, fetch func(*x509.Certificate) (*x509.CertPool, error), positive, negative expirationcache.Store) (*x509.CertPool, error) { + if signer == nil { + return nil, fmt.Errorf("signer certificate is required") + } + if fetch == nil { + return nil, fmt.Errorf("intermediate certificate fetch function is required") + } + if positive == nil || negative == nil { + return nil, fmt.Errorf("intermediate certificate caches are required") + } + + entry := &intermediateCertCacheEntry{key: intermediateCacheKey{ + rawIssuer: string(signer.RawIssuer), + authorityKeyID: string(signer.AuthorityKeyId), + }} + keyStr, err := intermediateCertCacheEntryKeyFunc(entry) + if err != nil { + return nil, err + } + + // Positive cache wins over negative: a successful later fetch overwrites + // any stale negative entry, which expires on its own shorter TTL. + if obj, ok, _ := positive.GetByKey(keyStr); ok { + klog.V(4).Infof("Intermediate certificate cache hit (positive) for signer issuer %q", signer.Issuer) + return obj.(*intermediateCertCacheEntry).pool, nil + } + if _, ok, _ := negative.GetByKey(keyStr); ok { + return nil, fmt.Errorf("intermediate certificate fetch recently failed for signer issuer %q (cached)", signer.Issuer) + } + + klog.V(2).Infof("Intermediate certificate cache miss for signer issuer %q; fetching from Microsoft PKI", signer.Issuer) + pool, fetchErr := fetch(signer) + entry.pool = pool + if fetchErr != nil { + // List() walks every entry and lazily deletes expired ones; ListKeys() + // would not trigger expiration. Doing this before each write bounds + // cache memory to ~(write_rate × TTL) without a background goroutine, + // which matters most for the negative cache since an attacker rotating + // issuer keys can drive writes to it at the fetch rate. Cost is O(N) + // per write, so in attack conditions writes become slower as the cache + // grows, which also acts as a natural rate limit. For legitimate + // traffic (a handful of entries), this is effectively free. + _ = negative.List() + _ = negative.Add(entry) + return nil, fetchErr + } + // Positive cache writes are rare and bounded to the legitimate Microsoft + // PKI issuer set, so this List() is symmetric with the negative path but + // essentially free in practice. + _ = positive.List() + _ = positive.Add(entry) + return pool, nil +} + +// fetchIntermediateCerts fetches intermediate CA certificates from validated +// Microsoft PKI AIA URLs from the signer certificate. +func fetchIntermediateCerts(signer *x509.Certificate) (*x509.CertPool, error) { + return fetchIntermediateCertsFromBaseURL(intermediateCertHTTPClient, microsoftIntermediateCertBaseURL, signer) +} + +// fetchIntermediateCertsFromBaseURL collects validated Microsoft PKI AIA URLs +// for the signer, downloads each certificate, and keeps only those that match +// the signer's issuer identity. +func fetchIntermediateCertsFromBaseURL(client *http.Client, baseURL string, signer *x509.Certificate) (*x509.CertPool, error) { + if client == nil { + return nil, fmt.Errorf("HTTP client is required") + } + if signer == nil { + return nil, fmt.Errorf("signer certificate is required") + } + + urls, err := microsoftIntermediateCandidateURLs(baseURL, signer) + if err != nil { + return nil, err + } + + pool := x509.NewCertPool() + matched := 0 + for _, url := range urls { + klog.V(4).Infof("Fetching intermediate certificate from %s", url) + cert, err := fetchCertificate(client, url) + if err != nil { + return nil, err + } + if err := validateFetchedIntermediateForSigner(signer, cert); err != nil { + klog.V(4).Infof("Fetched certificate from %s did not match signer issuer: %v", url, err) + continue + } + klog.V(4).Infof("Fetched intermediate certificate from %s matched signer issuer (subject=%q)", url, cert.Subject) + pool.AddCert(cert) + matched++ + } + if matched == 0 { + return nil, fmt.Errorf("no fetched intermediate certificates matched signer issuer %q", signer.Issuer) + } + klog.V(2).Infof("Fetched %d intermediate certificate(s) matching signer issuer %q from %d candidate URL(s)", matched, signer.Issuer, len(urls)) + + return pool, nil +} + +// fetchCertificate fetches and parses a DER-encoded certificate from the given URL. +func fetchCertificate(client *http.Client, url string) (*x509.Certificate, error) { + if client == nil { + return nil, fmt.Errorf("HTTP client is required") + } + if url == "" { + return nil, fmt.Errorf("certificate URL is required") + } + + resp, err := client.Get(url) + if err != nil { + return nil, fmt.Errorf("fetching intermediate certificate from %s: %w", url, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("fetching intermediate certificate from %s: status=%d", url, resp.StatusCode) + } + + // Cap the body read to reject pathologically large responses. Read one + // extra byte so we can distinguish "at the limit" from "exceeded limit". + body, err := io.ReadAll(io.LimitReader(resp.Body, intermediateCertMaxResponseBytes+1)) + if err != nil { + return nil, fmt.Errorf("reading intermediate certificate from %s: %w", url, err) + } + if len(body) > intermediateCertMaxResponseBytes { + return nil, fmt.Errorf("intermediate certificate from %s exceeds %d bytes", url, intermediateCertMaxResponseBytes) + } + + return x509.ParseCertificate(body) +} + +// validateFetchedIntermediateForSigner checks that a fetched intermediate is +// actually the issuer referenced by the signer certificate before it is used or cached. +func validateFetchedIntermediateForSigner(signer *x509.Certificate, cert *x509.Certificate) error { + if signer == nil { + return fmt.Errorf("signer certificate is required") + } + if cert == nil { + return fmt.Errorf("fetched certificate is required") + } + if !cert.IsCA { + return fmt.Errorf("fetched certificate is not a CA certificate") + } + if len(signer.RawIssuer) > 0 && !bytes.Equal(cert.RawSubject, signer.RawIssuer) { + return fmt.Errorf("fetched certificate subject does not match signer issuer") + } + if len(signer.AuthorityKeyId) > 0 && !bytes.Equal(cert.SubjectKeyId, signer.AuthorityKeyId) { + return fmt.Errorf("fetched certificate subject key identifier does not match signer authority key identifier") + } + + return nil +} + +// microsoftIntermediateCandidateURLs treats signer AIA values as untrusted +// input. It keeps only entries that stay within the configured Microsoft PKI +// host/path allowlist and normalizes them onto the configured scheme and host. +func microsoftIntermediateCandidateURLs(baseURL string, signer *x509.Certificate) ([]string, error) { + if signer == nil { + return nil, fmt.Errorf("signer certificate is required") + } + if baseURL == "" { + return nil, fmt.Errorf("base URL is required") + } + + base, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("parsing base URL: %w", err) + } + if !base.IsAbs() || base.Host == "" { + return nil, fmt.Errorf("base URL must be absolute") + } + + basePath := path.Clean(strings.TrimRight(base.Path, "/")) + if basePath == "." || basePath == "/" { + return nil, fmt.Errorf("base URL path is too broad") + } + + var urls []string + seen := make(map[string]struct{}) + for _, rawURL := range signer.IssuingCertificateURL { + normalized, ok := normalizeMicrosoftIntermediateURL(base, basePath, rawURL) + if !ok { + continue + } + if _, found := seen[normalized]; found { + continue + } + seen[normalized] = struct{}{} + urls = append(urls, normalized) + } + + if len(urls) == 0 { + return nil, fmt.Errorf("no valid Microsoft PKI AIA URLs found") + } + + return urls, nil +} + +// normalizeMicrosoftIntermediateURL copies only the allowed parts of a signer +// AIA URL onto the configured Microsoft PKI base URL. This keeps the path we +// need while ignoring attacker-controlled scheme, query, fragment, and userinfo. +func normalizeMicrosoftIntermediateURL(base *url.URL, basePath string, rawURL string) (string, bool) { + if base == nil { + return "", false + } + + candidate, err := url.Parse(rawURL) + if err != nil { + return "", false + } + if candidate.User != nil || candidate.RawQuery != "" || candidate.Fragment != "" { + return "", false + } + if candidate.Scheme != "http" && candidate.Scheme != "https" { + return "", false + } + if !strings.EqualFold(candidate.Hostname(), base.Hostname()) || candidate.Port() != base.Port() { + return "", false + } + + candidatePath := path.Clean(candidate.Path) + if candidatePath != basePath && !strings.HasPrefix(candidatePath, basePath+"/") { + return "", false + } + + return (&url.URL{ + Scheme: base.Scheme, + Host: base.Host, + Path: candidatePath, + }).String(), true +} + +// Azure guidance requires the metadata signer certificate to identify +// metadata.azure.com or a regional *.metadata.azure.com name in its DNS SANs. +func validateAzureMetadataSignerSAN(signer *x509.Certificate) error { + if signer == nil { + return fmt.Errorf("signer certificate is required") + } + + for _, dnsName := range signer.DNSNames { + if dnsName == azureMetadataDNSName || strings.HasSuffix(dnsName, azureMetadataSubdomainSuffix) { + return nil + } + } + + return fmt.Errorf("signer certificate SAN does not match Azure metadata domains") +} + +// verifySignerCertChain verifies that the signer certificate chains to a trusted root CA. +func verifySignerCertChain(signer *x509.Certificate, pkcs7Certs []*x509.Certificate, rootCertPool *x509.CertPool, intermediateCerts *x509.CertPool) error { + if signer == nil { + return fmt.Errorf("signer certificate is required") + } + if rootCertPool == nil { + return fmt.Errorf("root certificate pool is required") + } + if intermediateCerts == nil { + return fmt.Errorf("intermediate certificate pool is required") + } + + intermediates := intermediateCerts.Clone() + for _, cert := range pkcs7Certs { + intermediates.AddCert(cert) + } + + _, err := signer.Verify(x509.VerifyOptions{ + Roots: rootCertPool, + Intermediates: intermediates, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, + }) + return err +} diff --git a/upup/pkg/fi/cloudup/azure/attest_test.go b/upup/pkg/fi/cloudup/azure/attest_test.go new file mode 100644 index 0000000000000..43005a8450cad --- /dev/null +++ b/upup/pkg/fi/cloudup/azure/attest_test.go @@ -0,0 +1,926 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "math/big" + "net/http" + "slices" + "testing" + "time" + + "github.com/smallstep/pkcs7" + expirationcache "k8s.io/client-go/tools/cache" +) + +// roundTripperFunc adapts a function into an http.RoundTripper for tests. +type roundTripperFunc func(*http.Request) (*http.Response, error) + +// RoundTrip implements http.RoundTripper. +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +// newTestIntermediateCertCaches returns fresh positive and negative TTL +// caches so tests run independently from the package-level caches. +func newTestIntermediateCertCaches(positiveTTL, negativeTTL time.Duration) (expirationcache.Store, expirationcache.Store) { + return expirationcache.NewTTLStore(intermediateCertCacheEntryKeyFunc, positiveTTL), + expirationcache.NewTTLStore(intermediateCertCacheEntryKeyFunc, negativeTTL) +} + +// testPKI generates a CA certificate, a leaf certificate, and their keys for test PKCS7 signing. +func testPKI(tb testing.TB) (*x509.Certificate, *rsa.PrivateKey, *x509.Certificate, *rsa.PrivateKey) { + tb.Helper() + + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + tb.Fatalf("generating CA key: %v", err) + } + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "Test Root CA"}, + SubjectKeyId: []byte("test-ca-ski"), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + if err != nil { + tb.Fatalf("creating CA cert: %v", err) + } + caCert, err := x509.ParseCertificate(caCertDER) + if err != nil { + tb.Fatalf("parsing CA cert: %v", err) + } + + leafKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + tb.Fatalf("generating leaf key: %v", err) + } + leafTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{CommonName: "Test Signer"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + DNSNames: []string{"metadata.azure.com"}, + } + leafCertDER, err := x509.CreateCertificate(rand.Reader, leafTemplate, caCert, &leafKey.PublicKey, caKey) + if err != nil { + tb.Fatalf("creating leaf cert: %v", err) + } + leafCert, err := x509.ParseCertificate(leafCertDER) + if err != nil { + tb.Fatalf("parsing leaf cert: %v", err) + } + + return caCert, caKey, leafCert, leafKey +} + +// testPKIChain generates a root CA, an intermediate CA, and a leaf signer for +// tests that need to distinguish embedded-chain verification from fetch fallback. +func testPKIChain(tb testing.TB) (*x509.Certificate, *x509.Certificate, *x509.Certificate, *rsa.PrivateKey) { + tb.Helper() + + rootKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + tb.Fatalf("generating root key: %v", err) + } + rootTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(10), + Subject: pkix.Name{CommonName: "Test Root CA"}, + SubjectKeyId: []byte("test-root-ski"), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + rootCertDER, err := x509.CreateCertificate(rand.Reader, rootTemplate, rootTemplate, &rootKey.PublicKey, rootKey) + if err != nil { + tb.Fatalf("creating root cert: %v", err) + } + rootCert, err := x509.ParseCertificate(rootCertDER) + if err != nil { + tb.Fatalf("parsing root cert: %v", err) + } + + intermediateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + tb.Fatalf("generating intermediate key: %v", err) + } + intermediateTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(11), + Subject: pkix.Name{CommonName: "Test Intermediate CA"}, + SubjectKeyId: []byte("test-intermediate-ski"), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + intermediateCertDER, err := x509.CreateCertificate(rand.Reader, intermediateTemplate, rootCert, &intermediateKey.PublicKey, rootKey) + if err != nil { + tb.Fatalf("creating intermediate cert: %v", err) + } + intermediateCert, err := x509.ParseCertificate(intermediateCertDER) + if err != nil { + tb.Fatalf("parsing intermediate cert: %v", err) + } + + leafKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + tb.Fatalf("generating leaf key: %v", err) + } + leafTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(12), + Subject: pkix.Name{CommonName: "Test Signer"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + DNSNames: []string{"metadata.azure.com"}, + } + leafCertDER, err := x509.CreateCertificate(rand.Reader, leafTemplate, intermediateCert, &leafKey.PublicKey, intermediateKey) + if err != nil { + tb.Fatalf("creating leaf cert: %v", err) + } + leafCert, err := x509.ParseCertificate(leafCertDER) + if err != nil { + tb.Fatalf("parsing leaf cert: %v", err) + } + + return rootCert, intermediateCert, leafCert, leafKey +} + +// createTestPKCS7 creates a PKCS7 SignedData containing the given content, signed by the leaf cert. +func createTestPKCS7(tb testing.TB, content []byte, leafCert *x509.Certificate, leafKey *rsa.PrivateKey, parents ...*x509.Certificate) []byte { + tb.Helper() + + sd, err := pkcs7.NewSignedData(content) + if err != nil { + tb.Fatalf("creating signed data: %v", err) + } + if len(parents) == 0 { + if err := sd.AddSigner(leafCert, leafKey, pkcs7.SignerInfoConfig{}); err != nil { + tb.Fatalf("adding signer: %v", err) + } + } else { + if err := sd.AddSignerChain(leafCert, leafKey, parents, pkcs7.SignerInfoConfig{}); err != nil { + tb.Fatalf("adding signer chain: %v", err) + } + } + derBytes, err := sd.Finish() + if err != nil { + tb.Fatalf("finishing signed data: %v", err) + } + return derBytes +} + +// testSignature signs attested data and returns the PKCS7 as base64 text. +func testSignature(tb testing.TB, data attestedData, leafCert *x509.Certificate, leafKey *rsa.PrivateKey, caCert *x509.Certificate) string { + tb.Helper() + content, _ := json.Marshal(data) + pkcs7DER := createTestPKCS7(tb, content, leafCert, leafKey, caCert) + return base64.StdEncoding.EncodeToString(pkcs7DER) +} + +// TestMicrosoftIntermediateCandidateURLs verifies that only normalized +// Microsoft PKI AIA URLs are returned. +func TestMicrosoftIntermediateCandidateURLs(t *testing.T) { + signer := &x509.Certificate{ + IssuingCertificateURL: []string{ + "http://www.microsoft.com/pkiops/certs/Completely%20New%20Azure%20Metadata%20Issuing%20CA%2042%20-%20xsign.crt", + "https://www.microsoft.com/pkiops/certs/Completely%20New%20Azure%20Metadata%20Issuing%20CA%2042.crt?ignored=1", + "https://www.microsoft.com.evil.test/pkiops/certs/not-allowed.crt", + }, + } + + got, err := microsoftIntermediateCandidateURLs(microsoftIntermediateCertBaseURL, signer) + if err != nil { + t.Fatalf("collecting candidate URLs: %v", err) + } + + want := []string{ + "https://www.microsoft.com/pkiops/certs/Completely%20New%20Azure%20Metadata%20Issuing%20CA%2042%20-%20xsign.crt", + } + if !slices.Equal(got, want) { + t.Fatalf("candidate URLs mismatch: got %v, want %v", got, want) + } +} + +// TestMicrosoftIntermediateCandidateURLs_RejectsNonMicrosoftPKIURLs verifies +// that out-of-scope AIA URLs are rejected. +func TestMicrosoftIntermediateCandidateURLs_RejectsNonMicrosoftPKIURLs(t *testing.T) { + signer := &x509.Certificate{ + IssuingCertificateURL: []string{ + "https://www.microsoft.com.evil.test/pkiops/certs/not-allowed.crt", + "https://www.microsoft.com@evil.test/pkiops/certs/not-allowed.crt", + "https://www.microsoft.com/pkiops/other/not-allowed.crt", + }, + } + + _, err := microsoftIntermediateCandidateURLs(microsoftIntermediateCertBaseURL, signer) + if err == nil { + t.Fatal("expected error for non-Microsoft PKI AIA URLs") + } +} + +// TestFetchIntermediateCertsFromBaseURL_UsesValidatedSignerAIA verifies that +// fetching uses only normalized, validated AIA URLs. +func TestFetchIntermediateCertsFromBaseURL_UsesValidatedSignerAIA(t *testing.T) { + caCert, _, _, _ := testPKI(t) + signer := &x509.Certificate{ + RawIssuer: caCert.RawSubject, + AuthorityKeyId: caCert.SubjectKeyId, + IssuingCertificateURL: []string{ + "http://example.test/pkiops/certs/allowed.crt", + "http://127.0.0.1:1/not-used.crt", + }, + } + + var gotURLs []string + client := &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + gotURLs = append(gotURLs, req.URL.String()) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(caCert.Raw)), + Header: make(http.Header), + }, nil + }), + } + + baseURL := "https://example.test/pkiops/certs" + pool, err := fetchIntermediateCertsFromBaseURL(client, baseURL, signer) + if err != nil { + t.Fatalf("fetching intermediate certificates: %v", err) + } + if pool == nil { + t.Fatal("expected intermediate certificate pool") + } + + wantURLs := []string{ + "https://example.test/pkiops/certs/allowed.crt", + } + if !slices.Equal(gotURLs, wantURLs) { + t.Fatalf("requested URLs mismatch: got %v, want %v", gotURLs, wantURLs) + } +} + +// TestFetchIntermediateCertsFromBaseURL_RejectsNonMatchingIntermediate verifies +// that fetched certificates are ignored unless they match the signer's issuer identity. +func TestFetchIntermediateCertsFromBaseURL_RejectsNonMatchingIntermediate(t *testing.T) { + expectedIssuer, _, _, _ := testPKI(t) + _, otherIssuer, _, _ := testPKIChain(t) + + signer := &x509.Certificate{ + RawIssuer: expectedIssuer.RawSubject, + AuthorityKeyId: expectedIssuer.SubjectKeyId, + IssuingCertificateURL: []string{ + "https://example.test/pkiops/certs/not-the-issuer.crt", + }, + } + + client := &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(otherIssuer.Raw)), + Header: make(http.Header), + }, nil + }), + } + + _, err := fetchIntermediateCertsFromBaseURL(client, "https://example.test/pkiops/certs", signer) + if err == nil { + t.Fatal("expected error for non-matching fetched intermediate") + } +} + +// TestFetchIntermediateCertsFromBaseURL_KeepsOnlyMatchingIntermediate verifies +// that when multiple AIA URLs return different certs, only the one that matches +// the signer's issuer identity ends up in the pool. +func TestFetchIntermediateCertsFromBaseURL_KeepsOnlyMatchingIntermediate(t *testing.T) { + matchingCA, _, leafCert, _ := testPKI(t) + _, otherIssuer, _, _ := testPKIChain(t) + + signer := &x509.Certificate{ + RawIssuer: matchingCA.RawSubject, + AuthorityKeyId: matchingCA.SubjectKeyId, + IssuingCertificateURL: []string{ + "https://example.test/pkiops/certs/first.crt", + "https://example.test/pkiops/certs/second.crt", + }, + } + + calls := 0 + client := &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + calls++ + var body []byte + if calls == 1 { + body = otherIssuer.Raw // first URL returns wrong cert + } else { + body = matchingCA.Raw // second URL returns the right cert + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(body)), + Header: make(http.Header), + }, nil + }), + } + + pool, err := fetchIntermediateCertsFromBaseURL(client, "https://example.test/pkiops/certs", signer) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if calls != 2 { + t.Fatalf("expected both URLs to be fetched, got %d calls", calls) + } + // Pool should contain matchingCA (not otherIssuer). Verify by using the pool + // as roots and verifying the leaf signed by matchingCA — succeeds iff the + // matching CA is in the pool. + if _, err := leafCert.Verify(x509.VerifyOptions{ + Roots: pool, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, + }); err != nil { + t.Errorf("expected pool to contain matching CA: %v", err) + } +} + +// TestValidateFetchedIntermediateForSigner exercises each per-branch +// reject/accept rule for the fetched intermediate identity check. +func TestValidateFetchedIntermediateForSigner(t *testing.T) { + caCert, _, _, _ := testPKI(t) + + signerMatching := &x509.Certificate{ + RawIssuer: caCert.RawSubject, + AuthorityKeyId: caCert.SubjectKeyId, + } + + t.Run("accepts matching CA", func(t *testing.T) { + if err := validateFetchedIntermediateForSigner(signerMatching, caCert); err != nil { + t.Fatalf("unexpected rejection: %v", err) + } + }) + + t.Run("rejects non-CA cert", func(t *testing.T) { + nonCA := &x509.Certificate{ + RawSubject: caCert.RawSubject, + SubjectKeyId: caCert.SubjectKeyId, + IsCA: false, + } + if err := validateFetchedIntermediateForSigner(signerMatching, nonCA); err == nil { + t.Fatal("expected rejection of non-CA certificate") + } + }) + + t.Run("rejects subject mismatch", func(t *testing.T) { + signer := &x509.Certificate{ + RawIssuer: []byte("different-issuer"), + AuthorityKeyId: caCert.SubjectKeyId, + } + if err := validateFetchedIntermediateForSigner(signer, caCert); err == nil { + t.Fatal("expected rejection when cert subject does not match signer issuer") + } + }) + + t.Run("rejects SKI/AKI mismatch", func(t *testing.T) { + signer := &x509.Certificate{ + RawIssuer: caCert.RawSubject, + AuthorityKeyId: []byte("different-akid"), + } + if err := validateFetchedIntermediateForSigner(signer, caCert); err == nil { + t.Fatal("expected rejection when cert SKI does not match signer AKI") + } + }) + + t.Run("rejects nil signer", func(t *testing.T) { + if err := validateFetchedIntermediateForSigner(nil, caCert); err == nil { + t.Fatal("expected error for nil signer") + } + }) + + t.Run("rejects nil cert", func(t *testing.T) { + if err := validateFetchedIntermediateForSigner(signerMatching, nil); err == nil { + t.Fatal("expected error for nil cert") + } + }) +} + +// TestFetchCertificate_RejectsBadResponse exercises the rejection paths of +// fetchCertificate: oversized body, non-200 status, and unparseable DER. +func TestFetchCertificate_RejectsBadResponse(t *testing.T) { + testCases := []struct { + name string + status int + body []byte + }{ + {"oversized body", http.StatusOK, bytes.Repeat([]byte{0x30}, intermediateCertMaxResponseBytes+1)}, + {"non-200 status", http.StatusNotFound, nil}, + {"invalid DER", http.StatusOK, []byte("not a certificate")}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + client := &http.Client{ + Transport: roundTripperFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: tc.status, + Body: io.NopCloser(bytes.NewReader(tc.body)), + Header: make(http.Header), + }, nil + }), + } + if _, err := fetchCertificate(client, "https://example.test/cert.crt"); err == nil { + t.Fatal("expected error") + } + }) + } +} + +// TestIntermediateCertPoolWithCaches_ReusesSameSignerCacheEntry verifies +// cache reuse for repeated lookups of the same signer issuer. +func TestIntermediateCertPoolWithCaches_ReusesSameSignerCacheEntry(t *testing.T) { + positive, negative := newTestIntermediateCertCaches(24*time.Hour, 5*time.Minute) + + signer := &x509.Certificate{ + RawIssuer: []byte("issuer-1"), + AuthorityKeyId: []byte("akid-1"), + } + + expectedPool := x509.NewCertPool() + fetchCalls := 0 + fetch := func(gotSigner *x509.Certificate) (*x509.CertPool, error) { + fetchCalls++ + if gotSigner != signer { + t.Fatalf("unexpected signer: got %p, want %p", gotSigner, signer) + } + return expectedPool, nil + } + + first, err := intermediateCertPoolWithCaches(signer, fetch, positive, negative) + if err != nil { + t.Fatalf("first cache fetch: %v", err) + } + second, err := intermediateCertPoolWithCaches(signer, fetch, positive, negative) + if err != nil { + t.Fatalf("second cache fetch: %v", err) + } + + if fetchCalls != 1 { + t.Fatalf("fetch call count mismatch: got %d, want 1", fetchCalls) + } + if first != expectedPool || second != expectedPool { + t.Fatal("expected cached pool to be reused for the same signer") + } +} + +// TestIntermediateCertPoolWithCaches_SeparatesDifferentSignerCacheEntries +// verifies that different issuer keys do not share one cache entry. +func TestIntermediateCertPoolWithCaches_SeparatesDifferentSignerCacheEntries(t *testing.T) { + positive, negative := newTestIntermediateCertCaches(24*time.Hour, 5*time.Minute) + + signer1 := &x509.Certificate{ + RawIssuer: []byte("issuer-1"), + AuthorityKeyId: []byte("akid-1"), + } + signer2 := &x509.Certificate{ + RawIssuer: []byte("issuer-1"), + AuthorityKeyId: []byte("akid-2"), + } + + pool1 := x509.NewCertPool() + pool2 := x509.NewCertPool() + fetchCalls := 0 + fetch := func(gotSigner *x509.Certificate) (*x509.CertPool, error) { + fetchCalls++ + switch gotSigner { + case signer1: + return pool1, nil + case signer2: + return pool2, nil + default: + t.Fatalf("unexpected signer: %p", gotSigner) + return nil, nil + } + } + + first, err := intermediateCertPoolWithCaches(signer1, fetch, positive, negative) + if err != nil { + t.Fatalf("first signer cache fetch: %v", err) + } + second, err := intermediateCertPoolWithCaches(signer2, fetch, positive, negative) + if err != nil { + t.Fatalf("second signer cache fetch: %v", err) + } + firstAgain, err := intermediateCertPoolWithCaches(signer1, fetch, positive, negative) + if err != nil { + t.Fatalf("first signer cache reuse: %v", err) + } + + if fetchCalls != 2 { + t.Fatalf("fetch call count mismatch: got %d, want 2", fetchCalls) + } + if first != pool1 || firstAgain != pool1 { + t.Fatal("expected signer1 to reuse its own cached pool") + } + if second != pool2 { + t.Fatal("expected signer2 to use a distinct cached pool") + } +} + +// TestIntermediateCertPoolWithCaches_CachesFetchFailure verifies that a +// failed fetch is remembered so repeated attempts for the same issuer cannot +// amplify into repeated network calls. +func TestIntermediateCertPoolWithCaches_CachesFetchFailure(t *testing.T) { + positive, negative := newTestIntermediateCertCaches(24*time.Hour, 5*time.Minute) + + signer := &x509.Certificate{ + RawIssuer: []byte("issuer-1"), + AuthorityKeyId: []byte("akid-1"), + } + + fetchCalls := 0 + fetch := func(*x509.Certificate) (*x509.CertPool, error) { + fetchCalls++ + return nil, fmt.Errorf("simulated fetch failure") + } + + if _, err := intermediateCertPoolWithCaches(signer, fetch, positive, negative); err == nil { + t.Fatal("expected error from first fetch") + } + if _, err := intermediateCertPoolWithCaches(signer, fetch, positive, negative); err == nil { + t.Fatal("expected error from cached negative entry") + } + + if fetchCalls != 1 { + t.Fatalf("fetch call count mismatch: got %d, want 1 (negative cache should absorb the second call)", fetchCalls) + } +} + +// TestIntermediateCertPoolWithCaches_NegativeCacheExpires verifies that +// once the negative cache TTL passes, the fetcher is invoked again. +func TestIntermediateCertPoolWithCaches_NegativeCacheExpires(t *testing.T) { + // Short negative TTL so the test doesn't have to sleep long. + positive, negative := newTestIntermediateCertCaches(24*time.Hour, 50*time.Millisecond) + + signer := &x509.Certificate{ + RawIssuer: []byte("issuer-1"), + AuthorityKeyId: []byte("akid-1"), + } + + fetchCalls := 0 + fetch := func(*x509.Certificate) (*x509.CertPool, error) { + fetchCalls++ + return nil, fmt.Errorf("simulated fetch failure") + } + + if _, err := intermediateCertPoolWithCaches(signer, fetch, positive, negative); err == nil { + t.Fatal("expected error from first fetch") + } + + time.Sleep(100 * time.Millisecond) + + if _, err := intermediateCertPoolWithCaches(signer, fetch, positive, negative); err == nil { + t.Fatal("expected error from re-fetch after negative TTL") + } + + if fetchCalls != 2 { + t.Fatalf("fetch call count mismatch: got %d, want 2 (negative cache should have expired)", fetchCalls) + } +} + +// TestIntermediateCertPoolWithCaches_PositiveOverridesStaleNegative verifies +// that a positive cache entry wins over a still-live negative entry for the +// same issuer, per the positive-first lookup order documented in the code. +func TestIntermediateCertPoolWithCaches_PositiveOverridesStaleNegative(t *testing.T) { + positive, negative := newTestIntermediateCertCaches(24*time.Hour, 5*time.Minute) + + signer := &x509.Certificate{ + RawIssuer: []byte("issuer-1"), + AuthorityKeyId: []byte("akid-1"), + } + + expectedPool := x509.NewCertPool() + callCount := 0 + fetch := func(*x509.Certificate) (*x509.CertPool, error) { + callCount++ + if callCount == 1 { + return nil, fmt.Errorf("simulated fetch failure") + } + return expectedPool, nil + } + + // First call fails and is cached negatively. + if _, err := intermediateCertPoolWithCaches(signer, fetch, positive, negative); err == nil { + t.Fatal("expected error from first fetch") + } + + // Inject a positive entry under the same key, simulating a later success + // that bypassed the negative-cache check (e.g., via a second TTLStore + // instance or concurrent write). + entry := &intermediateCertCacheEntry{ + key: intermediateCacheKey{rawIssuer: "issuer-1", authorityKeyID: "akid-1"}, + pool: expectedPool, + } + if err := positive.Add(entry); err != nil { + t.Fatalf("adding positive entry: %v", err) + } + + // Next lookup must return the positive pool and must not call fetch. + got, err := intermediateCertPoolWithCaches(signer, fetch, positive, negative) + if err != nil { + t.Fatalf("expected cache hit, got error: %v", err) + } + if got != expectedPool { + t.Fatal("expected positive cached pool to be returned over the negative entry") + } + if callCount != 1 { + t.Fatalf("fetch call count mismatch: got %d, want 1 (positive hit should not re-fetch)", callCount) + } +} + +// TestValidateAzureMetadataSignerSAN verifies that signer certificates are +// accepted iff their SAN identifies an Azure metadata endpoint. +func TestValidateAzureMetadataSignerSAN(t *testing.T) { + testCases := []struct { + name string + dnsNames []string + wantErr bool + }{ + {"exact metadata.azure.com", []string{"metadata.azure.com"}, false}, + {"regional subdomain", []string{"northeurope.metadata.azure.com"}, false}, + {"mixed with other names", []string{"other.example.com", "metadata.azure.com"}, false}, + {"non-Azure metadata cert", []string{"metadata.example.com"}, true}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + signer := &x509.Certificate{DNSNames: tc.dnsNames} + err := validateAzureMetadataSignerSAN(signer) + if tc.wantErr { + if err == nil { + t.Fatal("expected SAN validation error") + } + return + } + if err != nil { + t.Errorf("expected SAN validation to pass, got: %v", err) + } + }) + } +} + +// TestVerifyAttestedDocumentWithCertPools exercises chain-trust, replay +// protection, expiration, and the happy path in one table. +func TestVerifyAttestedDocumentWithCertPools(t *testing.T) { + caCert, _, leafCert, leafKey := testPKI(t) + body := []byte("test-body") + + trustedPool := x509.NewCertPool() + trustedPool.AddCert(caCert) + + validTimestamps := attestedTimeStamp{ + CreatedOn: time.Now().UTC().Format(attestedDocumentTimeFormat), + ExpiresOn: time.Now().Add(time.Hour).UTC().Format(attestedDocumentTimeFormat), + } + expiredTimestamps := attestedTimeStamp{ + CreatedOn: time.Now().Add(-2 * time.Hour).UTC().Format(attestedDocumentTimeFormat), + ExpiresOn: time.Now().Add(-1 * time.Hour).UTC().Format(attestedDocumentTimeFormat), + } + + testCases := []struct { + name string + data attestedData + trustStore *x509.CertPool + wantErr bool + wantVMId string + }{ + { + name: "untrusted signature", + data: attestedData{VMId: "test-vm-id", Nonce: nonceForBody(body), TimeStamp: validTimestamps}, + trustStore: x509.NewCertPool(), // empty — test CA is not trusted + wantErr: true, + }, + { + name: "nonce mismatch", + data: attestedData{VMId: "test-vm-id", Nonce: "wrong-nonce", TimeStamp: validTimestamps}, + trustStore: trustedPool, + wantErr: true, + }, + { + name: "expired document", + data: attestedData{VMId: "test-vm-id", Nonce: nonceForBody(body), TimeStamp: expiredTimestamps}, + trustStore: trustedPool, + wantErr: true, + }, + { + name: "success", + data: attestedData{ + VMId: "02aab8a4-74ef-476e-8182-f6d2ba4166a6", + SubscriptionId: "8d10da13-8125-4ba9-a717-bf7490507b3d", + Nonce: nonceForBody(body), + TimeStamp: validTimestamps, + }, + trustStore: trustedPool, + wantVMId: "02aab8a4-74ef-476e-8182-f6d2ba4166a6", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sig := testSignature(t, tc.data, leafCert, leafKey, caCert) + result, err := verifyAttestedDocumentWithCertPools(sig, body, tc.trustStore, x509.NewCertPool()) + if tc.wantErr { + if err == nil { + t.Fatal("expected error") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.wantVMId != "" && result.VMId != tc.wantVMId { + t.Errorf("vmId: got %q, want %q", result.VMId, tc.wantVMId) + } + }) + } +} + +// TestVerifyAttestedDocumentWithRootAndFetcher_SkipsFetchWhenEmbeddedChainSuffices +// verifies that the verifier accepts an attestation using only embedded PKCS7 +// certificates and does not consult Microsoft PKI in that case. +func TestVerifyAttestedDocumentWithRootAndFetcher_AlwaysFetchesIntermediates(t *testing.T) { + rootCert, intermediateCert, leafCert, leafKey := testPKIChain(t) + + body := []byte("test-body") + content, err := json.Marshal(attestedData{ + VMId: "test-vm-id", + SubscriptionId: "test-subscription", + Nonce: nonceForBody(body), + TimeStamp: attestedTimeStamp{ + CreatedOn: time.Now().UTC().Format(attestedDocumentTimeFormat), + ExpiresOn: time.Now().Add(time.Hour).UTC().Format(attestedDocumentTimeFormat), + }, + }) + if err != nil { + t.Fatalf("marshalling attested data: %v", err) + } + // Include the intermediate in the PKCS7 to confirm the fetcher is still called. + sig := base64.StdEncoding.EncodeToString(createTestPKCS7(t, content, leafCert, leafKey, intermediateCert)) + + rootPool := x509.NewCertPool() + rootPool.AddCert(rootCert) + + fetchCalls := 0 + result, err := verifyAttestedDocumentWithRootAndFetcher(sig, body, rootPool, func(*x509.Certificate) (*x509.CertPool, error) { + fetchCalls++ + pool := x509.NewCertPool() + pool.AddCert(intermediateCert) + return pool, nil + }) + if err != nil { + t.Fatalf("unexpected verification error: %v", err) + } + if result.VMId != "test-vm-id" { + t.Fatalf("expected vmId %q, got %q", "test-vm-id", result.VMId) + } + if fetchCalls != 1 { + t.Fatalf("fetch call count mismatch: got %d, want 1", fetchCalls) + } +} + +// TestVerifyAttestedDocumentWithRootAndFetcher_FetchesMissingIntermediate verifies +// that the verifier falls back to the fetcher only when the embedded PKCS7 +// chain does not already contain the needed intermediate. +func TestVerifyAttestedDocumentWithRootAndFetcher_FetchesMissingIntermediate(t *testing.T) { + rootCert, intermediateCert, leafCert, leafKey := testPKIChain(t) + + body := []byte("test-body") + content, err := json.Marshal(attestedData{ + VMId: "test-vm-id", + SubscriptionId: "test-subscription", + Nonce: nonceForBody(body), + TimeStamp: attestedTimeStamp{ + CreatedOn: time.Now().UTC().Format(attestedDocumentTimeFormat), + ExpiresOn: time.Now().Add(time.Hour).UTC().Format(attestedDocumentTimeFormat), + }, + }) + if err != nil { + t.Fatalf("marshalling attested data: %v", err) + } + sig := base64.StdEncoding.EncodeToString(createTestPKCS7(t, content, leafCert, leafKey)) + + rootPool := x509.NewCertPool() + rootPool.AddCert(rootCert) + + fetchCalls := 0 + result, err := verifyAttestedDocumentWithRootAndFetcher(sig, body, rootPool, func(*x509.Certificate) (*x509.CertPool, error) { + fetchCalls++ + pool := x509.NewCertPool() + pool.AddCert(intermediateCert) + return pool, nil + }) + if err != nil { + t.Fatalf("unexpected verification error: %v", err) + } + if result.VMId != "test-vm-id" { + t.Fatalf("expected vmId %q, got %q", "test-vm-id", result.VMId) + } + if fetchCalls != 1 { + t.Fatalf("fetch call count mismatch: got %d, want 1", fetchCalls) + } +} + +// TestVerifyAttestedDocumentWithRootAndFetcher_PropagatesFetchError verifies +// that when the fetcher fails, the error is propagated to the caller. +func TestVerifyAttestedDocumentWithRootAndFetcher_PropagatesFetchError(t *testing.T) { + rootCert, _, leafCert, leafKey := testPKIChain(t) + + body := []byte("test-body") + content, err := json.Marshal(attestedData{ + VMId: "test-vm-id", + Nonce: nonceForBody(body), + TimeStamp: attestedTimeStamp{ + CreatedOn: time.Now().UTC().Format(attestedDocumentTimeFormat), + ExpiresOn: time.Now().Add(time.Hour).UTC().Format(attestedDocumentTimeFormat), + }, + }) + if err != nil { + t.Fatalf("marshalling attested data: %v", err) + } + sig := base64.StdEncoding.EncodeToString(createTestPKCS7(t, content, leafCert, leafKey)) + + rootPool := x509.NewCertPool() + rootPool.AddCert(rootCert) + + fetchErr := fmt.Errorf("simulated fetch failure") + _, err = verifyAttestedDocumentWithRootAndFetcher(sig, body, rootPool, func(*x509.Certificate) (*x509.CertPool, error) { + return nil, fetchErr + }) + if err == nil { + t.Fatal("expected error when fetch fails") + } + if !errors.Is(err, fetchErr) { + t.Errorf("expected wrapped fetchErr in error tree, got: %v", err) + } +} + +// TestParseAndValidateAttestedDocumentContent_AcceptsMicrosoftFormat verifies +// that the verifier decodes the JSON field names and timestamp format that +// Azure IMDS actually emits — including the -0000 timezone suffix that Go's +// default time formatter does not produce. This locks in compatibility with +// real Microsoft IMDS responses without depending on a signed fixture whose +// certificates expire. +func TestParseAndValidateAttestedDocumentContent_AcceptsMicrosoftFormat(t *testing.T) { + body := []byte("test-body") + content := fmt.Sprintf(`{ + "vmId": "3ceb0a9e-ff74-4e17-924a-f2acd3b31310", + "subscriptionId": "46678f10-4bbb-447e-98e8-d2829589f2d8", + "nonce": %q, + "timeStamp": { + "createdOn": "01/01/99 00:00:00 -0000", + "expiresOn": "01/01/50 00:00:00 -0000" + } + }`, nonceForBody(body)) + + data, err := parseAndValidateAttestedDocumentContent([]byte(content), body) + if err != nil { + t.Fatalf("parsing Microsoft-format content: %v", err) + } + if data.VMId != "3ceb0a9e-ff74-4e17-924a-f2acd3b31310" { + t.Errorf("vmId decode mismatch: got %q", data.VMId) + } + if data.SubscriptionId != "46678f10-4bbb-447e-98e8-d2829589f2d8" { + t.Errorf("subscriptionId decode mismatch: got %q", data.SubscriptionId) + } + if data.TimeStamp.ExpiresOn != "01/01/50 00:00:00 -0000" { + t.Errorf("expiresOn decode mismatch: got %q", data.TimeStamp.ExpiresOn) + } +} diff --git a/upup/pkg/fi/cloudup/azure/authenticator.go b/upup/pkg/fi/cloudup/azure/authenticator.go index b009804560832..e71962f213de4 100644 --- a/upup/pkg/fi/cloudup/azure/authenticator.go +++ b/upup/pkg/fi/cloudup/azure/authenticator.go @@ -17,80 +17,53 @@ limitations under the License. package azure import ( - "encoding/json" "fmt" - "io" - "net/http" + "k8s.io/klog/v2" "k8s.io/kops/pkg/bootstrap" ) +// AzureAuthenticationTokenPrefix prefixes bootstrap tokens created from Azure +// IMDS instance identity data. const AzureAuthenticationTokenPrefix = "x-azure-id " -type azureAuthenticator struct { -} +type azureAuthenticator struct{} var _ bootstrap.Authenticator = (*azureAuthenticator)(nil) +// NewAzureAuthenticator returns an authenticator that mints Azure bootstrap +// tokens backed by IMDS metadata and an attested document signature. func NewAzureAuthenticator() (bootstrap.Authenticator, error) { return &azureAuthenticator{}, nil } +// CreateToken fetches the local VM identity from IMDS and returns a bootstrap +// token containing the resource ID and signed attested document. func (h *azureAuthenticator) CreateToken(body []byte) (string, error) { + klog.V(2).Infof("Azure authenticator creating bootstrap token") + + // Query IMDS for the VM's resource ID. metadata, err := QueryComputeInstanceMetadata() if err != nil { return "", fmt.Errorf("querying instance metadata: %w", err) } - if metadata == nil || metadata.VMID == "" { - return "", fmt.Errorf("missing virtual machine ID") + if metadata == nil || metadata.ResourceID == "" { + return "", fmt.Errorf("missing resource ID") } + klog.V(4).Infof("Azure authenticator obtained resource ID %q", metadata.ResourceID) - token := metadata.ResourceID + " " + metadata.VMID - - return AzureAuthenticationTokenPrefix + token, nil -} - -// InstanceMetadata contains compute instance metadata from the Azure IMDS. -type InstanceMetadata struct { - SubscriptionID string `json:"subscriptionId"` - ResourceGroupName string `json:"resourceGroupName"` - ResourceID string `json:"resourceId"` - VMID string `json:"vmId"` -} - -// QueryComputeInstanceMetadata queries Azure Instance Metadata Service (IMDS) -// https://learn.microsoft.com/en-us/azure/virtual-machines/instance-metadata-service -func QueryComputeInstanceMetadata() (*InstanceMetadata, error) { - transport := &http.Transport{Proxy: nil} - - client := http.Client{Transport: transport} - - req, err := http.NewRequest("GET", "http://169.254.169.254/metadata/instance/compute", nil) + // Query IMDS for a PKCS7-signed attested document containing the nonce. + nonce := nonceForBody(body) + klog.V(4).Infof("Azure authenticator requesting attested document with nonce %q", nonce) + doc, err := queryIMDSAttestedDocument(nonce) if err != nil { - return nil, fmt.Errorf("creating a new request: %w", err) + return "", fmt.Errorf("querying attested document: %w", err) } - req.Header.Add("Metadata", "True") - - q := req.URL.Query() - q.Add("api-version", "2025-04-07") - q.Add("format", "json") - req.URL.RawQuery = q.Encode() - - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("sending request to the instance metadata server: %w", err) - } - - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("reading a response from the metadata server: %w", err) - } - metadata := &InstanceMetadata{} - err = json.Unmarshal(body, metadata) - if err != nil { - return nil, fmt.Errorf("unmarshalling instance metadata: %w", err) + if doc.Signature == "" { + return "", fmt.Errorf("empty attested document signature") } + klog.V(2).Infof("Azure authenticator obtained attested document for resource %q", metadata.ResourceID) - return metadata, nil + // Token format: "x-azure-id " + return AzureAuthenticationTokenPrefix + metadata.ResourceID + " " + doc.Signature, nil } diff --git a/upup/pkg/fi/cloudup/azure/imds.go b/upup/pkg/fi/cloudup/azure/imds.go new file mode 100644 index 0000000000000..7615719d99f55 --- /dev/null +++ b/upup/pkg/fi/cloudup/azure/imds.go @@ -0,0 +1,128 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "k8s.io/klog/v2" +) + +const ( + // imdsBaseURL is the base URL for the Azure Instance Metadata Service. + imdsBaseURL = "http://169.254.169.254" + + // imdsAPIVersion is the IMDS API version. + // https://learn.microsoft.com/en-us/azure/virtual-machines/instance-metadata-service + imdsAPIVersion = "2025-04-07" +) + +// imdsHTTPClient is shared by all IMDS queries. The explicit Transport{Proxy: nil} +// bypasses any system proxy since IMDS lives at the link-local 169.254.169.254 +// address and must not be routed through one. +var imdsHTTPClient = &http.Client{ + Transport: &http.Transport{Proxy: nil}, + Timeout: 10 * time.Second, +} + +// InstanceMetadata contains compute instance metadata from the Azure IMDS. +type InstanceMetadata struct { + SubscriptionID string `json:"subscriptionId"` + ResourceGroupName string `json:"resourceGroupName"` + ResourceID string `json:"resourceId"` + VMID string `json:"vmId"` +} + +// attestedDocument is the JSON response from the IMDS attested/document endpoint. +type attestedDocument struct { + Encoding string `json:"encoding"` + Signature string `json:"signature"` +} + +// queryIMDS queries an Azure IMDS endpoint and unmarshals the JSON response. +// https://learn.microsoft.com/en-us/azure/virtual-machines/instance-metadata-service +func queryIMDS(path string, params url.Values, result any) error { + if path == "" { + return fmt.Errorf("IMDS path is required") + } + if result == nil { + return fmt.Errorf("result is required") + } + + req, err := http.NewRequest("GET", imdsBaseURL+path, nil) + if err != nil { + return fmt.Errorf("creating IMDS request: %w", err) + } + req.Header.Add("Metadata", "True") + + params.Set("api-version", imdsAPIVersion) + req.URL.RawQuery = params.Encode() + + klog.V(4).Infof("Querying Azure IMDS: %s (api-version=%s)", path, imdsAPIVersion) + + resp, err := imdsHTTPClient.Do(req) + if err != nil { + return fmt.Errorf("querying IMDS %s: %w", path, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("querying IMDS %s: status %d", path, resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("reading IMDS response: %w", err) + } + klog.V(8).Infof("Azure IMDS response for %s: %d bytes", path, len(body)) + + if err := json.Unmarshal(body, result); err != nil { + return fmt.Errorf("unmarshalling IMDS response: %w", err) + } + + return nil +} + +// QueryComputeInstanceMetadata queries Azure IMDS for compute instance metadata. +func QueryComputeInstanceMetadata() (*InstanceMetadata, error) { + metadata := &InstanceMetadata{} + params := url.Values{"format": {"json"}} + if err := queryIMDS("/metadata/instance/compute", params, metadata); err != nil { + return nil, err + } + return metadata, nil +} + +// queryIMDSAttestedDocument queries the Azure IMDS attested document endpoint. +// The nonce is included in the PKCS7 signed content for replay protection. +func queryIMDSAttestedDocument(nonce string) (*attestedDocument, error) { + if nonce == "" { + return nil, fmt.Errorf("nonce is required") + } + + doc := &attestedDocument{} + params := url.Values{"nonce": {nonce}} + if err := queryIMDS("/metadata/attested/document", params, doc); err != nil { + return nil, err + } + return doc, nil +} diff --git a/upup/pkg/fi/cloudup/azure/verifier.go b/upup/pkg/fi/cloudup/azure/verifier.go index 9b3157d56f679..2bb192a7787ba 100644 --- a/upup/pkg/fi/cloudup/azure/verifier.go +++ b/upup/pkg/fi/cloudup/azure/verifier.go @@ -28,6 +28,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azidentity" compute "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" network "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + "k8s.io/klog/v2" "k8s.io/kops/pkg/bootstrap" "k8s.io/kops/pkg/wellknownports" ) @@ -37,6 +38,7 @@ const ( InstanceGroupNameTag = "kops.k8s.io_instancegroup" ) +// AzureVerifierOptions configures the Azure bootstrap token verifier. type AzureVerifierOptions struct { ClusterName string `json:"clusterName,omitempty"` } @@ -48,8 +50,10 @@ type azureVerifier struct { var _ bootstrap.Verifier = (*azureVerifier)(nil) +// NewAzureVerifier returns a verifier that validates Azure IMDS attestation +// tokens and resolves the claimed VM identity through the Azure API. func NewAzureVerifier(ctx context.Context, opt *AzureVerifierOptions) (bootstrap.Verifier, error) { - azureClient, err := newClient() + azureClient, err := newVerifierClient() if err != nil { return nil, err } @@ -64,30 +68,62 @@ func NewAzureVerifier(ctx context.Context, opt *AzureVerifierOptions) (bootstrap }, nil } +// VerifyToken validates the Azure attestation token, confirms the claimed VM +// through the Azure API, and returns the node bootstrap identity. func (a azureVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request, token string, body []byte) (*bootstrap.VerifyResult, error) { if !strings.HasPrefix(token, AzureAuthenticationTokenPrefix) { return nil, bootstrap.ErrNotThisVerifier } + // Token format: "x-azure-id " v := strings.Split(strings.TrimPrefix(token, AzureAuthenticationTokenPrefix), " ") if len(v) != 2 { return nil, fmt.Errorf("incorrect token format") } resourceID := v[0] - vmID := v[1] + signature := v[1] + klog.V(2).Infof("Azure verifier processing token for resource %q", resourceID) + + // Parse the resource ID early to reject malformed tokens before expensive crypto. res, err := arm.ParseResourceID(resourceID) if err != nil { - return nil, fmt.Errorf("error parsing token: %v", err) + return nil, fmt.Errorf("parsing resource ID: %w", err) + } + klog.V(4).Infof("Azure verifier parsed resource ID: type=%q subscription=%q resourceGroup=%q name=%q", res.ResourceType, res.SubscriptionID, res.ResourceGroupName, res.Name) + + // Reject resource IDs outside the verifier's own subscription / resource + // group. The Azure API lookup below is already scoped to kops-controller's + // subscription and resource group, so any claim that names a different + // location cannot describe a cluster VM. Failing here avoids a wasted + // Azure API call and makes the scope explicit instead of implicit. + if !strings.EqualFold(res.SubscriptionID, a.client.subscriptionID) { + return nil, fmt.Errorf("resource ID subscription %q does not match verifier subscription %q", res.SubscriptionID, a.client.subscriptionID) + } + if !strings.EqualFold(res.ResourceGroupName, a.client.resourceGroup) { + return nil, fmt.Errorf("resource ID resource group %q does not match verifier resource group %q", res.ResourceGroupName, a.client.resourceGroup) } + klog.V(4).Infof("Azure verifier confirmed resource ID belongs to subscription %q resource group %q", a.client.subscriptionID, a.client.resourceGroup) + // Verify the PKCS7 attested document: signature, certificate chain, nonce, and expiration. + klog.V(2).Infof("Azure verifier verifying attested document for resource %q", resourceID) + data, err := verifyAttestedDocument(signature, body) + if err != nil { + return nil, err + } + klog.V(2).Infof("Azure verifier attested document verified: vmId=%q subscriptionId=%q", data.VMId, data.SubscriptionId) + + // Look up the VM or VMSS VM via the Azure API using the resource ID, + // cross-verify the attested vmId, and extract node identity. var nodeName, igName string var addrs, challengeEndpoints []string switch res.ResourceType.String() { case "Microsoft.Compute/virtualMachines": vmName := res.Name + klog.V(2).Infof("Azure verifier looking up VM %q via Azure API", vmName) + // Fetch the VM from the Azure API. vm, err := a.client.vmsClient.Get(ctx, a.client.resourceGroup, vmName, nil) if err != nil { return nil, fmt.Errorf("getting info for VM %q: %w", vmName, err) @@ -95,20 +131,27 @@ func (a azureVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request if vm.Properties == nil || vm.Properties.VMID == nil { return nil, fmt.Errorf("determining VMID for VM %q", vmName) } - if vmID != *vm.Properties.VMID { - return nil, fmt.Errorf("matching VMID %q to VM %q", vmID, vmName) + + // Cross-verify: the vmId from the cryptographically signed attested document + // must match the vmId from the Azure API for the claimed resource ID. + klog.V(4).Infof("Azure verifier cross-verifying vmId: attested=%q api=%q", data.VMId, *vm.Properties.VMID) + if data.VMId != *vm.Properties.VMID { + return nil, fmt.Errorf("attested vmId %q does not match VM %q (API vmId %q)", data.VMId, vmName, *vm.Properties.VMID) } if vm.Properties.OSProfile == nil || vm.Properties.OSProfile.ComputerName == nil || *vm.Properties.OSProfile.ComputerName == "" { return nil, fmt.Errorf("determining ComputerName for VM %q", vmName) } + // Extract node name and instance group from VM metadata. nodeName = strings.ToLower(*vm.Properties.OSProfile.ComputerName) if igNameTag, ok := vm.Tags[InstanceGroupNameTag]; ok && igNameTag != nil { igName = *igNameTag } else { return nil, fmt.Errorf("determining IG name for VM %q", vmName) } + klog.V(4).Infof("Azure verifier VM %q identity: node=%q instanceGroup=%q", vmName, nodeName, igName) + // Collect private IP addresses from the VM's network interface. ni, err := a.client.nisClient.Get(ctx, a.client.resourceGroup, nodeName, nil) if err != nil { return nil, fmt.Errorf("getting info for VM network interface %q: %w", vmName, err) @@ -124,11 +167,14 @@ func (a azureVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request case "Microsoft.Compute/virtualMachineScaleSets/virtualMachines": vmssName := res.Parent.Name vmssIndex := res.Name + klog.V(2).Infof("Azure verifier looking up VMSS VM %q #%s via Azure API", vmssName, vmssIndex) + // Verify the VMSS belongs to this cluster. if !strings.HasSuffix(vmssName, "."+a.clusterName) { return nil, fmt.Errorf("matching cluster name %q to VMSS %q", a.clusterName, vmssName) } + // Fetch the VMSS VM from the Azure API. vm, err := a.client.vmssVMsClient.Get(ctx, a.client.resourceGroup, vmssName, vmssIndex, nil) if err != nil { return nil, fmt.Errorf("getting info for VMSS VM %q #%s: %w", vmssName, vmssIndex, err) @@ -136,20 +182,27 @@ func (a azureVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request if vm.Properties == nil || vm.Properties.VMID == nil { return nil, fmt.Errorf("determining VMID for VMSS %q VM #%s", vmssName, vmssIndex) } - if vmID != *vm.Properties.VMID { - return nil, fmt.Errorf("matching VMID %q to VMSS %q VM #%s", vmID, vmssName, vmssIndex) + + // Cross-verify: the vmId from the cryptographically signed attested document + // must match the vmId from the Azure API for the claimed resource ID. + klog.V(4).Infof("Azure verifier cross-verifying vmId: attested=%q api=%q", data.VMId, *vm.Properties.VMID) + if data.VMId != *vm.Properties.VMID { + return nil, fmt.Errorf("attested vmId %q does not match VMSS %q VM #%s (API vmId %q)", data.VMId, vmssName, vmssIndex, *vm.Properties.VMID) } if vm.Properties.OSProfile == nil || vm.Properties.OSProfile.ComputerName == nil || *vm.Properties.OSProfile.ComputerName == "" { return nil, fmt.Errorf("determining ComputerName for VMSS %q VM #%s", vmssName, vmssIndex) } + // Extract node name and instance group from VMSS VM metadata. nodeName = strings.ToLower(*vm.Properties.OSProfile.ComputerName) if igNameTag, ok := vm.Tags[InstanceGroupNameTag]; ok && igNameTag != nil { igName = *igNameTag } else { return nil, fmt.Errorf("determining IG name for VM %q", vmssName) } + klog.V(4).Infof("Azure verifier VMSS VM %q #%s identity: node=%q instanceGroup=%q", vmssName, vmssIndex, nodeName, igName) + // Collect private IP addresses from the VMSS VM's network interface. ni, err := a.client.nisClient.GetVirtualMachineScaleSetNetworkInterface(ctx, a.client.resourceGroup, vmssName, vmssIndex, vmssName, nil) if err != nil { return nil, fmt.Errorf("getting info for VMSS VM network interface %q #%s: %w", vmssName, vmssIndex, err) @@ -166,6 +219,7 @@ func (a azureVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request return nil, fmt.Errorf("unsupported resource type %q", res.ResourceType) } + // Validate that we found at least one address and challenge endpoint. if len(addrs) == 0 { return nil, fmt.Errorf("determining certificate alternate names for node %q", nodeName) } @@ -180,19 +234,22 @@ func (a azureVerifier) VerifyToken(ctx context.Context, rawRequest *http.Request ChallengeEndpoint: challengeEndpoints[0], } + klog.V(2).Infof("Azure verifier verified node %q in instance group %q with %d address(es)", nodeName, igName, len(addrs)) return result, nil } // client is an Azure client. type client struct { - resourceGroup string - nisClient *network.InterfacesClient - vmsClient *compute.VirtualMachinesClient - vmssVMsClient *compute.VirtualMachineScaleSetVMsClient + subscriptionID string + resourceGroup string + nisClient *network.InterfacesClient + vmsClient *compute.VirtualMachinesClient + vmssVMsClient *compute.VirtualMachineScaleSetVMsClient } -// newClient returns a new Client. -func newClient() (*client, error) { +// newVerifierClient builds Azure API clients scoped to the local instance's +// subscription and resource group from IMDS metadata. +func newVerifierClient() (*client, error) { metadata, err := QueryComputeInstanceMetadata() if err != nil || metadata == nil { return nil, fmt.Errorf("getting instance metadata: %w", err) @@ -203,6 +260,7 @@ func newClient() (*client, error) { if metadata.SubscriptionID == "" { return nil, fmt.Errorf("empty subscription ID") } + klog.V(4).Infof("Azure verifier client using subscription %q resource group %q", metadata.SubscriptionID, metadata.ResourceGroupName) cred, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { @@ -223,9 +281,10 @@ func newClient() (*client, error) { } return &client{ - resourceGroup: metadata.ResourceGroupName, - nisClient: nisClient, - vmsClient: vmsClient, - vmssVMsClient: vmssVMsClient, + subscriptionID: metadata.SubscriptionID, + resourceGroup: metadata.ResourceGroupName, + nisClient: nisClient, + vmsClient: vmsClient, + vmssVMsClient: vmssVMsClient, }, nil } diff --git a/upup/pkg/fi/cloudup/azure/verifier_test.go b/upup/pkg/fi/cloudup/azure/verifier_test.go new file mode 100644 index 0000000000000..56535a56ea5a1 --- /dev/null +++ b/upup/pkg/fi/cloudup/azure/verifier_test.go @@ -0,0 +1,67 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package azure + +import ( + "encoding/base64" + "testing" + + "k8s.io/kops/pkg/bootstrap" +) + +// TestVerifyToken covers the early rejection paths: wrong prefix (different +// cloud verifier), malformed two-part payload, mismatched subscription/RG, +// and unparseable PKCS7. +func TestVerifyToken(t *testing.T) { + matchingResourceID := "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm" + wrongSubResourceID := "/subscriptions/other/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm" + wrongRGResourceID := "/subscriptions/sub/resourceGroups/other/providers/Microsoft.Compute/virtualMachines/vm" + invalidPKCS7 := base64.StdEncoding.EncodeToString([]byte("not-pkcs7")) + + testCases := []struct { + name string + token string + wantErr error // explicit error to compare with ==; nil means "any non-nil error" + }{ + {"wrong prefix", "x-aws-sts something", bootstrap.ErrNotThisVerifier}, + {"missing signature", AzureAuthenticationTokenPrefix + "no-space-here", nil}, + {"subscription mismatch", AzureAuthenticationTokenPrefix + wrongSubResourceID + " " + invalidPKCS7, nil}, + {"resource group mismatch", AzureAuthenticationTokenPrefix + wrongRGResourceID + " " + invalidPKCS7, nil}, + {"invalid PKCS7", AzureAuthenticationTokenPrefix + matchingResourceID + " " + invalidPKCS7, nil}, + } + + v := &azureVerifier{ + client: &client{ + subscriptionID: "sub", + resourceGroup: "rg", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := v.VerifyToken(nil, nil, tc.token, nil) + if tc.wantErr != nil { + if err != tc.wantErr { + t.Errorf("expected %v, got: %v", tc.wantErr, err) + } + return + } + if err == nil { + t.Error("expected error") + } + }) + } +} diff --git a/vendor/github.com/smallstep/pkcs7/.gitignore b/vendor/github.com/smallstep/pkcs7/.gitignore new file mode 100644 index 0000000000000..948aae2acf4a7 --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/.gitignore @@ -0,0 +1,28 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof + +# Development +.envrc +coverage.out \ No newline at end of file diff --git a/vendor/github.com/smallstep/pkcs7/LICENSE b/vendor/github.com/smallstep/pkcs7/LICENSE new file mode 100644 index 0000000000000..75f3209085b8e --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/LICENSE @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2015 Andrew Smith + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/vendor/github.com/smallstep/pkcs7/Makefile b/vendor/github.com/smallstep/pkcs7/Makefile new file mode 100644 index 0000000000000..47c73b8684e81 --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/Makefile @@ -0,0 +1,20 @@ +all: vet staticcheck test + +test: + go test -covermode=count -coverprofile=coverage.out . + +showcoverage: test + go tool cover -html=coverage.out + +vet: + go vet . + +lint: + golint . + +staticcheck: + staticcheck . + +gettools: + go get -u honnef.co/go/tools/... + go get -u golang.org/x/lint/golint diff --git a/vendor/github.com/smallstep/pkcs7/README.md b/vendor/github.com/smallstep/pkcs7/README.md new file mode 100644 index 0000000000000..9d94e65f251f0 --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/README.md @@ -0,0 +1,63 @@ +# pkcs7 + +[![Go Reference](https://pkg.go.dev/badge/github.com/smallstep/pkcs7.svg)](https://pkg.go.dev/github.com/smallstep/pkcs7) +[![Build Status](https://github.com/smallstep/pkcs7/workflows/CI/badge.svg?query=branch%3Amain+event%3Apush)](https://github.com/smallstep/pkcs7/actions/workflows/ci.yml?query=branch%3Amain+event%3Apush) + +pkcs7 implements parsing and creating signed and enveloped messages. + +```go +package main + +import ( + "bytes" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "os" + + "github.com/smallstep/pkcs7" +) + +func SignAndDetach(content []byte, cert *x509.Certificate, privkey *rsa.PrivateKey) (signed []byte, err error) { + toBeSigned, err := NewSignedData(content) + if err != nil { + return fmt.Errorf("Cannot initialize signed data: %w", err) + } + if err = toBeSigned.AddSigner(cert, privkey, SignerInfoConfig{}); err != nil { + return fmt.Errorf("Cannot add signer: %w", err) + } + + // Detach signature, omit if you want an embedded signature + toBeSigned.Detach() + + signed, err = toBeSigned.Finish() + if err != nil { + return fmt.Errorf("Cannot finish signing data: %w", err) + } + + // Verify the signature + pem.Encode(os.Stdout, &pem.Block{Type: "PKCS7", Bytes: signed}) + p7, err := pkcs7.Parse(signed) + if err != nil { + return fmt.Errorf("Cannot parse our signed data: %w", err) + } + + // since the signature was detached, reattach the content here + p7.Content = content + + if bytes.Compare(content, p7.Content) != 0 { + return fmt.Errorf("Our content was not in the parsed data:\n\tExpected: %s\n\tActual: %s", content, p7.Content) + } + if err = p7.Verify(); err != nil { + return fmt.Errorf("Cannot verify our signed data: %w", err) + } + + return signed, nil +} +``` + + +## Credits + +This is a fork of [mozilla-services/pkcs7](https://github.com/mozilla-services/pkcs7) which, itself, was a fork of [fullsailor/pkcs7](https://github.com/fullsailor/pkcs7). diff --git a/vendor/github.com/smallstep/pkcs7/ber.go b/vendor/github.com/smallstep/pkcs7/ber.go new file mode 100644 index 0000000000000..52333215dee27 --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/ber.go @@ -0,0 +1,266 @@ +package pkcs7 + +import ( + "bytes" + "errors" +) + +type asn1Object interface { + EncodeTo(writer *bytes.Buffer) error +} + +type asn1Structured struct { + tagBytes []byte + content []asn1Object +} + +func (s asn1Structured) EncodeTo(out *bytes.Buffer) error { + inner := new(bytes.Buffer) + for _, obj := range s.content { + err := obj.EncodeTo(inner) + if err != nil { + return err + } + } + out.Write(s.tagBytes) + encodeLength(out, inner.Len()) + out.Write(inner.Bytes()) + return nil +} + +type asn1Primitive struct { + tagBytes []byte + length int + content []byte +} + +func (p asn1Primitive) EncodeTo(out *bytes.Buffer) error { + _, err := out.Write(p.tagBytes) + if err != nil { + return err + } + if err = encodeLength(out, p.length); err != nil { + return err + } + // fmt.Printf("%s--> tag: % X length: %d\n", strings.Repeat("| ", encodeIndent), p.tagBytes, p.length) + // fmt.Printf("%s--> content length: %d\n", strings.Repeat("| ", encodeIndent), len(p.content)) + out.Write(p.content) + + return nil +} + +func ber2der(ber []byte) ([]byte, error) { + if len(ber) == 0 { + return nil, errors.New("ber2der: input ber is empty") + } + // fmt.Printf("--> ber2der: Transcoding %d bytes\n", len(ber)) + out := new(bytes.Buffer) + + obj, _, err := readObject(ber, 0) + if err != nil { + return nil, err + } + obj.EncodeTo(out) + + // if offset < len(ber) { + // return nil, fmt.Errorf("ber2der: Content longer than expected. Got %d, expected %d", offset, len(ber)) + // } + + return out.Bytes(), nil +} + +// encodes lengths that are longer than 127 into string of bytes +func marshalLongLength(out *bytes.Buffer, i int) (err error) { + n := lengthLength(i) + + for ; n > 0; n-- { + err = out.WriteByte(byte(i >> uint((n-1)*8))) + if err != nil { + return + } + } + + return nil +} + +// computes the byte length of an encoded length value +func lengthLength(i int) (numBytes int) { + numBytes = 1 + for i > 255 { + numBytes++ + i >>= 8 + } + return +} + +// encodes the length in DER format +// If the length fits in 7 bits, the value is encoded directly. +// +// Otherwise, the number of bytes to encode the length is first determined. +// This number is likely to be 4 or less for a 32bit length. This number is +// added to 0x80. The length is encoded in big endian encoding follow after +// +// Examples: +// length | byte 1 | bytes n +// 0 | 0x00 | - +// 120 | 0x78 | - +// 200 | 0x81 | 0xC8 +// 500 | 0x82 | 0x01 0xF4 +// +func encodeLength(out *bytes.Buffer, length int) (err error) { + if length >= 128 { + l := lengthLength(length) + err = out.WriteByte(0x80 | byte(l)) + if err != nil { + return + } + err = marshalLongLength(out, length) + if err != nil { + return + } + } else { + err = out.WriteByte(byte(length)) + if err != nil { + return + } + } + return +} + +func readObject(ber []byte, offset int) (asn1Object, int, error) { + berLen := len(ber) + if offset >= berLen { + return nil, 0, errors.New("ber2der: offset is after end of ber data") + } + tagStart := offset + b := ber[offset] + offset++ + if offset >= berLen { + return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached") + } + tag := b & 0x1F // last 5 bits + if tag == 0x1F { + tag = 0 + for ber[offset] >= 0x80 { + tag = tag*128 + ber[offset] - 0x80 + offset++ + if offset > berLen { + return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached") + } + } + // jvehent 20170227: this doesn't appear to be used anywhere... + // tag = tag*128 + ber[offset] - 0x80 + offset++ + if offset > berLen { + return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached") + } + } + tagEnd := offset + + kind := b & 0x20 + if kind == 0 { + debugprint("--> Primitive\n") + } else { + debugprint("--> Constructed\n") + } + // read length + var length int + l := ber[offset] + offset++ + if offset > berLen { + return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached") + } + indefinite := false + if l > 0x80 { + numberOfBytes := (int)(l & 0x7F) + if numberOfBytes > 4 { // int is only guaranteed to be 32bit + return nil, 0, errors.New("ber2der: BER tag length too long") + } + if numberOfBytes == 4 && (int)(ber[offset]) > 0x7F { + return nil, 0, errors.New("ber2der: BER tag length is negative") + } + if (int)(ber[offset]) == 0x0 { + return nil, 0, errors.New("ber2der: BER tag length has leading zero") + } + debugprint("--> (compute length) indicator byte: %x\n", l) + debugprint("--> (compute length) length bytes: % X\n", ber[offset:offset+numberOfBytes]) + for i := 0; i < numberOfBytes; i++ { + length = length*256 + (int)(ber[offset]) + offset++ + if offset > berLen { + return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached") + } + } + } else if l == 0x80 { + indefinite = true + } else { + length = (int)(l) + } + if length < 0 { + return nil, 0, errors.New("ber2der: invalid negative value found in BER tag length") + } + // fmt.Printf("--> length : %d\n", length) + contentEnd := offset + length + if contentEnd > len(ber) { + return nil, 0, errors.New("ber2der: BER tag length is more than available data") + } + debugprint("--> content start : %d\n", offset) + debugprint("--> content end : %d\n", contentEnd) + debugprint("--> content : % X\n", ber[offset:contentEnd]) + var obj asn1Object + if indefinite && kind == 0 { + return nil, 0, errors.New("ber2der: Indefinite form tag must have constructed encoding") + } + if kind == 0 { + obj = asn1Primitive{ + tagBytes: ber[tagStart:tagEnd], + length: length, + content: ber[offset:contentEnd], + } + } else { + var subObjects []asn1Object + for (offset < contentEnd) || indefinite { + var subObj asn1Object + var err error + subObj, offset, err = readObject(ber, offset) + if err != nil { + return nil, 0, err + } + subObjects = append(subObjects, subObj) + + if indefinite { + terminated, err := isIndefiniteTermination(ber, offset) + if err != nil { + return nil, 0, err + } + + if terminated { + break + } + } + } + obj = asn1Structured{ + tagBytes: ber[tagStart:tagEnd], + content: subObjects, + } + } + + // Apply indefinite form length with 0x0000 terminator. + if indefinite { + contentEnd = offset + 2 + } + + return obj, contentEnd, nil +} + +func isIndefiniteTermination(ber []byte, offset int) (bool, error) { + if len(ber)-offset < 2 { + return false, errors.New("ber2der: Invalid BER format") + } + + return bytes.Index(ber[offset:], []byte{0x0, 0x0}) == 0, nil +} + +func debugprint(format string, a ...interface{}) { + // fmt.Printf(format, a) +} diff --git a/vendor/github.com/smallstep/pkcs7/decrypt.go b/vendor/github.com/smallstep/pkcs7/decrypt.go new file mode 100644 index 0000000000000..76dc17f74ce13 --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/decrypt.go @@ -0,0 +1,233 @@ +package pkcs7 + +import ( + "bytes" + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "fmt" +) + +// ErrUnsupportedAlgorithm tells you when our quick dev assumptions have failed +var ErrUnsupportedAlgorithm = errors.New("pkcs7: cannot decrypt data: only RSA, DES, DES-EDE3, AES-256-CBC and AES-128-GCM supported") + +// ErrUnsupportedAsymmetricEncryptionAlgorithm is returned when attempting to use an unknown asymmetric encryption algorithm +var ErrUnsupportedAsymmetricEncryptionAlgorithm = errors.New("pkcs7: cannot decrypt data: only RSA PKCS#1 v1.5 and RSA OAEP are supported") + +// ErrUnsupportedKeyType is returned when attempting to encrypting keys using a key that's not an RSA key +var ErrUnsupportedKeyType = errors.New("pkcs7: only RSA keys are supported") + +// ErrNotEncryptedContent is returned when attempting to Decrypt data that is not encrypted data +var ErrNotEncryptedContent = errors.New("pkcs7: content data is a decryptable data type") + +// Decrypt decrypts encrypted content info for recipient cert and private key +func (p7 *PKCS7) Decrypt(cert *x509.Certificate, pkey crypto.PrivateKey) ([]byte, error) { + data, ok := p7.raw.(envelopedData) + if !ok { + return nil, ErrNotEncryptedContent + } + recipient := selectRecipientForCertificate(data.RecipientInfos, cert) + if recipient.EncryptedKey == nil { + return nil, errors.New("pkcs7: no enveloped recipient for provided certificate") + } + switch pkey := pkey.(type) { + case crypto.Decrypter: + var opts crypto.DecrypterOpts + switch algorithm := recipient.KeyEncryptionAlgorithm.Algorithm; { + case algorithm.Equal(OIDEncryptionAlgorithmRSAESOAEP): + hashFunc, err := getHashFuncForKeyEncryptionAlgorithm(recipient.KeyEncryptionAlgorithm) + if err != nil { + return nil, err + } + opts = &rsa.OAEPOptions{Hash: hashFunc} + case algorithm.Equal(OIDEncryptionAlgorithmRSA): + opts = &rsa.PKCS1v15DecryptOptions{} + default: + return nil, ErrUnsupportedAsymmetricEncryptionAlgorithm + } + contentKey, err := pkey.Decrypt(rand.Reader, recipient.EncryptedKey, opts) + if err != nil { + return nil, err + } + return data.EncryptedContentInfo.decrypt(contentKey) + } + return nil, ErrUnsupportedAlgorithm +} + +// RFC 4055, 4.1 +// The current ASN.1 parser does not support non-integer defaults so the 'default:' tags here do nothing. +type rsaOAEPAlgorithmParameters struct { + HashFunc pkix.AlgorithmIdentifier `asn1:"optional,explicit,tag:0,default:sha1Identifier"` + MaskGenFunc pkix.AlgorithmIdentifier `asn1:"optional,explicit,tag:1,default:mgf1SHA1Identifier"` + PSourceFunc pkix.AlgorithmIdentifier `asn1:"optional,explicit,tag:2,default:pSpecifiedEmptyIdentifier"` +} + +func getHashFuncForKeyEncryptionAlgorithm(keyEncryptionAlgorithm pkix.AlgorithmIdentifier) (crypto.Hash, error) { + invalidHashFunc := crypto.Hash(0) + params := &rsaOAEPAlgorithmParameters{ + HashFunc: pkix.AlgorithmIdentifier{Algorithm: OIDDigestAlgorithmSHA1}, // set default hash algorithm to SHA1 + } + var rest []byte + rest, err := asn1.Unmarshal(keyEncryptionAlgorithm.Parameters.FullBytes, params) + if err != nil { + return invalidHashFunc, fmt.Errorf("pkcs7: failed unmarshaling key encryption algorithm parameters: %v", err) + } + if len(rest) != 0 { + return invalidHashFunc, errors.New("pkcs7: trailing data after RSA OAEP parameters") + } + + switch { + case params.HashFunc.Algorithm.Equal(OIDDigestAlgorithmSHA1): + return crypto.SHA1, nil + case params.HashFunc.Algorithm.Equal(OIDDigestAlgorithmSHA224): + return crypto.SHA224, nil + case params.HashFunc.Algorithm.Equal(OIDDigestAlgorithmSHA256): + return crypto.SHA256, nil + case params.HashFunc.Algorithm.Equal(OIDDigestAlgorithmSHA384): + return crypto.SHA384, nil + case params.HashFunc.Algorithm.Equal(OIDDigestAlgorithmSHA512): + return crypto.SHA512, nil + default: + return invalidHashFunc, errors.New("pkcs7: unsupported hash function for RSA OAEP") + } +} + +// DecryptUsingPSK decrypts encrypted data using caller provided +// pre-shared secret +func (p7 *PKCS7) DecryptUsingPSK(key []byte) ([]byte, error) { + data, ok := p7.raw.(encryptedData) + if !ok { + return nil, ErrNotEncryptedContent + } + return data.EncryptedContentInfo.decrypt(key) +} + +func (eci encryptedContentInfo) decrypt(key []byte) ([]byte, error) { + alg := eci.ContentEncryptionAlgorithm.Algorithm + if !alg.Equal(OIDEncryptionAlgorithmDESCBC) && + !alg.Equal(OIDEncryptionAlgorithmDESEDE3CBC) && + !alg.Equal(OIDEncryptionAlgorithmAES256CBC) && + !alg.Equal(OIDEncryptionAlgorithmAES128CBC) && + !alg.Equal(OIDEncryptionAlgorithmAES128GCM) && + !alg.Equal(OIDEncryptionAlgorithmAES256GCM) { + return nil, ErrUnsupportedAlgorithm + } + + // EncryptedContent can either be constructed of multple OCTET STRINGs + // or _be_ a tagged OCTET STRING + var cyphertext []byte + if eci.EncryptedContent.IsCompound { + // Complex case to concat all of the children OCTET STRINGs + var buf bytes.Buffer + cypherbytes := eci.EncryptedContent.Bytes + for { + var part []byte + cypherbytes, _ = asn1.Unmarshal(cypherbytes, &part) + buf.Write(part) + if cypherbytes == nil { + break + } + } + cyphertext = buf.Bytes() + } else { + // Simple case, the bytes _are_ the cyphertext + cyphertext = eci.EncryptedContent.Bytes + } + + var block cipher.Block + var err error + + switch { + case alg.Equal(OIDEncryptionAlgorithmDESCBC): + block, err = des.NewCipher(key) + case alg.Equal(OIDEncryptionAlgorithmDESEDE3CBC): + block, err = des.NewTripleDESCipher(key) + case alg.Equal(OIDEncryptionAlgorithmAES256CBC), alg.Equal(OIDEncryptionAlgorithmAES256GCM): + fallthrough + case alg.Equal(OIDEncryptionAlgorithmAES128GCM), alg.Equal(OIDEncryptionAlgorithmAES128CBC): + block, err = aes.NewCipher(key) + } + + if err != nil { + return nil, err + } + + if alg.Equal(OIDEncryptionAlgorithmAES128GCM) || alg.Equal(OIDEncryptionAlgorithmAES256GCM) { + params := aesGCMParameters{} + paramBytes := eci.ContentEncryptionAlgorithm.Parameters.Bytes + + _, err := asn1.Unmarshal(paramBytes, ¶ms) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + if len(params.Nonce) != gcm.NonceSize() { + return nil, errors.New("pkcs7: encryption algorithm parameters are incorrect") + } + if params.ICVLen != gcm.Overhead() { + return nil, errors.New("pkcs7: encryption algorithm parameters are incorrect") + } + + plaintext, err := gcm.Open(nil, params.Nonce, cyphertext, nil) + if err != nil { + return nil, err + } + + return plaintext, nil + } + + iv := eci.ContentEncryptionAlgorithm.Parameters.Bytes + if len(iv) != block.BlockSize() { + return nil, errors.New("pkcs7: encryption algorithm parameters are malformed") + } + mode := cipher.NewCBCDecrypter(block, iv) + plaintext := make([]byte, len(cyphertext)) + mode.CryptBlocks(plaintext, cyphertext) + if plaintext, err = unpad(plaintext, mode.BlockSize()); err != nil { + return nil, err + } + return plaintext, nil +} + +func unpad(data []byte, blocklen int) ([]byte, error) { + if blocklen < 1 { + return nil, fmt.Errorf("pkcs7: invalid blocklen %d", blocklen) + } + if len(data)%blocklen != 0 || len(data) == 0 { + return nil, fmt.Errorf("pkcs7: invalid data len %d", len(data)) + } + + // the last byte is the length of padding + padlen := int(data[len(data)-1]) + + // check padding integrity, all bytes should be the same + pad := data[len(data)-padlen:] + for _, padbyte := range pad { + if padbyte != byte(padlen) { + return nil, errors.New("pkcs7: invalid padding") + } + } + + return data[:len(data)-padlen], nil +} + +func selectRecipientForCertificate(recipients []recipientInfo, cert *x509.Certificate) recipientInfo { + for _, recp := range recipients { + if isCertMatchForIssuerAndSerial(cert, recp.IssuerAndSerialNumber) { + return recp + } + } + return recipientInfo{} +} diff --git a/vendor/github.com/smallstep/pkcs7/encrypt.go b/vendor/github.com/smallstep/pkcs7/encrypt.go new file mode 100644 index 0000000000000..a5c96e7553100 --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/encrypt.go @@ -0,0 +1,475 @@ +package pkcs7 + +import ( + "bytes" + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "fmt" +) + +type envelopedData struct { + Version int + RecipientInfos []recipientInfo `asn1:"set"` + EncryptedContentInfo encryptedContentInfo +} + +type encryptedData struct { + Version int + EncryptedContentInfo encryptedContentInfo +} + +type recipientInfo struct { + Version int + IssuerAndSerialNumber issuerAndSerial + KeyEncryptionAlgorithm pkix.AlgorithmIdentifier + EncryptedKey []byte +} + +type encryptedContentInfo struct { + ContentType asn1.ObjectIdentifier + ContentEncryptionAlgorithm pkix.AlgorithmIdentifier + EncryptedContent asn1.RawValue `asn1:"tag:0,optional"` +} + +const ( + // EncryptionAlgorithmDESCBC is the DES CBC encryption algorithm + EncryptionAlgorithmDESCBC = iota + + // EncryptionAlgorithmAES128CBC is the AES 128 bits with CBC encryption algorithm + // Avoid this algorithm unless required for interoperability; use AES GCM instead. + EncryptionAlgorithmAES128CBC + + // EncryptionAlgorithmAES256CBC is the AES 256 bits with CBC encryption algorithm + // Avoid this algorithm unless required for interoperability; use AES GCM instead. + EncryptionAlgorithmAES256CBC + + // EncryptionAlgorithmAES128GCM is the AES 128 bits with GCM encryption algorithm + EncryptionAlgorithmAES128GCM + + // EncryptionAlgorithmAES256GCM is the AES 256 bits with GCM encryption algorithm + EncryptionAlgorithmAES256GCM +) + +// ContentEncryptionAlgorithm determines the algorithm used to encrypt the +// plaintext message. Change the value of this variable to change which +// algorithm is used in the Encrypt() function. +var ContentEncryptionAlgorithm = EncryptionAlgorithmDESCBC + +// ErrUnsupportedEncryptionAlgorithm is returned when attempting to encrypt +// content with an unsupported algorithm. +var ErrUnsupportedEncryptionAlgorithm = errors.New("pkcs7: cannot encrypt content: only DES-CBC, AES-CBC, and AES-GCM supported") + +// KeyEncryptionAlgorithm determines the algorithm used to encrypt a +// content key. Change the value of this variable to change which +// algorithm is used in the Encrypt() function. +var KeyEncryptionAlgorithm = OIDEncryptionAlgorithmRSA + +// ErrUnsupportedKeyEncryptionAlgorithm is returned when an +// unsupported key encryption algorithm OID is provided. +var ErrUnsupportedKeyEncryptionAlgorithm = errors.New("pkcs7: unsupported key encryption algorithm provided") + +// KeyEncryptionHash determines the crypto.Hash algorithm to use +// when encrypting a content key. Change the value of this variable +// to change which algorithm is used in the Encrypt() function. +var KeyEncryptionHash = crypto.SHA256 + +// ErrUnsupportedKeyEncryptionHash is returned when an +// unsupported key encryption hash is provided. +var ErrUnsupportedKeyEncryptionHash = errors.New("pkcs7: unsupported key encryption hash provided") + +// ErrPSKNotProvided is returned when attempting to encrypt +// using a PSK without actually providing the PSK. +var ErrPSKNotProvided = errors.New("pkcs7: cannot encrypt content: PSK not provided") + +const nonceSize = 12 + +type aesGCMParameters struct { + Nonce []byte `asn1:"tag:4"` + ICVLen int +} + +func encryptAESGCM(content []byte, key []byte) ([]byte, *encryptedContentInfo, error) { + var keyLen int + var algID asn1.ObjectIdentifier + switch ContentEncryptionAlgorithm { + case EncryptionAlgorithmAES128GCM: + keyLen = 16 + algID = OIDEncryptionAlgorithmAES128GCM + case EncryptionAlgorithmAES256GCM: + keyLen = 32 + algID = OIDEncryptionAlgorithmAES256GCM + default: + return nil, nil, fmt.Errorf("invalid ContentEncryptionAlgorithm in encryptAESGCM: %d", ContentEncryptionAlgorithm) + } + if key == nil { + // Create AES key + key = make([]byte, keyLen) + + _, err := rand.Read(key) + if err != nil { + return nil, nil, err + } + } + + // Create nonce + nonce := make([]byte, nonceSize) + + _, err := rand.Read(nonce) + if err != nil { + return nil, nil, err + } + + // Encrypt content + block, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, nil, err + } + + ciphertext := gcm.Seal(nil, nonce, content, nil) + + // Prepare ASN.1 Encrypted Content Info + paramSeq := aesGCMParameters{ + Nonce: nonce, + ICVLen: gcm.Overhead(), + } + + paramBytes, err := asn1.Marshal(paramSeq) + if err != nil { + return nil, nil, err + } + + eci := encryptedContentInfo{ + ContentType: OIDData, + ContentEncryptionAlgorithm: pkix.AlgorithmIdentifier{ + Algorithm: algID, + Parameters: asn1.RawValue{ + Tag: asn1.TagSequence, + Bytes: paramBytes, + }, + }, + EncryptedContent: marshalEncryptedContent(ciphertext), + } + + return key, &eci, nil +} + +func encryptDESCBC(content []byte, key []byte) ([]byte, *encryptedContentInfo, error) { + if key == nil { + // Create DES key + key = make([]byte, 8) + + _, err := rand.Read(key) + if err != nil { + return nil, nil, err + } + } + + // Create CBC IV + iv := make([]byte, des.BlockSize) + _, err := rand.Read(iv) + if err != nil { + return nil, nil, err + } + + // Encrypt padded content + block, err := des.NewCipher(key) + if err != nil { + return nil, nil, err + } + mode := cipher.NewCBCEncrypter(block, iv) + plaintext, err := pad(content, mode.BlockSize()) + if err != nil { + return nil, nil, err + } + cyphertext := make([]byte, len(plaintext)) + mode.CryptBlocks(cyphertext, plaintext) + + // Prepare ASN.1 Encrypted Content Info + eci := encryptedContentInfo{ + ContentType: OIDData, + ContentEncryptionAlgorithm: pkix.AlgorithmIdentifier{ + Algorithm: OIDEncryptionAlgorithmDESCBC, + Parameters: asn1.RawValue{Tag: 4, Bytes: iv}, + }, + EncryptedContent: marshalEncryptedContent(cyphertext), + } + + return key, &eci, nil +} + +func encryptAESCBC(content []byte, key []byte) ([]byte, *encryptedContentInfo, error) { + var keyLen int + var algID asn1.ObjectIdentifier + switch ContentEncryptionAlgorithm { + case EncryptionAlgorithmAES128CBC: + keyLen = 16 + algID = OIDEncryptionAlgorithmAES128CBC + case EncryptionAlgorithmAES256CBC: + keyLen = 32 + algID = OIDEncryptionAlgorithmAES256CBC + default: + return nil, nil, fmt.Errorf("invalid ContentEncryptionAlgorithm in encryptAESCBC: %d", ContentEncryptionAlgorithm) + } + + if key == nil { + // Create AES key + key = make([]byte, keyLen) + + _, err := rand.Read(key) + if err != nil { + return nil, nil, err + } + } + + // Create CBC IV + iv := make([]byte, aes.BlockSize) + _, err := rand.Read(iv) + if err != nil { + return nil, nil, err + } + + // Encrypt padded content + block, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + mode := cipher.NewCBCEncrypter(block, iv) + plaintext, err := pad(content, mode.BlockSize()) + if err != nil { + return nil, nil, err + } + cyphertext := make([]byte, len(plaintext)) + mode.CryptBlocks(cyphertext, plaintext) + + // Prepare ASN.1 Encrypted Content Info + eci := encryptedContentInfo{ + ContentType: OIDData, + ContentEncryptionAlgorithm: pkix.AlgorithmIdentifier{ + Algorithm: algID, + Parameters: asn1.RawValue{Tag: 4, Bytes: iv}, + }, + EncryptedContent: marshalEncryptedContent(cyphertext), + } + + return key, &eci, nil +} + +// Encrypt creates and returns an envelope data PKCS7 structure with encrypted +// recipient keys for each recipient public key. +// +// The algorithm used to perform encryption is determined by the current value +// of the global ContentEncryptionAlgorithm package variable. By default, the +// value is EncryptionAlgorithmDESCBC. To use a different algorithm, change the +// value before calling Encrypt(). For example: +// +// ContentEncryptionAlgorithm = EncryptionAlgorithmAES256GCM +// +// TODO(fullsailor): Add support for encrypting content with other algorithms +func Encrypt(content []byte, recipients []*x509.Certificate) ([]byte, error) { + var eci *encryptedContentInfo + var key []byte + var err error + + // Apply chosen symmetric encryption method + switch ContentEncryptionAlgorithm { + case EncryptionAlgorithmDESCBC: + key, eci, err = encryptDESCBC(content, nil) + case EncryptionAlgorithmAES128CBC: + fallthrough + case EncryptionAlgorithmAES256CBC: + key, eci, err = encryptAESCBC(content, nil) + case EncryptionAlgorithmAES128GCM: + fallthrough + case EncryptionAlgorithmAES256GCM: + key, eci, err = encryptAESGCM(content, nil) + + default: + return nil, ErrUnsupportedEncryptionAlgorithm + } + + if err != nil { + return nil, err + } + + // Prepare each recipient's encrypted cipher key + recipientInfos := make([]recipientInfo, len(recipients)) + for i, recipient := range recipients { + algorithm := KeyEncryptionAlgorithm + hash := KeyEncryptionHash + var kea pkix.AlgorithmIdentifier + switch { + case algorithm.Equal(OIDEncryptionAlgorithmRSAESOAEP): + parameters, err := getParametersForKeyEncryptionAlgorithm(algorithm, hash) + if err != nil { + return nil, fmt.Errorf("failed to get parameters for key encryption: %v", err) + } + kea = pkix.AlgorithmIdentifier{ + Algorithm: algorithm, + Parameters: parameters, + } + case algorithm.Equal(OIDEncryptionAlgorithmRSA): + kea = pkix.AlgorithmIdentifier{ + Algorithm: algorithm, + } + default: + return nil, ErrUnsupportedKeyEncryptionAlgorithm + } + encrypted, err := encryptKey(key, recipient, algorithm, hash) + if err != nil { + return nil, err + } + ias, err := cert2issuerAndSerial(recipient) + if err != nil { + return nil, err + } + info := recipientInfo{ + Version: 0, + IssuerAndSerialNumber: ias, + KeyEncryptionAlgorithm: kea, + EncryptedKey: encrypted, + } + recipientInfos[i] = info + } + + // Prepare envelope content + envelope := envelopedData{ + EncryptedContentInfo: *eci, + Version: 0, + RecipientInfos: recipientInfos, + } + innerContent, err := asn1.Marshal(envelope) + if err != nil { + return nil, err + } + + // Prepare outer payload structure + wrapper := contentInfo{ + ContentType: OIDEnvelopedData, + Content: asn1.RawValue{Class: 2, Tag: 0, IsCompound: true, Bytes: innerContent}, + } + + return asn1.Marshal(wrapper) +} + +func getParametersForKeyEncryptionAlgorithm(algorithm asn1.ObjectIdentifier, hash crypto.Hash) (asn1.RawValue, error) { + if !algorithm.Equal(OIDEncryptionAlgorithmRSAESOAEP) { + return asn1.RawValue{}, nil // return empty; not used + } + + params := rsaOAEPAlgorithmParameters{} + switch hash { + case crypto.SHA1: + params.HashFunc = pkix.AlgorithmIdentifier{Algorithm: OIDDigestAlgorithmSHA1} + case crypto.SHA224: + params.HashFunc = pkix.AlgorithmIdentifier{Algorithm: OIDDigestAlgorithmSHA224} + case crypto.SHA256: + params.HashFunc = pkix.AlgorithmIdentifier{Algorithm: OIDDigestAlgorithmSHA256} + case crypto.SHA384: + params.HashFunc = pkix.AlgorithmIdentifier{Algorithm: OIDDigestAlgorithmSHA384} + case crypto.SHA512: + params.HashFunc = pkix.AlgorithmIdentifier{Algorithm: OIDDigestAlgorithmSHA512} + default: + return asn1.RawValue{}, ErrUnsupportedAlgorithm + } + + b, err := asn1.Marshal(params) + if err != nil { + return asn1.RawValue{}, fmt.Errorf("failed marshaling key encryption parameters: %v", err) + } + + return asn1.RawValue{ + FullBytes: b, + }, nil +} + +// EncryptUsingPSK creates and returns an encrypted data PKCS7 structure, +// encrypted using caller provided pre-shared secret. +func EncryptUsingPSK(content []byte, key []byte) ([]byte, error) { + var eci *encryptedContentInfo + var err error + + if key == nil { + return nil, ErrPSKNotProvided + } + + // Apply chosen symmetric encryption method + switch ContentEncryptionAlgorithm { + case EncryptionAlgorithmDESCBC: + _, eci, err = encryptDESCBC(content, key) + + case EncryptionAlgorithmAES128GCM: + fallthrough + case EncryptionAlgorithmAES256GCM: + _, eci, err = encryptAESGCM(content, key) + + default: + return nil, ErrUnsupportedEncryptionAlgorithm + } + + if err != nil { + return nil, err + } + + // Prepare encrypted-data content + ed := encryptedData{ + Version: 0, + EncryptedContentInfo: *eci, + } + innerContent, err := asn1.Marshal(ed) + if err != nil { + return nil, err + } + + // Prepare outer payload structure + wrapper := contentInfo{ + ContentType: OIDEncryptedData, + Content: asn1.RawValue{Class: 2, Tag: 0, IsCompound: true, Bytes: innerContent}, + } + + return asn1.Marshal(wrapper) +} + +func marshalEncryptedContent(content []byte) asn1.RawValue { + return asn1.RawValue{Bytes: content, Class: 2, IsCompound: false} +} + +func encryptKey(key []byte, recipient *x509.Certificate, algorithm asn1.ObjectIdentifier, hash crypto.Hash) ([]byte, error) { + pub, ok := recipient.PublicKey.(*rsa.PublicKey) + if !ok { + return nil, ErrUnsupportedKeyType + } + + switch { + case algorithm.Equal(OIDEncryptionAlgorithmRSA): + return rsa.EncryptPKCS1v15(rand.Reader, pub, key) + case algorithm.Equal(OIDEncryptionAlgorithmRSAESOAEP): + return rsa.EncryptOAEP(hash.New(), rand.Reader, pub, key, nil) + default: + return nil, ErrUnsupportedKeyEncryptionAlgorithm + } +} + +func pad(data []byte, blocklen int) ([]byte, error) { + if blocklen < 1 { + return nil, fmt.Errorf("invalid blocklen %d", blocklen) + } + padlen := blocklen - (len(data) % blocklen) + if padlen == 0 { + padlen = blocklen + } + pad := bytes.Repeat([]byte{byte(padlen)}, padlen) + return append(data, pad...), nil +} diff --git a/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/debug.go b/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/debug.go new file mode 100644 index 0000000000000..378cc265d264b --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/debug.go @@ -0,0 +1,14 @@ +package legacyx509 + +import "fmt" + +// legacyGodebugSetting is a type mimicking Go's internal godebug package +// settings, which are used to enable / disable certain functionalities at +// build time. +type legacyGodebugSetting int + +func (s legacyGodebugSetting) Value() string { + return fmt.Sprintf("%d", s) +} + +func (s legacyGodebugSetting) IncNonDefault() {} diff --git a/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/doc.go b/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/doc.go new file mode 100644 index 0000000000000..7d1469b6d0023 --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/doc.go @@ -0,0 +1,14 @@ +/* +Package legacyx509 is a copy of certain parts of Go's crypto/x509 package. +It is based on Go 1.23, and has just the parts copied over required for +parsing X509 certificates. + +The primary reason this copy exists is to keep support for parsing PKCS7 +messages containing Simple Certificate Enrolment Protocol (SCEP) requests +from Windows devices. Go 1.23 made a change marking certificates with a +critical authority key identifier as invalid, which is mandated by RFC 5280, +but apparently Windows marks those specific certificates as such, resulting +in those SCEP requests failing from being parsed correctly. +*/ + +package legacyx509 diff --git a/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/oid.go b/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/oid.go new file mode 100644 index 0000000000000..8268a07c50a04 --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/oid.go @@ -0,0 +1,377 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package legacyx509 + +import ( + "bytes" + "encoding/asn1" + "errors" + "math" + "math/big" + "math/bits" + "strconv" + "strings" +) + +var ( + errInvalidOID = errors.New("invalid oid") +) + +// An OID represents an ASN.1 OBJECT IDENTIFIER. +type OID struct { + der []byte +} + +// ParseOID parses a Object Identifier string, represented by ASCII numbers separated by dots. +func ParseOID(oid string) (OID, error) { + var o OID + return o, o.unmarshalOIDText(oid) +} + +func newOIDFromDER(der []byte) (OID, bool) { + if len(der) == 0 || der[len(der)-1]&0x80 != 0 { + return OID{}, false + } + + start := 0 + for i, v := range der { + // ITU-T X.690, section 8.19.2: + // The subidentifier shall be encoded in the fewest possible octets, + // that is, the leading octet of the subidentifier shall not have the value 0x80. + if i == start && v == 0x80 { + return OID{}, false + } + if v&0x80 == 0 { + start = i + 1 + } + } + + return OID{der}, true +} + +// OIDFromInts creates a new OID using ints, each integer is a separate component. +func OIDFromInts(oid []uint64) (OID, error) { + if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) { + return OID{}, errInvalidOID + } + + length := base128IntLength(oid[0]*40 + oid[1]) + for _, v := range oid[2:] { + length += base128IntLength(v) + } + + der := make([]byte, 0, length) + der = appendBase128Int(der, oid[0]*40+oid[1]) + for _, v := range oid[2:] { + der = appendBase128Int(der, v) + } + return OID{der}, nil +} + +func base128IntLength(n uint64) int { + if n == 0 { + return 1 + } + return (bits.Len64(n) + 6) / 7 +} + +func appendBase128Int(dst []byte, n uint64) []byte { + for i := base128IntLength(n) - 1; i >= 0; i-- { + o := byte(n >> uint(i*7)) + o &= 0x7f + if i != 0 { + o |= 0x80 + } + dst = append(dst, o) + } + return dst +} + +func base128BigIntLength(n *big.Int) int { + if n.Cmp(big.NewInt(0)) == 0 { + return 1 + } + return (n.BitLen() + 6) / 7 +} + +func appendBase128BigInt(dst []byte, n *big.Int) []byte { + if n.Cmp(big.NewInt(0)) == 0 { + return append(dst, 0) + } + + for i := base128BigIntLength(n) - 1; i >= 0; i-- { + o := byte(big.NewInt(0).Rsh(n, uint(i)*7).Bits()[0]) + o &= 0x7f + if i != 0 { + o |= 0x80 + } + dst = append(dst, o) + } + return dst +} + +// AppendText implements [encoding.TextAppender] +func (o OID) AppendText(b []byte) ([]byte, error) { + return append(b, o.String()...), nil +} + +// MarshalText implements [encoding.TextMarshaler] +func (o OID) MarshalText() ([]byte, error) { + return o.AppendText(nil) +} + +// UnmarshalText implements [encoding.TextUnmarshaler] +func (o *OID) UnmarshalText(text []byte) error { + return o.unmarshalOIDText(string(text)) +} + +// cutString slices s around the first instance of sep, +// returning the text before and after sep. +// The found result reports whether sep appears in s. +// If sep does not appear in s, cut returns s, "", false. +func cutString(s, sep string) (before, after string, found bool) { + if i := strings.Index(s, sep); i >= 0 { + return s[:i], s[i+len(sep):], true + } + return s, "", false +} + +func (o *OID) unmarshalOIDText(oid string) error { + // (*big.Int).SetString allows +/- signs, but we don't want + // to allow them in the string representation of Object Identifier, so + // reject such encodings. + for _, c := range oid { + isDigit := c >= '0' && c <= '9' + if !isDigit && c != '.' { + return errInvalidOID + } + } + + var ( + firstNum string + secondNum string + ) + + var nextComponentExists bool + firstNum, oid, nextComponentExists = cutString(oid, ".") + if !nextComponentExists { + return errInvalidOID + } + secondNum, oid, nextComponentExists = cutString(oid, ".") + + var ( + first = big.NewInt(0) + second = big.NewInt(0) + ) + + if _, ok := first.SetString(firstNum, 10); !ok { + return errInvalidOID + } + if _, ok := second.SetString(secondNum, 10); !ok { + return errInvalidOID + } + + if first.Cmp(big.NewInt(2)) > 0 || (first.Cmp(big.NewInt(2)) < 0 && second.Cmp(big.NewInt(40)) >= 0) { + return errInvalidOID + } + + firstComponent := first.Mul(first, big.NewInt(40)) + firstComponent.Add(firstComponent, second) + + der := appendBase128BigInt(make([]byte, 0, 32), firstComponent) + + for nextComponentExists { + var strNum string + strNum, oid, nextComponentExists = cutString(oid, ".") + b, ok := big.NewInt(0).SetString(strNum, 10) + if !ok { + return errInvalidOID + } + der = appendBase128BigInt(der, b) + } + + o.der = der + return nil +} + +// AppendBinary implements [encoding.BinaryAppender] +func (o OID) AppendBinary(b []byte) ([]byte, error) { + return append(b, o.der...), nil +} + +// MarshalBinary implements [encoding.BinaryMarshaler] +func (o OID) MarshalBinary() ([]byte, error) { + return o.AppendBinary(nil) +} + +// cloneBytes returns a copy of b[:len(b)]. +// The result may have additional unused capacity. +// Clone(nil) returns nil. +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } + return append([]byte{}, b...) +} + +// UnmarshalBinary implements [encoding.BinaryUnmarshaler] +func (o *OID) UnmarshalBinary(b []byte) error { + oid, ok := newOIDFromDER(cloneBytes(b)) + if !ok { + return errInvalidOID + } + *o = oid + return nil +} + +// Equal returns true when oid and other represents the same Object Identifier. +func (oid OID) Equal(other OID) bool { + // There is only one possible DER encoding of + // each unique Object Identifier. + return bytes.Equal(oid.der, other.der) +} + +func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, failed bool) { + offset = initOffset + var ret64 int64 + for shifted := 0; offset < len(bytes); shifted++ { + // 5 * 7 bits per byte == 35 bits of data + // Thus the representation is either non-minimal or too large for an int32 + if shifted == 5 { + failed = true + return + } + ret64 <<= 7 + b := bytes[offset] + // integers should be minimally encoded, so the leading octet should + // never be 0x80 + if shifted == 0 && b == 0x80 { + failed = true + return + } + ret64 |= int64(b & 0x7f) + offset++ + if b&0x80 == 0 { + ret = int(ret64) + // Ensure that the returned value fits in an int on all platforms + if ret64 > math.MaxInt32 { + failed = true + } + return + } + } + failed = true + return +} + +// EqualASN1OID returns whether an OID equals an asn1.ObjectIdentifier. If +// asn1.ObjectIdentifier cannot represent the OID specified by oid, because +// a component of OID requires more than 31 bits, it returns false. +func (oid OID) EqualASN1OID(other asn1.ObjectIdentifier) bool { + if len(other) < 2 { + return false + } + v, offset, failed := parseBase128Int(oid.der, 0) + if failed { + // This should never happen, since we've already parsed the OID, + // but just in case. + return false + } + if v < 80 { + a, b := v/40, v%40 + if other[0] != a || other[1] != b { + return false + } + } else { + a, b := 2, v-80 + if other[0] != a || other[1] != b { + return false + } + } + + i := 2 + for ; offset < len(oid.der); i++ { + v, offset, failed = parseBase128Int(oid.der, offset) + if failed { + // Again, shouldn't happen, since we've already parsed + // the OID, but better safe than sorry. + return false + } + if i >= len(other) || v != other[i] { + return false + } + } + + return i == len(other) +} + +// Strings returns the string representation of the Object Identifier. +func (oid OID) String() string { + var b strings.Builder + b.Grow(32) + const ( + valSize = 64 // size in bits of val. + bitsPerByte = 7 + maxValSafeShift = (1 << (valSize - bitsPerByte)) - 1 + ) + var ( + start = 0 + val = uint64(0) + numBuf = make([]byte, 0, 21) + bigVal *big.Int + overflow bool + ) + for i, v := range oid.der { + curVal := v & 0x7F + valEnd := v&0x80 == 0 + if valEnd { + if start != 0 { + b.WriteByte('.') + } + } + if !overflow && val > maxValSafeShift { + if bigVal == nil { + bigVal = new(big.Int) + } + bigVal = bigVal.SetUint64(val) + overflow = true + } + if overflow { + bigVal = bigVal.Lsh(bigVal, bitsPerByte).Or(bigVal, big.NewInt(int64(curVal))) + if valEnd { + if start == 0 { + b.WriteString("2.") + bigVal = bigVal.Sub(bigVal, big.NewInt(80)) + } + numBuf = bigVal.Append(numBuf, 10) + b.Write(numBuf) + numBuf = numBuf[:0] + val = 0 + start = i + 1 + overflow = false + } + continue + } + val <<= bitsPerByte + val |= uint64(curVal) + if valEnd { + if start == 0 { + if val < 80 { + b.Write(strconv.AppendUint(numBuf, val/40, 10)) + b.WriteByte('.') + b.Write(strconv.AppendUint(numBuf, val%40, 10)) + } else { + b.WriteString("2.") + b.Write(strconv.AppendUint(numBuf, val-80, 10)) + } + } else { + b.Write(strconv.AppendUint(numBuf, val, 10)) + } + val = 0 + start = i + 1 + } + } + return b.String() +} diff --git a/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/parser.go b/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/parser.go new file mode 100644 index 0000000000000..ec57e79f6ebf8 --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/parser.go @@ -0,0 +1,1027 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package legacyx509 + +import ( + "bytes" + "crypto/dsa" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "fmt" + "math/big" + "net" + "net/url" + "strconv" + "strings" + "time" + "unicode/utf16" + "unicode/utf8" + + "golang.org/x/crypto/cryptobyte" + cryptobyte_asn1 "golang.org/x/crypto/cryptobyte/asn1" + + stdx509 "crypto/x509" +) + +// ParseCertificates parses one or more certificates from the given ASN.1 DER +// data. The certificates must be concatenated with no intermediate padding. +func ParseCertificates(der []byte) ([]*stdx509.Certificate, error) { + var certs []*stdx509.Certificate + for len(der) > 0 { + cert, err := parseCertificate(der) + if err != nil { + return nil, err + } + certs = append(certs, cert) + der = der[len(cert.Raw):] + } + return certs, nil +} + +// isPrintable reports whether the given b is in the ASN.1 PrintableString set. +// This is a simplified version of encoding/asn1.isPrintable. +func isPrintable(b byte) bool { + return 'a' <= b && b <= 'z' || + 'A' <= b && b <= 'Z' || + '0' <= b && b <= '9' || + '\'' <= b && b <= ')' || + '+' <= b && b <= '/' || + b == ' ' || + b == ':' || + b == '=' || + b == '?' || + // This is technically not allowed in a PrintableString. + // However, x509 certificates with wildcard strings don't + // always use the correct string type so we permit it. + b == '*' || + // This is not technically allowed either. However, not + // only is it relatively common, but there are also a + // handful of CA certificates that contain it. At least + // one of which will not expire until 2027. + b == '&' +} + +// parseASN1String parses the ASN.1 string types T61String, PrintableString, +// UTF8String, BMPString, IA5String, and NumericString. This is mostly copied +// from the respective encoding/asn1.parse... methods, rather than just +// increasing the API surface of that package. +func parseASN1String(tag cryptobyte_asn1.Tag, value []byte) (string, error) { + switch tag { + case cryptobyte_asn1.T61String: + return string(value), nil + case cryptobyte_asn1.PrintableString: + for _, b := range value { + if !isPrintable(b) { + return "", errors.New("invalid PrintableString") + } + } + return string(value), nil + case cryptobyte_asn1.UTF8String: + if !utf8.Valid(value) { + return "", errors.New("invalid UTF-8 string") + } + return string(value), nil + case cryptobyte_asn1.Tag(asn1.TagBMPString): + if len(value)%2 != 0 { + return "", errors.New("invalid BMPString") + } + + // Strip terminator if present. + if l := len(value); l >= 2 && value[l-1] == 0 && value[l-2] == 0 { + value = value[:l-2] + } + + s := make([]uint16, 0, len(value)/2) + for len(value) > 0 { + s = append(s, uint16(value[0])<<8+uint16(value[1])) + value = value[2:] + } + + return string(utf16.Decode(s)), nil + case cryptobyte_asn1.IA5String: + s := string(value) + if isIA5String(s) != nil { + return "", errors.New("invalid IA5String") + } + return s, nil + case cryptobyte_asn1.Tag(asn1.TagNumericString): + for _, b := range value { + if !('0' <= b && b <= '9' || b == ' ') { + return "", errors.New("invalid NumericString") + } + } + return string(value), nil + } + return "", fmt.Errorf("unsupported string type: %v", tag) +} + +// parseName parses a DER encoded Name as defined in RFC 5280. We may +// want to export this function in the future for use in crypto/tls. +func parseName(raw cryptobyte.String) (*pkix.RDNSequence, error) { + if !raw.ReadASN1(&raw, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: invalid RDNSequence") + } + + var rdnSeq pkix.RDNSequence + for !raw.Empty() { + var rdnSet pkix.RelativeDistinguishedNameSET + var set cryptobyte.String + if !raw.ReadASN1(&set, cryptobyte_asn1.SET) { + return nil, errors.New("x509: invalid RDNSequence") + } + for !set.Empty() { + var atav cryptobyte.String + if !set.ReadASN1(&atav, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: invalid RDNSequence: invalid attribute") + } + var attr pkix.AttributeTypeAndValue + if !atav.ReadASN1ObjectIdentifier(&attr.Type) { + return nil, errors.New("x509: invalid RDNSequence: invalid attribute type") + } + var rawValue cryptobyte.String + var valueTag cryptobyte_asn1.Tag + if !atav.ReadAnyASN1(&rawValue, &valueTag) { + return nil, errors.New("x509: invalid RDNSequence: invalid attribute value") + } + var err error + attr.Value, err = parseASN1String(valueTag, rawValue) + if err != nil { + return nil, fmt.Errorf("x509: invalid RDNSequence: invalid attribute value: %s", err) + } + rdnSet = append(rdnSet, attr) + } + + rdnSeq = append(rdnSeq, rdnSet) + } + + return &rdnSeq, nil +} + +func parseAI(der cryptobyte.String) (pkix.AlgorithmIdentifier, error) { + ai := pkix.AlgorithmIdentifier{} + if !der.ReadASN1ObjectIdentifier(&ai.Algorithm) { + return ai, errors.New("x509: malformed OID") + } + if der.Empty() { + return ai, nil + } + var params cryptobyte.String + var tag cryptobyte_asn1.Tag + if !der.ReadAnyASN1Element(¶ms, &tag) { + return ai, errors.New("x509: malformed parameters") + } + ai.Parameters.Tag = int(tag) + ai.Parameters.FullBytes = params + return ai, nil +} + +func parseTime(der *cryptobyte.String) (time.Time, error) { + var t time.Time + switch { + case der.PeekASN1Tag(cryptobyte_asn1.UTCTime): + if !der.ReadASN1UTCTime(&t) { + return t, errors.New("x509: malformed UTCTime") + } + case der.PeekASN1Tag(cryptobyte_asn1.GeneralizedTime): + if !der.ReadASN1GeneralizedTime(&t) { + return t, errors.New("x509: malformed GeneralizedTime") + } + default: + return t, errors.New("x509: unsupported time format") + } + return t, nil +} + +func parseValidity(der cryptobyte.String) (time.Time, time.Time, error) { + notBefore, err := parseTime(&der) + if err != nil { + return time.Time{}, time.Time{}, err + } + notAfter, err := parseTime(&der) + if err != nil { + return time.Time{}, time.Time{}, err + } + + return notBefore, notAfter, nil +} + +func parseExtension(der cryptobyte.String) (pkix.Extension, error) { + var ext pkix.Extension + if !der.ReadASN1ObjectIdentifier(&ext.Id) { + return ext, errors.New("x509: malformed extension OID field") + } + if der.PeekASN1Tag(cryptobyte_asn1.BOOLEAN) { + if !der.ReadASN1Boolean(&ext.Critical) { + return ext, errors.New("x509: malformed extension critical field") + } + } + var val cryptobyte.String + if !der.ReadASN1(&val, cryptobyte_asn1.OCTET_STRING) { + return ext, errors.New("x509: malformed extension value field") + } + ext.Value = val + return ext, nil +} + +func parsePublicKey(keyData *publicKeyInfo) (interface{}, error) { + oid := keyData.Algorithm.Algorithm + params := keyData.Algorithm.Parameters + der := cryptobyte.String(keyData.PublicKey.RightAlign()) + switch { + case oid.Equal(oidPublicKeyRSA): + // RSA public keys must have a NULL in the parameters. + // See RFC 3279, Section 2.3.1. + if !bytes.Equal(params.FullBytes, asn1.NullBytes) { + return nil, errors.New("x509: RSA key missing NULL parameters") + } + + p := &pkcs1PublicKey{N: new(big.Int)} + if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: invalid RSA public key") + } + if !der.ReadASN1Integer(p.N) { + return nil, errors.New("x509: invalid RSA modulus") + } + if !der.ReadASN1Integer(&p.E) { + return nil, errors.New("x509: invalid RSA public exponent") + } + + if p.N.Sign() <= 0 { + return nil, errors.New("x509: RSA modulus is not a positive number") + } + if p.E <= 0 { + return nil, errors.New("x509: RSA public exponent is not a positive number") + } + + pub := &rsa.PublicKey{ + E: p.E, + N: p.N, + } + return pub, nil + case oid.Equal(oidPublicKeyECDSA): + paramsDer := cryptobyte.String(params.FullBytes) + namedCurveOID := new(asn1.ObjectIdentifier) + if !paramsDer.ReadASN1ObjectIdentifier(namedCurveOID) { + return nil, errors.New("x509: invalid ECDSA parameters") + } + namedCurve := namedCurveFromOID(*namedCurveOID) + if namedCurve == nil { + return nil, errors.New("x509: unsupported elliptic curve") + } + x, y := elliptic.Unmarshal(namedCurve, der) + if x == nil { + return nil, errors.New("x509: failed to unmarshal elliptic curve point") + } + pub := &ecdsa.PublicKey{ + Curve: namedCurve, + X: x, + Y: y, + } + return pub, nil + case oid.Equal(oidPublicKeyEd25519): + // RFC 8410, Section 3 + // > For all of the OIDs, the parameters MUST be absent. + if len(params.FullBytes) != 0 { + return nil, errors.New("x509: Ed25519 key encoded with illegal parameters") + } + if len(der) != ed25519.PublicKeySize { + return nil, errors.New("x509: wrong Ed25519 public key size") + } + return ed25519.PublicKey(der), nil + // case oid.Equal(oidPublicKeyX25519): + // // RFC 8410, Section 3 + // // > For all of the OIDs, the parameters MUST be absent. + // if len(params.FullBytes) != 0 { + // return nil, errors.New("x509: X25519 key encoded with illegal parameters") + // } + // return ecdh.X25519().NewPublicKey(der) + case oid.Equal(oidPublicKeyDSA): + y := new(big.Int) + if !der.ReadASN1Integer(y) { + return nil, errors.New("x509: invalid DSA public key") + } + pub := &dsa.PublicKey{ + Y: y, + Parameters: dsa.Parameters{ + P: new(big.Int), + Q: new(big.Int), + G: new(big.Int), + }, + } + paramsDer := cryptobyte.String(params.FullBytes) + if !paramsDer.ReadASN1(¶msDer, cryptobyte_asn1.SEQUENCE) || + !paramsDer.ReadASN1Integer(pub.Parameters.P) || + !paramsDer.ReadASN1Integer(pub.Parameters.Q) || + !paramsDer.ReadASN1Integer(pub.Parameters.G) { + return nil, errors.New("x509: invalid DSA parameters") + } + if pub.Y.Sign() <= 0 || pub.Parameters.P.Sign() <= 0 || + pub.Parameters.Q.Sign() <= 0 || pub.Parameters.G.Sign() <= 0 { + return nil, errors.New("x509: zero or negative DSA parameter") + } + return pub, nil + default: + return nil, errors.New("x509: unknown public key algorithm") + } +} + +func parseKeyUsageExtension(der cryptobyte.String) (stdx509.KeyUsage, error) { + var usageBits asn1.BitString + if !der.ReadASN1BitString(&usageBits) { + return 0, errors.New("x509: invalid key usage") + } + + var usage int + for i := 0; i < 9; i++ { + if usageBits.At(i) != 0 { + usage |= 1 << uint(i) + } + } + return stdx509.KeyUsage(usage), nil +} + +func parseBasicConstraintsExtension(der cryptobyte.String) (bool, int, error) { + var isCA bool + if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) { + return false, 0, errors.New("x509: invalid basic constraints") + } + if der.PeekASN1Tag(cryptobyte_asn1.BOOLEAN) { + if !der.ReadASN1Boolean(&isCA) { + return false, 0, errors.New("x509: invalid basic constraints") + } + } + maxPathLen := -1 + if der.PeekASN1Tag(cryptobyte_asn1.INTEGER) { + if !der.ReadASN1Integer(&maxPathLen) { + return false, 0, errors.New("x509: invalid basic constraints") + } + } + + // TODO: map out.MaxPathLen to 0 if it has the -1 default value? (Issue 19285) + return isCA, maxPathLen, nil +} + +func forEachSAN(der cryptobyte.String, callback func(tag int, data []byte) error) error { + if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) { + return errors.New("x509: invalid subject alternative names") + } + for !der.Empty() { + var san cryptobyte.String + var tag cryptobyte_asn1.Tag + if !der.ReadAnyASN1(&san, &tag) { + return errors.New("x509: invalid subject alternative name") + } + if err := callback(int(tag^0x80), san); err != nil { + return err + } + } + + return nil +} + +func parseSANExtension(der cryptobyte.String) (dnsNames, emailAddresses []string, ipAddresses []net.IP, uris []*url.URL, err error) { + err = forEachSAN(der, func(tag int, data []byte) error { + switch tag { + case nameTypeEmail: + email := string(data) + if err := isIA5String(email); err != nil { + return errors.New("x509: SAN rfc822Name is malformed") + } + emailAddresses = append(emailAddresses, email) + case nameTypeDNS: + name := string(data) + if err := isIA5String(name); err != nil { + return errors.New("x509: SAN dNSName is malformed") + } + dnsNames = append(dnsNames, string(name)) + case nameTypeURI: + uriStr := string(data) + if err := isIA5String(uriStr); err != nil { + return errors.New("x509: SAN uniformResourceIdentifier is malformed") + } + uri, err := url.Parse(uriStr) + if err != nil { + return fmt.Errorf("x509: cannot parse URI %q: %s", uriStr, err) + } + if len(uri.Host) > 0 { + if _, ok := domainToReverseLabels(uri.Host); !ok { + return fmt.Errorf("x509: cannot parse URI %q: invalid domain", uriStr) + } + } + uris = append(uris, uri) + case nameTypeIP: + switch len(data) { + case net.IPv4len, net.IPv6len: + ipAddresses = append(ipAddresses, data) + default: + return errors.New("x509: cannot parse IP address of length " + strconv.Itoa(len(data))) + } + } + + return nil + }) + + return +} + +func parseAuthorityKeyIdentifier(e pkix.Extension) ([]byte, error) { + // RFC 5280, Section 4.2.1.1 + // if e.Critical { + // // Conforming CAs MUST mark this extension as non-critical + // return nil, errors.New("x509: authority key identifier incorrectly marked critical") + // } + val := cryptobyte.String(e.Value) + var akid cryptobyte.String + if !val.ReadASN1(&akid, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: invalid authority key identifier") + } + if akid.PeekASN1Tag(cryptobyte_asn1.Tag(0).ContextSpecific()) { + if !akid.ReadASN1(&akid, cryptobyte_asn1.Tag(0).ContextSpecific()) { + return nil, errors.New("x509: invalid authority key identifier") + } + return akid, nil + } + return nil, nil +} + +func parseExtKeyUsageExtension(der cryptobyte.String) ([]stdx509.ExtKeyUsage, []asn1.ObjectIdentifier, error) { + var extKeyUsages []stdx509.ExtKeyUsage + var unknownUsages []asn1.ObjectIdentifier + if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) { + return nil, nil, errors.New("x509: invalid extended key usages") + } + for !der.Empty() { + var eku asn1.ObjectIdentifier + if !der.ReadASN1ObjectIdentifier(&eku) { + return nil, nil, errors.New("x509: invalid extended key usages") + } + if extKeyUsage, ok := extKeyUsageFromOID(eku); ok { + extKeyUsages = append(extKeyUsages, stdx509.ExtKeyUsage(extKeyUsage)) + } else { + unknownUsages = append(unknownUsages, eku) + } + } + return extKeyUsages, unknownUsages, nil +} + +// func parseCertificatePoliciesExtension(der cryptobyte.String) ([]OID, error) { +// var oids []OID +// if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) { +// return nil, errors.New("x509: invalid certificate policies") +// } +// for !der.Empty() { +// var cp cryptobyte.String +// var OIDBytes cryptobyte.String +// if !der.ReadASN1(&cp, cryptobyte_asn1.SEQUENCE) || !cp.ReadASN1(&OIDBytes, cryptobyte_asn1.OBJECT_IDENTIFIER) { +// return nil, errors.New("x509: invalid certificate policies") +// } +// oid, ok := newOIDFromDER(OIDBytes) +// if !ok { +// return nil, errors.New("x509: invalid certificate policies") +// } +// oids = append(oids, oid) +// } +// return oids, nil +// } + +// isValidIPMask reports whether mask consists of zero or more 1 bits, followed by zero bits. +func isValidIPMask(mask []byte) bool { + seenZero := false + + for _, b := range mask { + if seenZero { + if b != 0 { + return false + } + + continue + } + + switch b { + case 0x00, 0x80, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc, 0xfe: + seenZero = true + case 0xff: + default: + return false + } + } + + return true +} + +func parseNameConstraintsExtension(out *stdx509.Certificate, e pkix.Extension) (unhandled bool, err error) { + // RFC 5280, 4.2.1.10 + + // NameConstraints ::= SEQUENCE { + // permittedSubtrees [0] GeneralSubtrees OPTIONAL, + // excludedSubtrees [1] GeneralSubtrees OPTIONAL } + // + // GeneralSubtrees ::= SEQUENCE SIZE (1..MAX) OF GeneralSubtree + // + // GeneralSubtree ::= SEQUENCE { + // base GeneralName, + // minimum [0] BaseDistance DEFAULT 0, + // maximum [1] BaseDistance OPTIONAL } + // + // BaseDistance ::= INTEGER (0..MAX) + + outer := cryptobyte.String(e.Value) + var toplevel, permitted, excluded cryptobyte.String + var havePermitted, haveExcluded bool + if !outer.ReadASN1(&toplevel, cryptobyte_asn1.SEQUENCE) || + !outer.Empty() || + !toplevel.ReadOptionalASN1(&permitted, &havePermitted, cryptobyte_asn1.Tag(0).ContextSpecific().Constructed()) || + !toplevel.ReadOptionalASN1(&excluded, &haveExcluded, cryptobyte_asn1.Tag(1).ContextSpecific().Constructed()) || + !toplevel.Empty() { + return false, errors.New("x509: invalid NameConstraints extension") + } + + if !havePermitted && !haveExcluded || len(permitted) == 0 && len(excluded) == 0 { + // From RFC 5280, Section 4.2.1.10: + // “either the permittedSubtrees field + // or the excludedSubtrees MUST be + // present” + return false, errors.New("x509: empty name constraints extension") + } + + getValues := func(subtrees cryptobyte.String) (dnsNames []string, ips []*net.IPNet, emails, uriDomains []string, err error) { + for !subtrees.Empty() { + var seq, value cryptobyte.String + var tag cryptobyte_asn1.Tag + if !subtrees.ReadASN1(&seq, cryptobyte_asn1.SEQUENCE) || + !seq.ReadAnyASN1(&value, &tag) { + return nil, nil, nil, nil, fmt.Errorf("x509: invalid NameConstraints extension") + } + + var ( + dnsTag = cryptobyte_asn1.Tag(2).ContextSpecific() + emailTag = cryptobyte_asn1.Tag(1).ContextSpecific() + ipTag = cryptobyte_asn1.Tag(7).ContextSpecific() + uriTag = cryptobyte_asn1.Tag(6).ContextSpecific() + ) + + switch tag { + case dnsTag: + domain := string(value) + if err := isIA5String(domain); err != nil { + return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error()) + } + + trimmedDomain := domain + if len(trimmedDomain) > 0 && trimmedDomain[0] == '.' { + // constraints can have a leading + // period to exclude the domain + // itself, but that's not valid in a + // normal domain name. + trimmedDomain = trimmedDomain[1:] + } + if _, ok := domainToReverseLabels(trimmedDomain); !ok { + return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse dnsName constraint %q", domain) + } + dnsNames = append(dnsNames, domain) + + case ipTag: + l := len(value) + var ip, mask []byte + + switch l { + case 8: + ip = value[:4] + mask = value[4:] + + case 32: + ip = value[:16] + mask = value[16:] + + default: + return nil, nil, nil, nil, fmt.Errorf("x509: IP constraint contained value of length %d", l) + } + + if !isValidIPMask(mask) { + return nil, nil, nil, nil, fmt.Errorf("x509: IP constraint contained invalid mask %x", mask) + } + + ips = append(ips, &net.IPNet{IP: net.IP(ip), Mask: net.IPMask(mask)}) + + case emailTag: + constraint := string(value) + if err := isIA5String(constraint); err != nil { + return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error()) + } + + // If the constraint contains an @ then + // it specifies an exact mailbox name. + if strings.Contains(constraint, "@") { + if _, ok := parseRFC2821Mailbox(constraint); !ok { + return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse rfc822Name constraint %q", constraint) + } + } else { + // Otherwise it's a domain name. + domain := constraint + if len(domain) > 0 && domain[0] == '.' { + domain = domain[1:] + } + if _, ok := domainToReverseLabels(domain); !ok { + return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse rfc822Name constraint %q", constraint) + } + } + emails = append(emails, constraint) + + case uriTag: + domain := string(value) + if err := isIA5String(domain); err != nil { + return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error()) + } + + if net.ParseIP(domain) != nil { + return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse URI constraint %q: cannot be IP address", domain) + } + + trimmedDomain := domain + if len(trimmedDomain) > 0 && trimmedDomain[0] == '.' { + // constraints can have a leading + // period to exclude the domain itself, + // but that's not valid in a normal + // domain name. + trimmedDomain = trimmedDomain[1:] + } + if _, ok := domainToReverseLabels(trimmedDomain); !ok { + return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse URI constraint %q", domain) + } + uriDomains = append(uriDomains, domain) + + default: + unhandled = true + } + } + + return dnsNames, ips, emails, uriDomains, nil + } + + if out.PermittedDNSDomains, out.PermittedIPRanges, out.PermittedEmailAddresses, out.PermittedURIDomains, err = getValues(permitted); err != nil { + return false, err + } + if out.ExcludedDNSDomains, out.ExcludedIPRanges, out.ExcludedEmailAddresses, out.ExcludedURIDomains, err = getValues(excluded); err != nil { + return false, err + } + out.PermittedDNSDomainsCritical = e.Critical + + return unhandled, nil +} + +func processExtensions(out *stdx509.Certificate) error { + var err error + for _, e := range out.Extensions { + unhandled := false + + if len(e.Id) == 4 && e.Id[0] == 2 && e.Id[1] == 5 && e.Id[2] == 29 { + switch e.Id[3] { + case 15: + out.KeyUsage, err = parseKeyUsageExtension(e.Value) + if err != nil { + return err + } + case 19: + out.IsCA, out.MaxPathLen, err = parseBasicConstraintsExtension(e.Value) + if err != nil { + return err + } + out.BasicConstraintsValid = true + out.MaxPathLenZero = out.MaxPathLen == 0 + case 17: + out.DNSNames, out.EmailAddresses, out.IPAddresses, out.URIs, err = parseSANExtension(e.Value) + if err != nil { + return err + } + + if len(out.DNSNames) == 0 && len(out.EmailAddresses) == 0 && len(out.IPAddresses) == 0 && len(out.URIs) == 0 { + // If we didn't parse anything then we do the critical check, below. + unhandled = true + } + + case 30: + unhandled, err = parseNameConstraintsExtension(out, e) + if err != nil { + return err + } + + case 31: + // RFC 5280, 4.2.1.13 + + // CRLDistributionPoints ::= SEQUENCE SIZE (1..MAX) OF DistributionPoint + // + // DistributionPoint ::= SEQUENCE { + // distributionPoint [0] DistributionPointName OPTIONAL, + // reasons [1] ReasonFlags OPTIONAL, + // cRLIssuer [2] GeneralNames OPTIONAL } + // + // DistributionPointName ::= CHOICE { + // fullName [0] GeneralNames, + // nameRelativeToCRLIssuer [1] RelativeDistinguishedName } + val := cryptobyte.String(e.Value) + if !val.ReadASN1(&val, cryptobyte_asn1.SEQUENCE) { + return errors.New("x509: invalid CRL distribution points") + } + for !val.Empty() { + var dpDER cryptobyte.String + if !val.ReadASN1(&dpDER, cryptobyte_asn1.SEQUENCE) { + return errors.New("x509: invalid CRL distribution point") + } + var dpNameDER cryptobyte.String + var dpNamePresent bool + if !dpDER.ReadOptionalASN1(&dpNameDER, &dpNamePresent, cryptobyte_asn1.Tag(0).Constructed().ContextSpecific()) { + return errors.New("x509: invalid CRL distribution point") + } + if !dpNamePresent { + continue + } + if !dpNameDER.ReadASN1(&dpNameDER, cryptobyte_asn1.Tag(0).Constructed().ContextSpecific()) { + return errors.New("x509: invalid CRL distribution point") + } + for !dpNameDER.Empty() { + if !dpNameDER.PeekASN1Tag(cryptobyte_asn1.Tag(6).ContextSpecific()) { + break + } + var uri cryptobyte.String + if !dpNameDER.ReadASN1(&uri, cryptobyte_asn1.Tag(6).ContextSpecific()) { + return errors.New("x509: invalid CRL distribution point") + } + out.CRLDistributionPoints = append(out.CRLDistributionPoints, string(uri)) + } + } + + case 35: + out.AuthorityKeyId, err = parseAuthorityKeyIdentifier(e) + if err != nil { + return err + } + case 37: + out.ExtKeyUsage, out.UnknownExtKeyUsage, err = parseExtKeyUsageExtension(e.Value) + if err != nil { + return err + } + case 14: + // RFC 5280, 4.2.1.2 + if e.Critical { + // Conforming CAs MUST mark this extension as non-critical + return errors.New("x509: subject key identifier incorrectly marked critical") + } + val := cryptobyte.String(e.Value) + var skid cryptobyte.String + if !val.ReadASN1(&skid, cryptobyte_asn1.OCTET_STRING) { + return errors.New("x509: invalid subject key identifier") + } + out.SubjectKeyId = skid + // case 32: + // out.Policies, err = parseCertificatePoliciesExtension(e.Value) + // if err != nil { + // return err + // } + // out.PolicyIdentifiers = make([]asn1.ObjectIdentifier, 0, len(out.Policies)) + // for _, oid := range out.Policies { + // if oid, ok := oid.toASN1OID(); ok { + // out.PolicyIdentifiers = append(out.PolicyIdentifiers, oid) + // } + // } + default: + // Unknown extensions are recorded if critical. + unhandled = true + } + } else if e.Id.Equal(oidExtensionAuthorityInfoAccess) { + // RFC 5280 4.2.2.1: Authority Information Access + if e.Critical { + // Conforming CAs MUST mark this extension as non-critical + return errors.New("x509: authority info access incorrectly marked critical") + } + val := cryptobyte.String(e.Value) + if !val.ReadASN1(&val, cryptobyte_asn1.SEQUENCE) { + return errors.New("x509: invalid authority info access") + } + for !val.Empty() { + var aiaDER cryptobyte.String + if !val.ReadASN1(&aiaDER, cryptobyte_asn1.SEQUENCE) { + return errors.New("x509: invalid authority info access") + } + var method asn1.ObjectIdentifier + if !aiaDER.ReadASN1ObjectIdentifier(&method) { + return errors.New("x509: invalid authority info access") + } + if !aiaDER.PeekASN1Tag(cryptobyte_asn1.Tag(6).ContextSpecific()) { + continue + } + if !aiaDER.ReadASN1(&aiaDER, cryptobyte_asn1.Tag(6).ContextSpecific()) { + return errors.New("x509: invalid authority info access") + } + switch { + case method.Equal(oidAuthorityInfoAccessOcsp): + out.OCSPServer = append(out.OCSPServer, string(aiaDER)) + case method.Equal(oidAuthorityInfoAccessIssuers): + out.IssuingCertificateURL = append(out.IssuingCertificateURL, string(aiaDER)) + } + } + } else { + // Unknown extensions are recorded if critical. + unhandled = true + } + + if e.Critical && unhandled { + out.UnhandledCriticalExtensions = append(out.UnhandledCriticalExtensions, e.Id) + } + } + + return nil +} + +var x509negativeserial = legacyGodebugSetting(0) // replaces godebug.New("x509negativeserial") + +func parseCertificate(der []byte) (*stdx509.Certificate, error) { + cert := &stdx509.Certificate{} + + input := cryptobyte.String(der) + // we read the SEQUENCE including length and tag bytes so that + // we can populate Certificate.Raw, before unwrapping the + // SEQUENCE so it can be operated on + if !input.ReadASN1Element(&input, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: malformed certificate") + } + cert.Raw = input + if !input.ReadASN1(&input, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: malformed certificate") + } + + var tbs cryptobyte.String + // do the same trick again as above to extract the raw + // bytes for Certificate.RawTBSCertificate + if !input.ReadASN1Element(&tbs, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: malformed tbs certificate") + } + cert.RawTBSCertificate = tbs + if !tbs.ReadASN1(&tbs, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: malformed tbs certificate") + } + + if !tbs.ReadOptionalASN1Integer(&cert.Version, cryptobyte_asn1.Tag(0).Constructed().ContextSpecific(), 0) { + return nil, errors.New("x509: malformed version") + } + if cert.Version < 0 { + return nil, errors.New("x509: malformed version") + } + // for backwards compat reasons Version is one-indexed, + // rather than zero-indexed as defined in 5280 + cert.Version++ + if cert.Version > 3 { + return nil, errors.New("x509: invalid version") + } + + serial := new(big.Int) + if !tbs.ReadASN1Integer(serial) { + return nil, errors.New("x509: malformed serial number") + } + if serial.Sign() == -1 { + if x509negativeserial.Value() != "1" { + return nil, errors.New("x509: negative serial number") + } else { + x509negativeserial.IncNonDefault() + } + } + cert.SerialNumber = serial + + var sigAISeq cryptobyte.String + if !tbs.ReadASN1(&sigAISeq, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: malformed signature algorithm identifier") + } + // Before parsing the inner algorithm identifier, extract + // the outer algorithm identifier and make sure that they + // match. + var outerSigAISeq cryptobyte.String + if !input.ReadASN1(&outerSigAISeq, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: malformed algorithm identifier") + } + if !bytes.Equal(outerSigAISeq, sigAISeq) { + return nil, errors.New("x509: inner and outer signature algorithm identifiers don't match") + } + sigAI, err := parseAI(sigAISeq) + if err != nil { + return nil, err + } + cert.SignatureAlgorithm = getSignatureAlgorithmFromAI(sigAI) + + var issuerSeq cryptobyte.String + if !tbs.ReadASN1Element(&issuerSeq, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: malformed issuer") + } + cert.RawIssuer = issuerSeq + issuerRDNs, err := parseName(issuerSeq) + if err != nil { + return nil, err + } + cert.Issuer.FillFromRDNSequence(issuerRDNs) + + var validity cryptobyte.String + if !tbs.ReadASN1(&validity, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: malformed validity") + } + cert.NotBefore, cert.NotAfter, err = parseValidity(validity) + if err != nil { + return nil, err + } + + var subjectSeq cryptobyte.String + if !tbs.ReadASN1Element(&subjectSeq, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: malformed issuer") + } + cert.RawSubject = subjectSeq + subjectRDNs, err := parseName(subjectSeq) + if err != nil { + return nil, err + } + cert.Subject.FillFromRDNSequence(subjectRDNs) + + var spki cryptobyte.String + if !tbs.ReadASN1Element(&spki, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: malformed spki") + } + cert.RawSubjectPublicKeyInfo = spki + if !spki.ReadASN1(&spki, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: malformed spki") + } + var pkAISeq cryptobyte.String + if !spki.ReadASN1(&pkAISeq, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: malformed public key algorithm identifier") + } + pkAI, err := parseAI(pkAISeq) + if err != nil { + return nil, err + } + cert.PublicKeyAlgorithm = getPublicKeyAlgorithmFromOID(pkAI.Algorithm) + var spk asn1.BitString + if !spki.ReadASN1BitString(&spk) { + return nil, errors.New("x509: malformed subjectPublicKey") + } + if cert.PublicKeyAlgorithm != stdx509.UnknownPublicKeyAlgorithm { + cert.PublicKey, err = parsePublicKey(&publicKeyInfo{ + Algorithm: pkAI, + PublicKey: spk, + }) + if err != nil { + return nil, err + } + } + + if cert.Version > 1 { + if !tbs.SkipOptionalASN1(cryptobyte_asn1.Tag(1).ContextSpecific()) { + return nil, errors.New("x509: malformed issuerUniqueID") + } + if !tbs.SkipOptionalASN1(cryptobyte_asn1.Tag(2).ContextSpecific()) { + return nil, errors.New("x509: malformed subjectUniqueID") + } + if cert.Version == 3 { + var extensions cryptobyte.String + var present bool + if !tbs.ReadOptionalASN1(&extensions, &present, cryptobyte_asn1.Tag(3).Constructed().ContextSpecific()) { + return nil, errors.New("x509: malformed extensions") + } + if present { + seenExts := make(map[string]bool) + if !extensions.ReadASN1(&extensions, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: malformed extensions") + } + for !extensions.Empty() { + var extension cryptobyte.String + if !extensions.ReadASN1(&extension, cryptobyte_asn1.SEQUENCE) { + return nil, errors.New("x509: malformed extension") + } + ext, err := parseExtension(extension) + if err != nil { + return nil, err + } + oidStr := ext.Id.String() + if seenExts[oidStr] { + return nil, fmt.Errorf("x509: certificate contains duplicate extension with OID %q", oidStr) + } + seenExts[oidStr] = true + cert.Extensions = append(cert.Extensions, ext) + } + err = processExtensions(cert) + if err != nil { + return nil, err + } + } + } + } + + var signature asn1.BitString + if !input.ReadASN1BitString(&signature) { + return nil, errors.New("x509: malformed signature") + } + cert.Signature = signature.RightAlign() + + return cert, nil +} diff --git a/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/pkcs1.go b/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/pkcs1.go new file mode 100644 index 0000000000000..da3c38a4e4bd8 --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/pkcs1.go @@ -0,0 +1,15 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package legacyx509 + +import ( + "math/big" +) + +// pkcs1PublicKey reflects the ASN.1 structure of a PKCS #1 public key. +type pkcs1PublicKey struct { + N *big.Int + E int +} diff --git a/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/verify.go b/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/verify.go new file mode 100644 index 0000000000000..901e3ba85b444 --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/verify.go @@ -0,0 +1,193 @@ +package legacyx509 + +import ( + "bytes" + "strings" +) + +// rfc2821Mailbox represents a “mailbox” (which is an email address to most +// people) by breaking it into the “local” (i.e. before the '@') and “domain” +// parts. +type rfc2821Mailbox struct { + local, domain string +} + +// parseRFC2821Mailbox parses an email address into local and domain parts, +// based on the ABNF for a “Mailbox” from RFC 2821. According to RFC 5280, +// Section 4.2.1.6 that's correct for an rfc822Name from a certificate: “The +// format of an rfc822Name is a "Mailbox" as defined in RFC 2821, Section 4.1.2”. +func parseRFC2821Mailbox(in string) (mailbox rfc2821Mailbox, ok bool) { + if len(in) == 0 { + return mailbox, false + } + + localPartBytes := make([]byte, 0, len(in)/2) + + if in[0] == '"' { + // Quoted-string = DQUOTE *qcontent DQUOTE + // non-whitespace-control = %d1-8 / %d11 / %d12 / %d14-31 / %d127 + // qcontent = qtext / quoted-pair + // qtext = non-whitespace-control / + // %d33 / %d35-91 / %d93-126 + // quoted-pair = ("\" text) / obs-qp + // text = %d1-9 / %d11 / %d12 / %d14-127 / obs-text + // + // (Names beginning with “obs-” are the obsolete syntax from RFC 2822, + // Section 4. Since it has been 16 years, we no longer accept that.) + in = in[1:] + QuotedString: + for { + if len(in) == 0 { + return mailbox, false + } + c := in[0] + in = in[1:] + + switch { + case c == '"': + break QuotedString + + case c == '\\': + // quoted-pair + if len(in) == 0 { + return mailbox, false + } + if in[0] == 11 || + in[0] == 12 || + (1 <= in[0] && in[0] <= 9) || + (14 <= in[0] && in[0] <= 127) { + localPartBytes = append(localPartBytes, in[0]) + in = in[1:] + } else { + return mailbox, false + } + + case c == 11 || + c == 12 || + // Space (char 32) is not allowed based on the + // BNF, but RFC 3696 gives an example that + // assumes that it is. Several “verified” + // errata continue to argue about this point. + // We choose to accept it. + c == 32 || + c == 33 || + c == 127 || + (1 <= c && c <= 8) || + (14 <= c && c <= 31) || + (35 <= c && c <= 91) || + (93 <= c && c <= 126): + // qtext + localPartBytes = append(localPartBytes, c) + + default: + return mailbox, false + } + } + } else { + // Atom ("." Atom)* + NextChar: + for len(in) > 0 { + // atext from RFC 2822, Section 3.2.4 + c := in[0] + + switch { + case c == '\\': + // Examples given in RFC 3696 suggest that + // escaped characters can appear outside of a + // quoted string. Several “verified” errata + // continue to argue the point. We choose to + // accept it. + in = in[1:] + if len(in) == 0 { + return mailbox, false + } + fallthrough + + case ('0' <= c && c <= '9') || + ('a' <= c && c <= 'z') || + ('A' <= c && c <= 'Z') || + c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || + c == '-' || c == '/' || c == '=' || c == '?' || + c == '^' || c == '_' || c == '`' || c == '{' || + c == '|' || c == '}' || c == '~' || c == '.': + localPartBytes = append(localPartBytes, in[0]) + in = in[1:] + + default: + break NextChar + } + } + + if len(localPartBytes) == 0 { + return mailbox, false + } + + // From RFC 3696, Section 3: + // “period (".") may also appear, but may not be used to start + // or end the local part, nor may two or more consecutive + // periods appear.” + twoDots := []byte{'.', '.'} + if localPartBytes[0] == '.' || + localPartBytes[len(localPartBytes)-1] == '.' || + bytes.Contains(localPartBytes, twoDots) { + return mailbox, false + } + } + + if len(in) == 0 || in[0] != '@' { + return mailbox, false + } + in = in[1:] + + // The RFC species a format for domains, but that's known to be + // violated in practice so we accept that anything after an '@' is the + // domain part. + if _, ok := domainToReverseLabels(in); !ok { + return mailbox, false + } + + mailbox.local = string(localPartBytes) + mailbox.domain = in + return mailbox, true +} + +// domainToReverseLabels converts a textual domain name like foo.example.com to +// the list of labels in reverse order, e.g. ["com", "example", "foo"]. +func domainToReverseLabels(domain string) (reverseLabels []string, ok bool) { + for len(domain) > 0 { + if i := strings.LastIndexByte(domain, '.'); i == -1 { + reverseLabels = append(reverseLabels, domain) + domain = "" + } else { + reverseLabels = append(reverseLabels, domain[i+1:]) + domain = domain[:i] + if i == 0 { // domain == "" + // domain is prefixed with an empty label, append an empty + // string to reverseLabels to indicate this. + reverseLabels = append(reverseLabels, "") + } + } + } + + if len(reverseLabels) > 0 && len(reverseLabels[0]) == 0 { + // An empty label at the end indicates an absolute value. + return nil, false + } + + for _, label := range reverseLabels { + if len(label) == 0 { + // Empty labels are otherwise invalid. + return nil, false + } + + for _, c := range label { + if c < 33 || c > 126 { + // Invalid character. + return nil, false + } + } + } + + return reverseLabels, true +} diff --git a/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/x509.go b/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/x509.go new file mode 100644 index 0000000000000..a4500bfb172e5 --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/internal/legacy/x509/x509.go @@ -0,0 +1,488 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package x509 implements a subset of the X.509 standard. +// +// It allows parsing and generating certificates, certificate signing +// requests, certificate revocation lists, and encoded public and private keys. +// It provides a certificate verifier, complete with a chain builder. +// +// The package targets the X.509 technical profile defined by the IETF (RFC +// 2459/3280/5280), and as further restricted by the CA/Browser Forum Baseline +// Requirements. There is minimal support for features outside of these +// profiles, as the primary goal of the package is to provide compatibility +// with the publicly trusted TLS certificate ecosystem and its policies and +// constraints. +// +// On macOS and Windows, certificate verification is handled by system APIs, but +// the package aims to apply consistent validation rules across operating +// systems. +package legacyx509 + +import ( + "bytes" + "crypto" + "crypto/elliptic" + stdx509 "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "fmt" + "strconv" + "unicode" + + // Explicitly import these for their crypto.RegisterHash init side-effects. + // Keep these as blank imports, even if they're imported above. + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" +) + +type publicKeyInfo struct { + Raw asn1.RawContent + Algorithm pkix.AlgorithmIdentifier + PublicKey asn1.BitString +} + +type SignatureAlgorithm int + +const ( + UnknownSignatureAlgorithm SignatureAlgorithm = iota + + MD2WithRSA // Unsupported. + MD5WithRSA // Only supported for signing, not verification. + SHA1WithRSA // Only supported for signing, and verification of CRLs, CSRs, and OCSP responses. + SHA256WithRSA + SHA384WithRSA + SHA512WithRSA + DSAWithSHA1 // Unsupported. + DSAWithSHA256 // Unsupported. + ECDSAWithSHA1 // Only supported for signing, and verification of CRLs, CSRs, and OCSP responses. + ECDSAWithSHA256 + ECDSAWithSHA384 + ECDSAWithSHA512 + SHA256WithRSAPSS + SHA384WithRSAPSS + SHA512WithRSAPSS + PureEd25519 +) + +func (algo SignatureAlgorithm) String() string { + for _, details := range signatureAlgorithmDetails { + if details.algo == algo { + return details.name + } + } + return strconv.Itoa(int(algo)) +} + +type PublicKeyAlgorithm int + +const ( + UnknownPublicKeyAlgorithm PublicKeyAlgorithm = iota + RSA + DSA // Only supported for parsing. + ECDSA + Ed25519 +) + +var publicKeyAlgoName = [...]string{ + RSA: "RSA", + DSA: "DSA", + ECDSA: "ECDSA", + Ed25519: "Ed25519", +} + +func (algo PublicKeyAlgorithm) String() string { + if 0 < algo && int(algo) < len(publicKeyAlgoName) { + return publicKeyAlgoName[algo] + } + return strconv.Itoa(int(algo)) +} + +// OIDs for signature algorithms +// +// pkcs-1 OBJECT IDENTIFIER ::= { +// iso(1) member-body(2) us(840) rsadsi(113549) pkcs(1) 1 } +// +// RFC 3279 2.2.1 RSA Signature Algorithms +// +// md5WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 4 } +// +// sha-1WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 5 } +// +// dsaWithSha1 OBJECT IDENTIFIER ::= { +// iso(1) member-body(2) us(840) x9-57(10040) x9cm(4) 3 } +// +// RFC 3279 2.2.3 ECDSA Signature Algorithm +// +// ecdsa-with-SHA1 OBJECT IDENTIFIER ::= { +// iso(1) member-body(2) us(840) ansi-x962(10045) +// signatures(4) ecdsa-with-SHA1(1)} +// +// RFC 4055 5 PKCS #1 Version 1.5 +// +// sha256WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 11 } +// +// sha384WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 12 } +// +// sha512WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 13 } +// +// RFC 5758 3.1 DSA Signature Algorithms +// +// dsaWithSha256 OBJECT IDENTIFIER ::= { +// joint-iso-ccitt(2) country(16) us(840) organization(1) gov(101) +// csor(3) algorithms(4) id-dsa-with-sha2(3) 2} +// +// RFC 5758 3.2 ECDSA Signature Algorithm +// +// ecdsa-with-SHA256 OBJECT IDENTIFIER ::= { iso(1) member-body(2) +// us(840) ansi-X9-62(10045) signatures(4) ecdsa-with-SHA2(3) 2 } +// +// ecdsa-with-SHA384 OBJECT IDENTIFIER ::= { iso(1) member-body(2) +// us(840) ansi-X9-62(10045) signatures(4) ecdsa-with-SHA2(3) 3 } +// +// ecdsa-with-SHA512 OBJECT IDENTIFIER ::= { iso(1) member-body(2) +// us(840) ansi-X9-62(10045) signatures(4) ecdsa-with-SHA2(3) 4 } +// +// RFC 8410 3 Curve25519 and Curve448 Algorithm Identifiers +// +// id-Ed25519 OBJECT IDENTIFIER ::= { 1 3 101 112 } +var ( + oidSignatureMD5WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 4} + oidSignatureSHA1WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 5} + oidSignatureSHA256WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 11} + oidSignatureSHA384WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 12} + oidSignatureSHA512WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 13} + oidSignatureRSAPSS = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 10} + oidSignatureDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 3} + oidSignatureDSAWithSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 3, 2} + oidSignatureECDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 1} + oidSignatureECDSAWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 2} + oidSignatureECDSAWithSHA384 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 3} + oidSignatureECDSAWithSHA512 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 4} + oidSignatureEd25519 = asn1.ObjectIdentifier{1, 3, 101, 112} + + oidSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 1} + oidSHA384 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 2} + oidSHA512 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 3} + + oidMGF1 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 8} + + // oidISOSignatureSHA1WithRSA means the same as oidSignatureSHA1WithRSA + // but it's specified by ISO. Microsoft's makecert.exe has been known + // to produce certificates with this OID. + oidISOSignatureSHA1WithRSA = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 29} +) + +var signatureAlgorithmDetails = []struct { + algo SignatureAlgorithm + name string + oid asn1.ObjectIdentifier + params asn1.RawValue + pubKeyAlgo PublicKeyAlgorithm + hash crypto.Hash + isRSAPSS bool +}{ + {MD5WithRSA, "MD5-RSA", oidSignatureMD5WithRSA, asn1.NullRawValue, RSA, crypto.MD5, false}, + {SHA1WithRSA, "SHA1-RSA", oidSignatureSHA1WithRSA, asn1.NullRawValue, RSA, crypto.SHA1, false}, + {SHA1WithRSA, "SHA1-RSA", oidISOSignatureSHA1WithRSA, asn1.NullRawValue, RSA, crypto.SHA1, false}, + {SHA256WithRSA, "SHA256-RSA", oidSignatureSHA256WithRSA, asn1.NullRawValue, RSA, crypto.SHA256, false}, + {SHA384WithRSA, "SHA384-RSA", oidSignatureSHA384WithRSA, asn1.NullRawValue, RSA, crypto.SHA384, false}, + {SHA512WithRSA, "SHA512-RSA", oidSignatureSHA512WithRSA, asn1.NullRawValue, RSA, crypto.SHA512, false}, + {SHA256WithRSAPSS, "SHA256-RSAPSS", oidSignatureRSAPSS, pssParametersSHA256, RSA, crypto.SHA256, true}, + {SHA384WithRSAPSS, "SHA384-RSAPSS", oidSignatureRSAPSS, pssParametersSHA384, RSA, crypto.SHA384, true}, + {SHA512WithRSAPSS, "SHA512-RSAPSS", oidSignatureRSAPSS, pssParametersSHA512, RSA, crypto.SHA512, true}, + {DSAWithSHA1, "DSA-SHA1", oidSignatureDSAWithSHA1, emptyRawValue, DSA, crypto.SHA1, false}, + {DSAWithSHA256, "DSA-SHA256", oidSignatureDSAWithSHA256, emptyRawValue, DSA, crypto.SHA256, false}, + {ECDSAWithSHA1, "ECDSA-SHA1", oidSignatureECDSAWithSHA1, emptyRawValue, ECDSA, crypto.SHA1, false}, + {ECDSAWithSHA256, "ECDSA-SHA256", oidSignatureECDSAWithSHA256, emptyRawValue, ECDSA, crypto.SHA256, false}, + {ECDSAWithSHA384, "ECDSA-SHA384", oidSignatureECDSAWithSHA384, emptyRawValue, ECDSA, crypto.SHA384, false}, + {ECDSAWithSHA512, "ECDSA-SHA512", oidSignatureECDSAWithSHA512, emptyRawValue, ECDSA, crypto.SHA512, false}, + {PureEd25519, "Ed25519", oidSignatureEd25519, emptyRawValue, Ed25519, crypto.Hash(0) /* no pre-hashing */, false}, +} + +var emptyRawValue = asn1.RawValue{} + +// DER encoded RSA PSS parameters for the +// SHA256, SHA384, and SHA512 hashes as defined in RFC 3447, Appendix A.2.3. +// The parameters contain the following values: +// - hashAlgorithm contains the associated hash identifier with NULL parameters +// - maskGenAlgorithm always contains the default mgf1SHA1 identifier +// - saltLength contains the length of the associated hash +// - trailerField always contains the default trailerFieldBC value +var ( + pssParametersSHA256 = asn1.RawValue{FullBytes: []byte{48, 52, 160, 15, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 1, 5, 0, 161, 28, 48, 26, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 8, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 1, 5, 0, 162, 3, 2, 1, 32}} + pssParametersSHA384 = asn1.RawValue{FullBytes: []byte{48, 52, 160, 15, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 2, 5, 0, 161, 28, 48, 26, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 8, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 2, 5, 0, 162, 3, 2, 1, 48}} + pssParametersSHA512 = asn1.RawValue{FullBytes: []byte{48, 52, 160, 15, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 3, 5, 0, 161, 28, 48, 26, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 8, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 3, 5, 0, 162, 3, 2, 1, 64}} +) + +// pssParameters reflects the parameters in an AlgorithmIdentifier that +// specifies RSA PSS. See RFC 3447, Appendix A.2.3. +type pssParameters struct { + // The following three fields are not marked as + // optional because the default values specify SHA-1, + // which is no longer suitable for use in signatures. + Hash pkix.AlgorithmIdentifier `asn1:"explicit,tag:0"` + MGF pkix.AlgorithmIdentifier `asn1:"explicit,tag:1"` + SaltLength int `asn1:"explicit,tag:2"` + TrailerField int `asn1:"optional,explicit,tag:3,default:1"` +} + +func getSignatureAlgorithmFromAI(ai pkix.AlgorithmIdentifier) stdx509.SignatureAlgorithm { + if ai.Algorithm.Equal(oidSignatureEd25519) { + // RFC 8410, Section 3 + // > For all of the OIDs, the parameters MUST be absent. + if len(ai.Parameters.FullBytes) != 0 { + return stdx509.UnknownSignatureAlgorithm + } + } + + if !ai.Algorithm.Equal(oidSignatureRSAPSS) { + for _, details := range signatureAlgorithmDetails { + if ai.Algorithm.Equal(details.oid) { + return stdx509.SignatureAlgorithm(details.algo) + } + } + return stdx509.UnknownSignatureAlgorithm + } + + // RSA PSS is special because it encodes important parameters + // in the Parameters. + + var params pssParameters + if _, err := asn1.Unmarshal(ai.Parameters.FullBytes, ¶ms); err != nil { + return stdx509.UnknownSignatureAlgorithm + } + + var mgf1HashFunc pkix.AlgorithmIdentifier + if _, err := asn1.Unmarshal(params.MGF.Parameters.FullBytes, &mgf1HashFunc); err != nil { + return stdx509.UnknownSignatureAlgorithm + } + + // PSS is greatly overburdened with options. This code forces them into + // three buckets by requiring that the MGF1 hash function always match the + // message hash function (as recommended in RFC 3447, Section 8.1), that the + // salt length matches the hash length, and that the trailer field has the + // default value. + if (len(params.Hash.Parameters.FullBytes) != 0 && !bytes.Equal(params.Hash.Parameters.FullBytes, asn1.NullBytes)) || + !params.MGF.Algorithm.Equal(oidMGF1) || + !mgf1HashFunc.Algorithm.Equal(params.Hash.Algorithm) || + (len(mgf1HashFunc.Parameters.FullBytes) != 0 && !bytes.Equal(mgf1HashFunc.Parameters.FullBytes, asn1.NullBytes)) || + params.TrailerField != 1 { + return stdx509.UnknownSignatureAlgorithm + } + + switch { + case params.Hash.Algorithm.Equal(oidSHA256) && params.SaltLength == 32: + return stdx509.SHA256WithRSAPSS + case params.Hash.Algorithm.Equal(oidSHA384) && params.SaltLength == 48: + return stdx509.SHA384WithRSAPSS + case params.Hash.Algorithm.Equal(oidSHA512) && params.SaltLength == 64: + return stdx509.SHA512WithRSAPSS + } + + return stdx509.UnknownSignatureAlgorithm +} + +var ( + // RFC 3279, 2.3 Public Key Algorithms + // + // pkcs-1 OBJECT IDENTIFIER ::== { iso(1) member-body(2) us(840) + // rsadsi(113549) pkcs(1) 1 } + // + // rsaEncryption OBJECT IDENTIFIER ::== { pkcs1-1 1 } + // + // id-dsa OBJECT IDENTIFIER ::== { iso(1) member-body(2) us(840) + // x9-57(10040) x9cm(4) 1 } + oidPublicKeyRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1} + oidPublicKeyDSA = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 1} + // RFC 5480, 2.1.1 Unrestricted Algorithm Identifier and Parameters + // + // id-ecPublicKey OBJECT IDENTIFIER ::= { + // iso(1) member-body(2) us(840) ansi-X9-62(10045) keyType(2) 1 } + oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} + // RFC 8410, Section 3 + // + // id-X25519 OBJECT IDENTIFIER ::= { 1 3 101 110 } + // id-Ed25519 OBJECT IDENTIFIER ::= { 1 3 101 112 } + oidPublicKeyX25519 = asn1.ObjectIdentifier{1, 3, 101, 110} + oidPublicKeyEd25519 = asn1.ObjectIdentifier{1, 3, 101, 112} +) + +// getPublicKeyAlgorithmFromOID returns the exposed PublicKeyAlgorithm +// identifier for public key types supported in certificates and CSRs. Marshal +// and Parse functions may support a different set of public key types. +func getPublicKeyAlgorithmFromOID(oid asn1.ObjectIdentifier) stdx509.PublicKeyAlgorithm { + switch { + case oid.Equal(oidPublicKeyRSA): + return stdx509.RSA + case oid.Equal(oidPublicKeyDSA): + return stdx509.DSA + case oid.Equal(oidPublicKeyECDSA): + return stdx509.ECDSA + case oid.Equal(oidPublicKeyEd25519): + return stdx509.Ed25519 + } + return stdx509.UnknownPublicKeyAlgorithm +} + +// RFC 5480, 2.1.1.1. Named Curve +// +// secp224r1 OBJECT IDENTIFIER ::= { +// iso(1) identified-organization(3) certicom(132) curve(0) 33 } +// +// secp256r1 OBJECT IDENTIFIER ::= { +// iso(1) member-body(2) us(840) ansi-X9-62(10045) curves(3) +// prime(1) 7 } +// +// secp384r1 OBJECT IDENTIFIER ::= { +// iso(1) identified-organization(3) certicom(132) curve(0) 34 } +// +// secp521r1 OBJECT IDENTIFIER ::= { +// iso(1) identified-organization(3) certicom(132) curve(0) 35 } +// +// NB: secp256r1 is equivalent to prime256v1 +var ( + oidNamedCurveP224 = asn1.ObjectIdentifier{1, 3, 132, 0, 33} + oidNamedCurveP256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7} + oidNamedCurveP384 = asn1.ObjectIdentifier{1, 3, 132, 0, 34} + oidNamedCurveP521 = asn1.ObjectIdentifier{1, 3, 132, 0, 35} +) + +func namedCurveFromOID(oid asn1.ObjectIdentifier) elliptic.Curve { + switch { + case oid.Equal(oidNamedCurveP224): + return elliptic.P224() + case oid.Equal(oidNamedCurveP256): + return elliptic.P256() + case oid.Equal(oidNamedCurveP384): + return elliptic.P384() + case oid.Equal(oidNamedCurveP521): + return elliptic.P521() + } + return nil +} + +// KeyUsage represents the set of actions that are valid for a given key. It's +// a bitmap of the KeyUsage* constants. +type KeyUsage int + +const ( + KeyUsageDigitalSignature KeyUsage = 1 << iota + KeyUsageContentCommitment + KeyUsageKeyEncipherment + KeyUsageDataEncipherment + KeyUsageKeyAgreement + KeyUsageCertSign + KeyUsageCRLSign + KeyUsageEncipherOnly + KeyUsageDecipherOnly +) + +// RFC 5280, 4.2.1.12 Extended Key Usage +// +// anyExtendedKeyUsage OBJECT IDENTIFIER ::= { id-ce-extKeyUsage 0 } +// +// id-kp OBJECT IDENTIFIER ::= { id-pkix 3 } +// +// id-kp-serverAuth OBJECT IDENTIFIER ::= { id-kp 1 } +// id-kp-clientAuth OBJECT IDENTIFIER ::= { id-kp 2 } +// id-kp-codeSigning OBJECT IDENTIFIER ::= { id-kp 3 } +// id-kp-emailProtection OBJECT IDENTIFIER ::= { id-kp 4 } +// id-kp-timeStamping OBJECT IDENTIFIER ::= { id-kp 8 } +// id-kp-OCSPSigning OBJECT IDENTIFIER ::= { id-kp 9 } +var ( + oidExtKeyUsageAny = asn1.ObjectIdentifier{2, 5, 29, 37, 0} + oidExtKeyUsageServerAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 1} + oidExtKeyUsageClientAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 2} + oidExtKeyUsageCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 3} + oidExtKeyUsageEmailProtection = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 4} + oidExtKeyUsageIPSECEndSystem = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 5} + oidExtKeyUsageIPSECTunnel = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 6} + oidExtKeyUsageIPSECUser = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 7} + oidExtKeyUsageTimeStamping = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 8} + oidExtKeyUsageOCSPSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 9} + oidExtKeyUsageMicrosoftServerGatedCrypto = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 10, 3, 3} + oidExtKeyUsageNetscapeServerGatedCrypto = asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 4, 1} + oidExtKeyUsageMicrosoftCommercialCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 2, 1, 22} + oidExtKeyUsageMicrosoftKernelCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 61, 1, 1} +) + +// ExtKeyUsage represents an extended set of actions that are valid for a given key. +// Each of the ExtKeyUsage* constants define a unique action. +type ExtKeyUsage int + +const ( + ExtKeyUsageAny ExtKeyUsage = iota + ExtKeyUsageServerAuth + ExtKeyUsageClientAuth + ExtKeyUsageCodeSigning + ExtKeyUsageEmailProtection + ExtKeyUsageIPSECEndSystem + ExtKeyUsageIPSECTunnel + ExtKeyUsageIPSECUser + ExtKeyUsageTimeStamping + ExtKeyUsageOCSPSigning + ExtKeyUsageMicrosoftServerGatedCrypto + ExtKeyUsageNetscapeServerGatedCrypto + ExtKeyUsageMicrosoftCommercialCodeSigning + ExtKeyUsageMicrosoftKernelCodeSigning +) + +// extKeyUsageOIDs contains the mapping between an ExtKeyUsage and its OID. +var extKeyUsageOIDs = []struct { + extKeyUsage ExtKeyUsage + oid asn1.ObjectIdentifier +}{ + {ExtKeyUsageAny, oidExtKeyUsageAny}, + {ExtKeyUsageServerAuth, oidExtKeyUsageServerAuth}, + {ExtKeyUsageClientAuth, oidExtKeyUsageClientAuth}, + {ExtKeyUsageCodeSigning, oidExtKeyUsageCodeSigning}, + {ExtKeyUsageEmailProtection, oidExtKeyUsageEmailProtection}, + {ExtKeyUsageIPSECEndSystem, oidExtKeyUsageIPSECEndSystem}, + {ExtKeyUsageIPSECTunnel, oidExtKeyUsageIPSECTunnel}, + {ExtKeyUsageIPSECUser, oidExtKeyUsageIPSECUser}, + {ExtKeyUsageTimeStamping, oidExtKeyUsageTimeStamping}, + {ExtKeyUsageOCSPSigning, oidExtKeyUsageOCSPSigning}, + {ExtKeyUsageMicrosoftServerGatedCrypto, oidExtKeyUsageMicrosoftServerGatedCrypto}, + {ExtKeyUsageNetscapeServerGatedCrypto, oidExtKeyUsageNetscapeServerGatedCrypto}, + {ExtKeyUsageMicrosoftCommercialCodeSigning, oidExtKeyUsageMicrosoftCommercialCodeSigning}, + {ExtKeyUsageMicrosoftKernelCodeSigning, oidExtKeyUsageMicrosoftKernelCodeSigning}, +} + +func extKeyUsageFromOID(oid asn1.ObjectIdentifier) (eku ExtKeyUsage, ok bool) { + for _, pair := range extKeyUsageOIDs { + if oid.Equal(pair.oid) { + return pair.extKeyUsage, true + } + } + return +} + +const ( + nameTypeEmail = 1 + nameTypeDNS = 2 + nameTypeURI = 6 + nameTypeIP = 7 +) + +var ( + oidExtensionAuthorityInfoAccess = []int{1, 3, 6, 1, 5, 5, 7, 1, 1} +) + +var ( + oidAuthorityInfoAccessOcsp = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 48, 1} + oidAuthorityInfoAccessIssuers = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 48, 2} +) + +func isIA5String(s string) error { + for _, r := range s { + // Per RFC5280 "IA5String is limited to the set of ASCII characters" + if r > unicode.MaxASCII { + return fmt.Errorf("x509: %q cannot be encoded as an IA5String", s) + } + } + + return nil +} diff --git a/vendor/github.com/smallstep/pkcs7/pkcs7.go b/vendor/github.com/smallstep/pkcs7/pkcs7.go new file mode 100644 index 0000000000000..dd5b18380ad4f --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/pkcs7.go @@ -0,0 +1,353 @@ +// Package pkcs7 implements parsing and generation of some PKCS#7 structures. +package pkcs7 + +import ( + "bytes" + "crypto" + "crypto/dsa" + "crypto/ecdsa" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "fmt" + "io" + "sort" + "sync" + + _ "crypto/sha1" // for crypto.SHA1 + + legacyx509 "github.com/smallstep/pkcs7/internal/legacy/x509" +) + +// PKCS7 Represents a PKCS7 structure +type PKCS7 struct { + Content []byte + Certificates []*x509.Certificate + CRLs []pkix.CertificateList + Signers []signerInfo + Hasher Hasher + raw interface{} +} + +// Hasher is an interface defining a custom hash calculator. +type Hasher interface { + Hash(crypto.Hash, io.Reader) ([]byte, error) +} + +type contentInfo struct { + ContentType asn1.ObjectIdentifier + Content asn1.RawValue `asn1:"explicit,optional,tag:0"` +} + +// ErrUnsupportedContentType is returned when a PKCS7 content type is not supported. +// Currently only Data (1.2.840.113549.1.7.1), Signed Data (1.2.840.113549.1.7.2), +// and Enveloped Data are supported (1.2.840.113549.1.7.3) +var ErrUnsupportedContentType = errors.New("pkcs7: cannot parse data: unimplemented content type") + +type unsignedData []byte + +var ( + // Signed Data OIDs + OIDData = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 7, 1} + OIDSignedData = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 7, 2} + OIDEnvelopedData = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 7, 3} + OIDEncryptedData = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 7, 6} + OIDAttributeContentType = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 3} + OIDAttributeMessageDigest = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 4} + OIDAttributeSigningTime = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 5} + + // Digest Algorithms + OIDDigestAlgorithmSHA1 = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 26} + OIDDigestAlgorithmSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 1} + OIDDigestAlgorithmSHA384 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 2} + OIDDigestAlgorithmSHA512 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 3} + OIDDigestAlgorithmSHA224 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 4} + + OIDDigestAlgorithmDSA = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 1} + OIDDigestAlgorithmDSASHA1 = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 3} + + OIDDigestAlgorithmECDSASHA1 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 1} + OIDDigestAlgorithmECDSASHA256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 2} + OIDDigestAlgorithmECDSASHA384 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 3} + OIDDigestAlgorithmECDSASHA512 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 4} + + // Signature Algorithms + OIDEncryptionAlgorithmRSAMD5 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 4} // see https://www.rfc-editor.org/rfc/rfc8017#appendix-A.2.4 + OIDEncryptionAlgorithmRSASHA1 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 5} // ditto + OIDEncryptionAlgorithmRSASHA256 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 11} // ditto + OIDEncryptionAlgorithmRSASHA384 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 12} // ditto + OIDEncryptionAlgorithmRSASHA512 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 13} // ditto + OIDEncryptionAlgorithmRSASHA224 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 14} // ditto + + OIDEncryptionAlgorithmECDSAP256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7} + OIDEncryptionAlgorithmECDSAP384 = asn1.ObjectIdentifier{1, 3, 132, 0, 34} + OIDEncryptionAlgorithmECDSAP521 = asn1.ObjectIdentifier{1, 3, 132, 0, 35} + + // Asymmetric Encryption Algorithms + OIDEncryptionAlgorithmRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1} // see https://www.rfc-editor.org/rfc/rfc8017#appendix-A.2.2 + OIDEncryptionAlgorithmRSAESOAEP = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 7} // see https://www.rfc-editor.org/rfc/rfc8017#appendix-A.2.1 + + // Symmetric Encryption Algorithms + OIDEncryptionAlgorithmDESCBC = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 7} // see https://www.rfc-editor.org/rfc/rfc8018.html#appendix-B.2.1 + OIDEncryptionAlgorithmDESEDE3CBC = asn1.ObjectIdentifier{1, 2, 840, 113549, 3, 7} // see https://www.rfc-editor.org/rfc/rfc8018.html#appendix-B.2.2 + OIDEncryptionAlgorithmAES256CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 42} // see https://www.rfc-editor.org/rfc/rfc3565.html#section-4.1 + OIDEncryptionAlgorithmAES128GCM = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 6} // see https://www.rfc-editor.org/rfc/rfc5084.html#section-3.2 + OIDEncryptionAlgorithmAES128CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 2} // see https://www.rfc-editor.org/rfc/rfc8018.html#appendix-B.2.5 + OIDEncryptionAlgorithmAES256GCM = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 46} // see https://www.rfc-editor.org/rfc/rfc5084.html#section-3.2 +) + +func getHashForOID(oid asn1.ObjectIdentifier) (crypto.Hash, error) { + switch { + case oid.Equal(OIDDigestAlgorithmSHA1), oid.Equal(OIDDigestAlgorithmECDSASHA1), + oid.Equal(OIDDigestAlgorithmDSA), oid.Equal(OIDDigestAlgorithmDSASHA1), + oid.Equal(OIDEncryptionAlgorithmRSA): + return crypto.SHA1, nil + case oid.Equal(OIDDigestAlgorithmSHA256), oid.Equal(OIDDigestAlgorithmECDSASHA256): + return crypto.SHA256, nil + case oid.Equal(OIDDigestAlgorithmSHA384), oid.Equal(OIDDigestAlgorithmECDSASHA384): + return crypto.SHA384, nil + case oid.Equal(OIDDigestAlgorithmSHA512), oid.Equal(OIDDigestAlgorithmECDSASHA512): + return crypto.SHA512, nil + } + return crypto.Hash(0), ErrUnsupportedAlgorithm +} + +// getDigestOIDForSignatureAlgorithm takes an x509.SignatureAlgorithm +// and returns the corresponding OID digest algorithm +func getDigestOIDForSignatureAlgorithm(digestAlg x509.SignatureAlgorithm) (asn1.ObjectIdentifier, error) { + switch digestAlg { + case x509.SHA1WithRSA, x509.ECDSAWithSHA1: + return OIDDigestAlgorithmSHA1, nil + case x509.SHA256WithRSA, x509.ECDSAWithSHA256: + return OIDDigestAlgorithmSHA256, nil + case x509.SHA384WithRSA, x509.ECDSAWithSHA384: + return OIDDigestAlgorithmSHA384, nil + case x509.SHA512WithRSA, x509.ECDSAWithSHA512: + return OIDDigestAlgorithmSHA512, nil + } + return nil, fmt.Errorf("pkcs7: cannot convert hash to oid, unknown hash algorithm") +} + +// getOIDForEncryptionAlgorithm takes the public or private key type of the signer and +// the OID of a digest algorithm to return the appropriate signerInfo.DigestEncryptionAlgorithm +func getOIDForEncryptionAlgorithm(pkey interface{}, OIDDigestAlg asn1.ObjectIdentifier) (asn1.ObjectIdentifier, error) { + switch k := pkey.(type) { + case *rsa.PrivateKey, *rsa.PublicKey: + switch { + default: + return OIDEncryptionAlgorithmRSA, nil + case OIDDigestAlg.Equal(OIDEncryptionAlgorithmRSA): + return OIDEncryptionAlgorithmRSA, nil + case OIDDigestAlg.Equal(OIDDigestAlgorithmSHA1): + return OIDEncryptionAlgorithmRSASHA1, nil + case OIDDigestAlg.Equal(OIDDigestAlgorithmSHA256): + return OIDEncryptionAlgorithmRSASHA256, nil + case OIDDigestAlg.Equal(OIDDigestAlgorithmSHA384): + return OIDEncryptionAlgorithmRSASHA384, nil + case OIDDigestAlg.Equal(OIDDigestAlgorithmSHA512): + return OIDEncryptionAlgorithmRSASHA512, nil + } + case *ecdsa.PrivateKey, *ecdsa.PublicKey: + switch { + case OIDDigestAlg.Equal(OIDDigestAlgorithmSHA1): + return OIDDigestAlgorithmECDSASHA1, nil + case OIDDigestAlg.Equal(OIDDigestAlgorithmSHA256): + return OIDDigestAlgorithmECDSASHA256, nil + case OIDDigestAlg.Equal(OIDDigestAlgorithmSHA384): + return OIDDigestAlgorithmECDSASHA384, nil + case OIDDigestAlg.Equal(OIDDigestAlgorithmSHA512): + return OIDDigestAlgorithmECDSASHA512, nil + } + case *dsa.PrivateKey, *dsa.PublicKey: + return OIDDigestAlgorithmDSA, nil + case crypto.Signer: + // This generic case is here to cover types from other packages. It + // was specifically added to handle the private keyRSA type in the + // github.com/go-piv/piv-go/piv package. + return getOIDForEncryptionAlgorithm(k.Public(), OIDDigestAlg) + } + return nil, fmt.Errorf("pkcs7: cannot convert encryption algorithm to oid, unknown private key type %T", pkey) + +} + +// Parse decodes a DER encoded PKCS7 package +func Parse(data []byte) (p7 *PKCS7, err error) { + if len(data) == 0 { + return nil, errors.New("pkcs7: input data is empty") + } + var info contentInfo + der, err := ber2der(data) + if err != nil { + return nil, err + } + rest, err := asn1.Unmarshal(der, &info) + if len(rest) > 0 { + err = asn1.SyntaxError{Msg: "trailing data"} + return + } + if err != nil { + return + } + + // fmt.Printf("--> Content Type: %s", info.ContentType) + switch { + case info.ContentType.Equal(OIDSignedData): + return parseSignedData(info.Content.Bytes) + case info.ContentType.Equal(OIDEnvelopedData): + return parseEnvelopedData(info.Content.Bytes) + case info.ContentType.Equal(OIDEncryptedData): + return parseEncryptedData(info.Content.Bytes) + } + return nil, ErrUnsupportedContentType +} + +func parseEnvelopedData(data []byte) (*PKCS7, error) { + var ed envelopedData + if _, err := asn1.Unmarshal(data, &ed); err != nil { + return nil, err + } + return &PKCS7{ + raw: ed, + }, nil +} + +func parseEncryptedData(data []byte) (*PKCS7, error) { + var ed encryptedData + if _, err := asn1.Unmarshal(data, &ed); err != nil { + return nil, err + } + return &PKCS7{ + raw: ed, + }, nil +} + +// SetFallbackLegacyX509CertificateParserEnabled enables parsing certificates +// embedded in a PKCS7 message using the logic from crypto/x509 from before +// Go 1.23. Go 1.23 introduced a breaking change in case a certificate contains +// a critical authority key identifier, which is the correct thing to do based +// on RFC 5280, but it breaks Windows devices performing the Simple Certificate +// Enrolment Protocol (SCEP), as the certificates embedded in those requests +// apparently have authority key identifier extensions marked critical. +// +// See https://go-review.googlesource.com/c/go/+/562341 for the change in the +// Go source. +// +// When [SetFallbackLegacyX509CertificateParserEnabled] is called with true, it +// enables parsing using the legacy crypto/x509 certificate parser. It'll first +// try to parse the certificates using the regular Go crypto/x509 package, but +// if it fails on the above case, it'll retry parsing the certificates using a +// copy of the crypto/x509 package based on Go 1.23, but skips checking the +// authority key identifier extension being critical or not. +func SetFallbackLegacyX509CertificateParserEnabled(v bool) { + legacyX509CertificateParser.Lock() + legacyX509CertificateParser.enabled = v + legacyX509CertificateParser.Unlock() +} + +var legacyX509CertificateParser struct { + sync.RWMutex + enabled bool +} + +func isLegacyX509ParserEnabled() bool { + legacyX509CertificateParser.RLock() + defer legacyX509CertificateParser.RUnlock() + return legacyX509CertificateParser.enabled +} + +func (raw rawCertificates) Parse() ([]*x509.Certificate, error) { + if len(raw.Raw) == 0 { + return nil, nil + } + + var val asn1.RawValue + if _, err := asn1.Unmarshal(raw.Raw, &val); err != nil { + return nil, err + } + + certificates, err := x509.ParseCertificates(val.Bytes) + if err != nil && err.Error() == "x509: authority key identifier incorrectly marked critical" { + if isLegacyX509ParserEnabled() { + certificates, err = legacyx509.ParseCertificates(val.Bytes) + } + } + + return certificates, err +} + +func isCertMatchForIssuerAndSerial(cert *x509.Certificate, ias issuerAndSerial) bool { + return cert.SerialNumber.Cmp(ias.SerialNumber) == 0 && bytes.Equal(cert.RawIssuer, ias.IssuerName.FullBytes) +} + +// Attribute represents a key value pair attribute. Value must be marshalable byte +// `encoding/asn1` +type Attribute struct { + Type asn1.ObjectIdentifier + Value interface{} +} + +type attributes struct { + types []asn1.ObjectIdentifier + values []interface{} +} + +// Add adds the attribute, maintaining insertion order +func (attrs *attributes) Add(attrType asn1.ObjectIdentifier, value interface{}) { + attrs.types = append(attrs.types, attrType) + attrs.values = append(attrs.values, value) +} + +type sortableAttribute struct { + SortKey []byte + Attribute attribute +} + +type attributeSet []sortableAttribute + +func (sa attributeSet) Len() int { + return len(sa) +} + +func (sa attributeSet) Less(i, j int) bool { + return bytes.Compare(sa[i].SortKey, sa[j].SortKey) < 0 +} + +func (sa attributeSet) Swap(i, j int) { + sa[i], sa[j] = sa[j], sa[i] +} + +func (sa attributeSet) Attributes() []attribute { + attrs := make([]attribute, len(sa)) + for i, attr := range sa { + attrs[i] = attr.Attribute + } + return attrs +} + +func (attrs *attributes) ForMarshalling() ([]attribute, error) { + sortables := make(attributeSet, len(attrs.types)) + for i := range sortables { + attrType := attrs.types[i] + attrValue := attrs.values[i] + asn1Value, err := asn1.Marshal(attrValue) + if err != nil { + return nil, err + } + attr := attribute{ + Type: attrType, + Value: asn1.RawValue{Tag: 17, IsCompound: true, Bytes: asn1Value}, // 17 == SET tag + } + encoded, err := asn1.Marshal(attr) + if err != nil { + return nil, err + } + sortables[i] = sortableAttribute{ + SortKey: encoded, + Attribute: attr, + } + } + sort.Sort(sortables) + return sortables.Attributes(), nil +} diff --git a/vendor/github.com/smallstep/pkcs7/sign.go b/vendor/github.com/smallstep/pkcs7/sign.go new file mode 100644 index 0000000000000..74ce50d802004 --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/sign.go @@ -0,0 +1,474 @@ +package pkcs7 + +import ( + "bytes" + "crypto" + "crypto/dsa" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "fmt" + "math/big" + "sync" + "time" +) + +func init() { + defaultMessageDigestAlgorithm.oid = OIDDigestAlgorithmSHA1 +} + +var defaultMessageDigestAlgorithm struct { + sync.RWMutex + oid asn1.ObjectIdentifier +} + +// SetDefaultDigestAlgorithm sets the default digest algorithm +// to be used for signing operations on [SignedData]. +// +// This must be called before creating a new instance of [SignedData] +// using [NewSignedData]. +// +// When this function is not called, the default digest algorithm is SHA1. +func SetDefaultDigestAlgorithm(d asn1.ObjectIdentifier) error { + defaultMessageDigestAlgorithm.Lock() + defer defaultMessageDigestAlgorithm.Unlock() + + switch { + case d.Equal(OIDDigestAlgorithmSHA1), + d.Equal(OIDDigestAlgorithmSHA224), d.Equal(OIDDigestAlgorithmSHA256), + d.Equal(OIDDigestAlgorithmSHA384), d.Equal(OIDDigestAlgorithmSHA512), + d.Equal(OIDDigestAlgorithmDSA), d.Equal(OIDDigestAlgorithmDSASHA1), + d.Equal(OIDDigestAlgorithmECDSASHA1), d.Equal(OIDDigestAlgorithmECDSASHA256), + d.Equal(OIDDigestAlgorithmECDSASHA384), d.Equal(OIDDigestAlgorithmECDSASHA512): + break + default: + return fmt.Errorf("unsupported message digest algorithm %v", d) + } + + defaultMessageDigestAlgorithm.oid = d + + return nil +} + +func defaultMessageDigestAlgorithmOID() asn1.ObjectIdentifier { + defaultMessageDigestAlgorithm.RLock() + defer defaultMessageDigestAlgorithm.RUnlock() + + return defaultMessageDigestAlgorithm.oid +} + +// SignedData is an opaque data structure for creating signed data payloads +type SignedData struct { + sd signedData + certs []*x509.Certificate + data, messageDigest []byte + digestOid asn1.ObjectIdentifier + encryptionOid asn1.ObjectIdentifier +} + +// NewSignedData takes data and initializes a PKCS7 SignedData struct that is +// ready to be signed via AddSigner. The digest algorithm is set to SHA1 by default +// and can be changed by calling SetDigestAlgorithm. +func NewSignedData(data []byte) (*SignedData, error) { + content, err := asn1.Marshal(data) + if err != nil { + return nil, err + } + ci := contentInfo{ + ContentType: OIDData, + Content: asn1.RawValue{Class: 2, Tag: 0, Bytes: content, IsCompound: true}, + } + sd := signedData{ + ContentInfo: ci, + Version: 1, + } + return &SignedData{sd: sd, data: data, digestOid: defaultMessageDigestAlgorithmOID()}, nil +} + +// SignerInfoConfig are optional values to include when adding a signer +type SignerInfoConfig struct { + ExtraSignedAttributes []Attribute + ExtraUnsignedAttributes []Attribute +} + +type signedData struct { + Version int `asn1:"default:1"` + DigestAlgorithmIdentifiers []pkix.AlgorithmIdentifier `asn1:"set"` + ContentInfo contentInfo + Certificates rawCertificates `asn1:"optional,tag:0"` + CRLs []pkix.CertificateList `asn1:"optional,tag:1"` + SignerInfos []signerInfo `asn1:"set"` +} + +type signerInfo struct { + Version int `asn1:"default:1"` + IssuerAndSerialNumber issuerAndSerial + DigestAlgorithm pkix.AlgorithmIdentifier + AuthenticatedAttributes []attribute `asn1:"optional,omitempty,tag:0"` + DigestEncryptionAlgorithm pkix.AlgorithmIdentifier + EncryptedDigest []byte + UnauthenticatedAttributes []attribute `asn1:"optional,omitempty,tag:1"` +} + +type attribute struct { + Type asn1.ObjectIdentifier + Value asn1.RawValue `asn1:"set"` +} + +func marshalAttributes(attrs []attribute) ([]byte, error) { + encodedAttributes, err := asn1.Marshal(struct { + A []attribute `asn1:"set"` + }{A: attrs}) + if err != nil { + return nil, err + } + + // Remove the leading sequence octets + var raw asn1.RawValue + asn1.Unmarshal(encodedAttributes, &raw) + return raw.Bytes, nil +} + +type rawCertificates struct { + Raw asn1.RawContent +} + +type issuerAndSerial struct { + IssuerName asn1.RawValue + SerialNumber *big.Int +} + +// SetDigestAlgorithm sets the digest algorithm to be used in the signing process. +// +// This should be called before adding signers +func (sd *SignedData) SetDigestAlgorithm(d asn1.ObjectIdentifier) { + sd.digestOid = d +} + +// SetEncryptionAlgorithm sets the encryption algorithm to be used in the signing process. +// +// This should be called before adding signers +func (sd *SignedData) SetEncryptionAlgorithm(d asn1.ObjectIdentifier) { + sd.encryptionOid = d +} + +// AddSigner is a wrapper around AddSignerChain() that adds a signer without any parent. +func (sd *SignedData) AddSigner(ee *x509.Certificate, pkey crypto.PrivateKey, config SignerInfoConfig) error { + var parents []*x509.Certificate + return sd.AddSignerChain(ee, pkey, parents, config) +} + +// AddSignerChain signs attributes about the content and adds certificates +// and signers infos to the Signed Data. The certificate and private key +// of the end-entity signer are used to issue the signature, and any +// parent of that end-entity that need to be added to the list of +// certifications can be specified in the parents slice. +// +// The signature algorithm used to hash the data is the one of the end-entity +// certificate. +func (sd *SignedData) AddSignerChain(ee *x509.Certificate, pkey crypto.PrivateKey, parents []*x509.Certificate, config SignerInfoConfig) error { + // Following RFC 2315, 9.2 SignerInfo type, the distinguished name of + // the issuer of the end-entity signer is stored in the issuerAndSerialNumber + // section of the SignedData.SignerInfo, alongside the serial number of + // the end-entity. + var ias issuerAndSerial + ias.SerialNumber = ee.SerialNumber + if len(parents) == 0 { + // no parent, the issuer is the end-entity cert itself + ias.IssuerName = asn1.RawValue{FullBytes: ee.RawIssuer} + } else { + err := verifyPartialChain(ee, parents) + if err != nil { + return err + } + // the first parent is the issuer + ias.IssuerName = asn1.RawValue{FullBytes: parents[0].RawSubject} + } + sd.sd.DigestAlgorithmIdentifiers = append(sd.sd.DigestAlgorithmIdentifiers, + pkix.AlgorithmIdentifier{Algorithm: sd.digestOid}, + ) + hash, err := getHashForOID(sd.digestOid) + if err != nil { + return err + } + h := hash.New() + h.Write(sd.data) + sd.messageDigest = h.Sum(nil) + encryptionOid, err := getOIDForEncryptionAlgorithm(pkey, sd.digestOid) + if err != nil { + return err + } + attrs := &attributes{} + attrs.Add(OIDAttributeContentType, sd.sd.ContentInfo.ContentType) + attrs.Add(OIDAttributeMessageDigest, sd.messageDigest) + attrs.Add(OIDAttributeSigningTime, time.Now().UTC()) + for _, attr := range config.ExtraSignedAttributes { + attrs.Add(attr.Type, attr.Value) + } + finalAttrs, err := attrs.ForMarshalling() + if err != nil { + return err + } + unsignedAttrs := &attributes{} + for _, attr := range config.ExtraUnsignedAttributes { + unsignedAttrs.Add(attr.Type, attr.Value) + } + finalUnsignedAttrs, err := unsignedAttrs.ForMarshalling() + if err != nil { + return err + } + // create signature of signed attributes + signature, err := signAttributes(finalAttrs, pkey, hash) + if err != nil { + return err + } + signer := signerInfo{ + AuthenticatedAttributes: finalAttrs, + UnauthenticatedAttributes: finalUnsignedAttrs, + DigestAlgorithm: pkix.AlgorithmIdentifier{Algorithm: sd.digestOid}, + DigestEncryptionAlgorithm: pkix.AlgorithmIdentifier{Algorithm: encryptionOid}, + IssuerAndSerialNumber: ias, + EncryptedDigest: signature, + Version: 1, + } + sd.certs = append(sd.certs, ee) + if len(parents) > 0 { + sd.certs = append(sd.certs, parents...) + } + sd.sd.SignerInfos = append(sd.sd.SignerInfos, signer) + return nil +} + +// SignWithoutAttr issues a signature on the content of the pkcs7 SignedData. +// Unlike AddSigner/AddSignerChain, it calculates the digest on the data alone +// and does not include any signed attributes like timestamp and so on. +// +// This function is needed to sign old Android APKs, something you probably +// shouldn't do unless you're maintaining backward compatibility for old +// applications. +func (sd *SignedData) SignWithoutAttr(ee *x509.Certificate, pkey crypto.PrivateKey, config SignerInfoConfig) error { + var signature []byte + sd.sd.DigestAlgorithmIdentifiers = append(sd.sd.DigestAlgorithmIdentifiers, pkix.AlgorithmIdentifier{Algorithm: sd.digestOid}) + hash, err := getHashForOID(sd.digestOid) + if err != nil { + return err + } + h := hash.New() + h.Write(sd.data) + sd.messageDigest = h.Sum(nil) + switch pkey := pkey.(type) { + case *dsa.PrivateKey: + // dsa doesn't implement crypto.Signer so we make a special case + // https://github.com/golang/go/issues/27889 + r, s, err := dsa.Sign(rand.Reader, pkey, sd.messageDigest) + if err != nil { + return err + } + signature, err = asn1.Marshal(dsaSignature{r, s}) + if err != nil { + return err + } + default: + key, ok := pkey.(crypto.Signer) + if !ok { + return errors.New("pkcs7: private key does not implement crypto.Signer") + } + signature, err = key.Sign(rand.Reader, sd.messageDigest, hash) + if err != nil { + return err + } + } + var ias issuerAndSerial + ias.SerialNumber = ee.SerialNumber + // no parent, the issue is the end-entity cert itself + ias.IssuerName = asn1.RawValue{FullBytes: ee.RawIssuer} + if sd.encryptionOid == nil { + // if the encryption algorithm wasn't set by SetEncryptionAlgorithm, + // infer it from the digest algorithm + sd.encryptionOid, err = getOIDForEncryptionAlgorithm(pkey, sd.digestOid) + } + if err != nil { + return err + } + signer := signerInfo{ + DigestAlgorithm: pkix.AlgorithmIdentifier{Algorithm: sd.digestOid}, + DigestEncryptionAlgorithm: pkix.AlgorithmIdentifier{Algorithm: sd.encryptionOid}, + IssuerAndSerialNumber: ias, + EncryptedDigest: signature, + Version: 1, + } + // create signature of signed attributes + sd.certs = append(sd.certs, ee) + sd.sd.SignerInfos = append(sd.sd.SignerInfos, signer) + return nil +} + +func (si *signerInfo) SetUnauthenticatedAttributes(extraUnsignedAttrs []Attribute) error { + unsignedAttrs := &attributes{} + for _, attr := range extraUnsignedAttrs { + unsignedAttrs.Add(attr.Type, attr.Value) + } + finalUnsignedAttrs, err := unsignedAttrs.ForMarshalling() + if err != nil { + return err + } + + si.UnauthenticatedAttributes = finalUnsignedAttrs + + return nil +} + +// AddCertificate adds the certificate to the payload. Useful for parent certificates +func (sd *SignedData) AddCertificate(cert *x509.Certificate) { + sd.certs = append(sd.certs, cert) +} + +// Detach removes content from the signed data struct to make it a detached signature. +// This must be called right before Finish() +func (sd *SignedData) Detach() { + sd.sd.ContentInfo = contentInfo{ContentType: OIDData} +} + +// GetSignedData returns the private Signed Data +func (sd *SignedData) GetSignedData() *signedData { + return &sd.sd +} + +// Finish marshals the content and its signers +func (sd *SignedData) Finish() ([]byte, error) { + sd.sd.Certificates = marshalCertificates(sd.certs) + inner, err := asn1.Marshal(sd.sd) + if err != nil { + return nil, err + } + outer := contentInfo{ + ContentType: OIDSignedData, + Content: asn1.RawValue{Class: 2, Tag: 0, Bytes: inner, IsCompound: true}, + } + return asn1.Marshal(outer) +} + +// RemoveAuthenticatedAttributes removes authenticated attributes from signedData +// similar to OpenSSL's PKCS7_NOATTR or -noattr flags +func (sd *SignedData) RemoveAuthenticatedAttributes() { + for i := range sd.sd.SignerInfos { + sd.sd.SignerInfos[i].AuthenticatedAttributes = nil + } +} + +// RemoveUnauthenticatedAttributes removes unauthenticated attributes from signedData +func (sd *SignedData) RemoveUnauthenticatedAttributes() { + for i := range sd.sd.SignerInfos { + sd.sd.SignerInfos[i].UnauthenticatedAttributes = nil + } +} + +// verifyPartialChain checks that a given cert is issued by the first parent in the list, +// then continue down the path. It doesn't require the last parent to be a root CA, +// or to be trusted in any truststore. It simply verifies that the chain provided, albeit +// partial, makes sense. +func verifyPartialChain(cert *x509.Certificate, parents []*x509.Certificate) error { + if len(parents) == 0 { + return fmt.Errorf("pkcs7: zero parents provided to verify the signature of certificate %q", cert.Subject.CommonName) + } + err := cert.CheckSignatureFrom(parents[0]) + if err != nil { + return fmt.Errorf("pkcs7: certificate signature from parent is invalid: %v", err) + } + if len(parents) == 1 { + // there is no more parent to check, return + return nil + } + return verifyPartialChain(parents[0], parents[1:]) +} + +func cert2issuerAndSerial(cert *x509.Certificate) (issuerAndSerial, error) { + var ias issuerAndSerial + // The issuer RDNSequence has to match exactly the sequence in the certificate + // We cannot use cert.Issuer.ToRDNSequence() here since it mangles the sequence + ias.IssuerName = asn1.RawValue{FullBytes: cert.RawIssuer} + ias.SerialNumber = cert.SerialNumber + + return ias, nil +} + +// signs the DER encoded form of the attributes with the private key +func signAttributes(attrs []attribute, pkey crypto.PrivateKey, digestAlg crypto.Hash) ([]byte, error) { + attrBytes, err := marshalAttributes(attrs) + if err != nil { + return nil, err + } + h := digestAlg.New() + h.Write(attrBytes) + hash := h.Sum(nil) + + // dsa doesn't implement crypto.Signer so we make a special case + // https://github.com/golang/go/issues/27889 + switch pkey := pkey.(type) { + case *dsa.PrivateKey: + r, s, err := dsa.Sign(rand.Reader, pkey, hash) + if err != nil { + return nil, err + } + return asn1.Marshal(dsaSignature{r, s}) + } + + key, ok := pkey.(crypto.Signer) + if !ok { + return nil, errors.New("pkcs7: private key does not implement crypto.Signer") + } + return key.Sign(rand.Reader, hash, digestAlg) +} + +type dsaSignature struct { + R, S *big.Int +} + +// concats and wraps the certificates in the RawValue structure +func marshalCertificates(certs []*x509.Certificate) rawCertificates { + var buf bytes.Buffer + for _, cert := range certs { + buf.Write(cert.Raw) + } + rawCerts, _ := marshalCertificateBytes(buf.Bytes()) + return rawCerts +} + +// Even though, the tag & length are stripped out during marshalling the +// RawContent, we have to encode it into the RawContent. If its missing, +// then `asn1.Marshal()` will strip out the certificate wrapper instead. +func marshalCertificateBytes(certs []byte) (rawCertificates, error) { + var val = asn1.RawValue{Bytes: certs, Class: 2, Tag: 0, IsCompound: true} + b, err := asn1.Marshal(val) + if err != nil { + return rawCertificates{}, err + } + return rawCertificates{Raw: b}, nil +} + +// DegenerateCertificate creates a signed data structure containing only the +// provided certificate or certificate chain. +func DegenerateCertificate(cert []byte) ([]byte, error) { + rawCert, err := marshalCertificateBytes(cert) + if err != nil { + return nil, err + } + emptyContent := contentInfo{ContentType: OIDData} + sd := signedData{ + Version: 1, + ContentInfo: emptyContent, + Certificates: rawCert, + CRLs: []pkix.CertificateList{}, + } + content, err := asn1.Marshal(sd) + if err != nil { + return nil, err + } + signedContent := contentInfo{ + ContentType: OIDSignedData, + Content: asn1.RawValue{Class: 2, Tag: 0, Bytes: content, IsCompound: true}, + } + return asn1.Marshal(signedContent) +} diff --git a/vendor/github.com/smallstep/pkcs7/verify.go b/vendor/github.com/smallstep/pkcs7/verify.go new file mode 100644 index 0000000000000..f9ad34bbab50f --- /dev/null +++ b/vendor/github.com/smallstep/pkcs7/verify.go @@ -0,0 +1,385 @@ +package pkcs7 + +import ( + "bytes" + "crypto" + "crypto/subtle" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "fmt" + "time" +) + +// Verify is a wrapper around VerifyWithChain() that initializes an empty +// trust store, effectively disabling certificate verification when validating +// a signature. +func (p7 *PKCS7) Verify() (err error) { + return p7.VerifyWithChain(nil) +} + +// VerifyWithChain checks the signatures of a PKCS7 object. +// +// If truststore is not nil, it also verifies the chain of trust of +// the end-entity signer cert to one of the roots in the +// truststore. When the PKCS7 object includes the signing time +// authenticated attr verifies the chain at that time and UTC now +// otherwise. +func (p7 *PKCS7) VerifyWithChain(truststore *x509.CertPool) (err error) { + if len(p7.Signers) == 0 { + return errors.New("pkcs7: Message has no signers") + } + for _, signer := range p7.Signers { + if err := verifySignature(p7, signer, truststore); err != nil { + return err + } + } + return nil +} + +// VerifyWithChainAtTime checks the signatures of a PKCS7 object. +// +// If truststore is not nil, it also verifies the chain of trust of +// the end-entity signer cert to a root in the truststore at +// currentTime. It does not use the signing time authenticated +// attribute. +func (p7 *PKCS7) VerifyWithChainAtTime(truststore *x509.CertPool, currentTime time.Time) (err error) { + if len(p7.Signers) == 0 { + return errors.New("pkcs7: Message has no signers") + } + for _, signer := range p7.Signers { + if err := verifySignatureAtTime(p7, signer, truststore, currentTime); err != nil { + return err + } + } + return nil +} + +// SigningTimeNotValidError is returned when the signing time attribute +// falls outside of the signer certificate validity. +type SigningTimeNotValidError struct { + SigningTime time.Time + NotBefore time.Time // NotBefore of signer + NotAfter time.Time // NotAfter of signer +} + +func (e *SigningTimeNotValidError) Error() string { + return fmt.Sprintf("pkcs7: signing time %q is outside of certificate validity %q to %q", + e.SigningTime.Format(time.RFC3339), + e.NotBefore.Format(time.RFC3339), + e.NotAfter.Format(time.RFC3339)) +} + +func verifySignatureAtTime(p7 *PKCS7, signer signerInfo, truststore *x509.CertPool, currentTime time.Time) (err error) { + signedData := p7.Content + ee := getCertFromCertsByIssuerAndSerial(p7.Certificates, signer.IssuerAndSerialNumber) + if ee == nil { + return errors.New("pkcs7: No certificate for signer") + } + if len(signer.AuthenticatedAttributes) > 0 { + // TODO(fullsailor): First check the content type match + var ( + digest []byte + signingTime time.Time + ) + err := unmarshalAttribute(signer.AuthenticatedAttributes, OIDAttributeMessageDigest, &digest) + if err != nil { + return err + } + hash, err := getHashForOID(signer.DigestAlgorithm.Algorithm) + if err != nil { + return err + } + computed, err := calculateHash(p7.Hasher, hash, p7.Content) + if err != nil { + return err + } + if subtle.ConstantTimeCompare(digest, computed) != 1 { + return &MessageDigestMismatchError{ + ExpectedDigest: digest, + ActualDigest: computed, + } + } + signedData, err = marshalAttributes(signer.AuthenticatedAttributes) + if err != nil { + return err + } + err = unmarshalAttribute(signer.AuthenticatedAttributes, OIDAttributeSigningTime, &signingTime) + if err == nil { + // signing time found, performing validity check + if signingTime.After(ee.NotAfter) || signingTime.Before(ee.NotBefore) { + return &SigningTimeNotValidError{ + SigningTime: signingTime, + NotBefore: ee.NotBefore, + NotAfter: ee.NotAfter, + } + } + } + } + if truststore != nil { + _, err = verifyCertChain(ee, p7.Certificates, truststore, currentTime) + if err != nil { + return err + } + } + sigalg, err := getSignatureAlgorithm(signer.DigestEncryptionAlgorithm, signer.DigestAlgorithm) + if err != nil { + return err + } + return ee.CheckSignature(sigalg, signedData, signer.EncryptedDigest) +} + +func verifySignature(p7 *PKCS7, signer signerInfo, truststore *x509.CertPool) (err error) { + signedData := p7.Content + ee := getCertFromCertsByIssuerAndSerial(p7.Certificates, signer.IssuerAndSerialNumber) + if ee == nil { + return errors.New("pkcs7: No certificate for signer") + } + signingTime := time.Now().UTC() + if len(signer.AuthenticatedAttributes) > 0 { + // TODO(fullsailor): First check the content type match + var digest []byte + err := unmarshalAttribute(signer.AuthenticatedAttributes, OIDAttributeMessageDigest, &digest) + if err != nil { + return err + } + hash, err := getHashForOID(signer.DigestAlgorithm.Algorithm) + if err != nil { + return err + } + computed, err := calculateHash(p7.Hasher, hash, p7.Content) + if err != nil { + return err + } + if subtle.ConstantTimeCompare(digest, computed) != 1 { + return &MessageDigestMismatchError{ + ExpectedDigest: digest, + ActualDigest: computed, + } + } + signedData, err = marshalAttributes(signer.AuthenticatedAttributes) + if err != nil { + return err + } + err = unmarshalAttribute(signer.AuthenticatedAttributes, OIDAttributeSigningTime, &signingTime) + if err == nil { + // signing time found, performing validity check + if signingTime.After(ee.NotAfter) || signingTime.Before(ee.NotBefore) { + return &SigningTimeNotValidError{ + SigningTime: signingTime, + NotBefore: ee.NotBefore, + NotAfter: ee.NotAfter, + } + } + } + } + if truststore != nil { + _, err = verifyCertChain(ee, p7.Certificates, truststore, signingTime) + if err != nil { + return err + } + } + sigalg, err := getSignatureAlgorithm(signer.DigestEncryptionAlgorithm, signer.DigestAlgorithm) + if err != nil { + return err + } + return ee.CheckSignature(sigalg, signedData, signer.EncryptedDigest) +} + +// GetOnlySigner returns an x509.Certificate for the first signer of the signed +// data payload. If there are more or less than one signer, nil is returned +func (p7 *PKCS7) GetOnlySigner() *x509.Certificate { + if len(p7.Signers) != 1 { + return nil + } + signer := p7.Signers[0] + return getCertFromCertsByIssuerAndSerial(p7.Certificates, signer.IssuerAndSerialNumber) +} + +// UnmarshalSignedAttribute decodes a single attribute from the signer info +func (p7 *PKCS7) UnmarshalSignedAttribute(attributeType asn1.ObjectIdentifier, out interface{}) error { + sd, ok := p7.raw.(signedData) + if !ok { + return errors.New("pkcs7: payload is not signedData content") + } + if len(sd.SignerInfos) < 1 { + return errors.New("pkcs7: payload has no signers") + } + attributes := sd.SignerInfos[0].AuthenticatedAttributes + return unmarshalAttribute(attributes, attributeType, out) +} + +func parseSignedData(data []byte) (*PKCS7, error) { + var sd signedData + asn1.Unmarshal(data, &sd) + certs, err := sd.Certificates.Parse() + if err != nil { + return nil, err + } + // fmt.Printf("--> Signed Data Version %d\n", sd.Version) + + var compound asn1.RawValue + var content unsignedData + + // The Content.Bytes maybe empty on PKI responses. + if len(sd.ContentInfo.Content.Bytes) > 0 { + if _, err := asn1.Unmarshal(sd.ContentInfo.Content.Bytes, &compound); err != nil { + return nil, err + } + } + // Compound octet string + if compound.IsCompound { + if compound.Tag == 4 { + for len(compound.Bytes) > 0 { + var cdata asn1.RawValue + if _, err = asn1.Unmarshal(compound.Bytes, &cdata); err != nil { + return nil, err + } + content = append(content, cdata.Bytes...) + compound.Bytes = compound.Bytes[len(cdata.FullBytes):] + } + } else { + content = compound.Bytes + } + } else { + // assuming this is tag 04 + content = compound.Bytes + } + return &PKCS7{ + Content: content, + Certificates: certs, + CRLs: sd.CRLs, + Signers: sd.SignerInfos, + raw: sd}, nil +} + +// verifyCertChain takes an end-entity certs, a list of potential intermediates and a +// truststore, and built all potential chains between the EE and a trusted root. +// +// When verifying chains that may have expired, currentTime can be set to a past date +// to allow the verification to pass. If unset, currentTime is set to the current UTC time. +func verifyCertChain(ee *x509.Certificate, certs []*x509.Certificate, truststore *x509.CertPool, currentTime time.Time) (chains [][]*x509.Certificate, err error) { + intermediates := x509.NewCertPool() + for _, intermediate := range certs { + intermediates.AddCert(intermediate) + } + verifyOptions := x509.VerifyOptions{ + Roots: truststore, + Intermediates: intermediates, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, + CurrentTime: currentTime, + } + chains, err = ee.Verify(verifyOptions) + if err != nil { + return chains, fmt.Errorf("pkcs7: failed to verify certificate chain: %v", err) + } + return +} + +// MessageDigestMismatchError is returned when the signer data digest does not +// match the computed digest for the contained content +type MessageDigestMismatchError struct { + ExpectedDigest []byte + ActualDigest []byte +} + +func (err *MessageDigestMismatchError) Error() string { + return fmt.Sprintf("pkcs7: Message digest mismatch\n\tExpected: %X\n\tActual : %X", err.ExpectedDigest, err.ActualDigest) +} + +func getSignatureAlgorithm(digestEncryption, digest pkix.AlgorithmIdentifier) (x509.SignatureAlgorithm, error) { + switch { + case digestEncryption.Algorithm.Equal(OIDDigestAlgorithmECDSASHA1): + return x509.ECDSAWithSHA1, nil + case digestEncryption.Algorithm.Equal(OIDDigestAlgorithmECDSASHA256): + return x509.ECDSAWithSHA256, nil + case digestEncryption.Algorithm.Equal(OIDDigestAlgorithmECDSASHA384): + return x509.ECDSAWithSHA384, nil + case digestEncryption.Algorithm.Equal(OIDDigestAlgorithmECDSASHA512): + return x509.ECDSAWithSHA512, nil + case digestEncryption.Algorithm.Equal(OIDEncryptionAlgorithmRSA), + digestEncryption.Algorithm.Equal(OIDEncryptionAlgorithmRSASHA1), + digestEncryption.Algorithm.Equal(OIDEncryptionAlgorithmRSASHA256), + digestEncryption.Algorithm.Equal(OIDEncryptionAlgorithmRSASHA384), + digestEncryption.Algorithm.Equal(OIDEncryptionAlgorithmRSASHA512): + switch { + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA1), digest.Algorithm.Equal(OIDEncryptionAlgorithmRSASHA1): + return x509.SHA1WithRSA, nil + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA256), digest.Algorithm.Equal(OIDEncryptionAlgorithmRSASHA256): + return x509.SHA256WithRSA, nil + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA384), digest.Algorithm.Equal(OIDEncryptionAlgorithmRSASHA384): + return x509.SHA384WithRSA, nil + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA512), digest.Algorithm.Equal(OIDEncryptionAlgorithmRSASHA512): + return x509.SHA512WithRSA, nil + default: + return -1, fmt.Errorf("pkcs7: unsupported digest %q for encryption algorithm %q", + digest.Algorithm.String(), digestEncryption.Algorithm.String()) + } + case digestEncryption.Algorithm.Equal(OIDDigestAlgorithmDSA), + digestEncryption.Algorithm.Equal(OIDDigestAlgorithmDSASHA1): + switch { + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA1): + return x509.DSAWithSHA1, nil + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA256): + return x509.DSAWithSHA256, nil + default: + return -1, fmt.Errorf("pkcs7: unsupported digest %q for encryption algorithm %q", + digest.Algorithm.String(), digestEncryption.Algorithm.String()) + } + case digestEncryption.Algorithm.Equal(OIDEncryptionAlgorithmECDSAP256), + digestEncryption.Algorithm.Equal(OIDEncryptionAlgorithmECDSAP384), + digestEncryption.Algorithm.Equal(OIDEncryptionAlgorithmECDSAP521): + switch { + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA1): + return x509.ECDSAWithSHA1, nil + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA256): + return x509.ECDSAWithSHA256, nil + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA384): + return x509.ECDSAWithSHA384, nil + case digest.Algorithm.Equal(OIDDigestAlgorithmSHA512): + return x509.ECDSAWithSHA512, nil + default: + return -1, fmt.Errorf("pkcs7: unsupported digest %q for encryption algorithm %q", + digest.Algorithm.String(), digestEncryption.Algorithm.String()) + } + default: + return -1, fmt.Errorf("pkcs7: unsupported algorithm %q", + digestEncryption.Algorithm.String()) + } +} + +func getCertFromCertsByIssuerAndSerial(certs []*x509.Certificate, ias issuerAndSerial) *x509.Certificate { + for _, cert := range certs { + if isCertMatchForIssuerAndSerial(cert, ias) { + return cert + } + } + return nil +} + +func unmarshalAttribute(attrs []attribute, attributeType asn1.ObjectIdentifier, out interface{}) error { + for _, attr := range attrs { + if attr.Type.Equal(attributeType) { + _, err := asn1.Unmarshal(attr.Value.Bytes, out) + return err + } + } + return errors.New("pkcs7: attribute type not in attributes") +} + +func calculateHash(hasher Hasher, hashFunc crypto.Hash, content []byte) (computed []byte, err error) { + if hasher != nil { + computed, err = hasher.Hash(hashFunc, bytes.NewReader(content)) + } else { + if !hashFunc.Available() { + return nil, fmt.Errorf("hash function %v not available", hashFunc) + } + + h := hashFunc.New() + _, _ = h.Write(content) + computed = h.Sum(nil) + } + + return +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 0210cda41a224..06220f477c89f 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -974,6 +974,10 @@ github.com/shopspring/decimal # github.com/sirupsen/logrus v1.9.3 ## explicit; go 1.13 github.com/sirupsen/logrus +# github.com/smallstep/pkcs7 v0.2.1 +## explicit; go 1.14 +github.com/smallstep/pkcs7 +github.com/smallstep/pkcs7/internal/legacy/x509 # github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 ## explicit; go 1.20 github.com/sourcegraph/conc