Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions coordinator/internal/stateguard/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation"
"github.com/edgelesssys/contrast/internal/attestation/certcache"
"github.com/edgelesssys/contrast/internal/attestation/insecure"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/attestation/tdx"
"github.com/edgelesssys/contrast/internal/constants"
Expand Down Expand Up @@ -104,6 +105,12 @@ func (c *Credentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.A
logger.NewWithAttrs(logger.NewNamed(c.logger, "validator"), map[string]string{"reference-values": name}), &authInfo, name))
}

if state.Manifest().AllowInsecure() {
validators = append(validators, insecure.NewValidatorWithReportSetter(
logger.NewWithAttrs(logger.NewNamed(c.logger, "validator"), map[string]string{"reference-values": "insecure"}),
&authInfo, "insecure"))
}

serverCfg, err := atls.CreateAttestationServerTLSConfig(c.issuer, validators, c.attestationFailuresCounter)
if err != nil {
log.Error("Could not create TLS config", "error", err)
Expand Down
20 changes: 17 additions & 3 deletions coordinator/internal/stateguard/stateguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ var (
// ErrConcurrentUpdate is returned by state-modifying operations if the input oldState is not
// the current state. This usually happens when a concurrent operation succeeded.
ErrConcurrentUpdate = errors.New("coordinator state was updated concurrently")

// ErrInsecureNotAllowed is returned when a manifest contains insecure platforms but the
// coordinator was not started with the allow-insecure flag.
ErrInsecureNotAllowed = errors.New("manifest contains insecure platforms, but the coordinator is not configured to allow them")
)

// Guard manages the manifest state of Contrast.
Expand All @@ -65,6 +69,9 @@ type Guard struct {
logger *slog.Logger
metrics metrics

// allowInsecure controls whether manifests with insecure platforms are accepted.
allowInsecure bool

clock clock.Clock
}

Expand All @@ -73,7 +80,10 @@ type metrics struct {
}

// New creates a new state Guard instance.
func New(hist *history.History, reg *prometheus.Registry, log *slog.Logger) *Guard {
//
// If allowInsecure is true, the Guard will accept manifests that contain insecure platforms.
// Otherwise, setting such a manifest will be rejected with ErrInsecureNotAllowed.
func New(hist *history.History, reg *prometheus.Registry, log *slog.Logger, allowInsecure bool) *Guard {
manifestGeneration := promauto.With(reg).NewGauge(prometheus.GaugeOpts{
Subsystem: "contrast_coordinator",
Name: "manifest_generation",
Expand All @@ -82,8 +92,9 @@ func New(hist *history.History, reg *prometheus.Registry, log *slog.Logger) *Gua
manifestGeneration.Set(0)

return &Guard{
hist: hist,
logger: log.WithGroup("stateguard"),
hist: hist,
logger: log.WithGroup("stateguard"),
allowInsecure: allowInsecure,
metrics: metrics{
manifestGeneration: manifestGeneration,
},
Expand Down Expand Up @@ -271,6 +282,9 @@ func (g *Guard) UpdateState(_ context.Context, oldState *State, se *seedengine.S
if err := json.Unmarshal(manifestBytes, &mnfst); err != nil {
return nil, fmt.Errorf("unmarshaling manifest: %w", err)
}
if !g.allowInsecure && mnfst.AllowInsecure() {
return nil, ErrInsecureNotAllowed
}
policyMap := make(map[[history.HashSize]byte][]byte)
for _, policy := range policies {
policyHash, err := g.hist.SetPolicy(policy)
Expand Down
63 changes: 58 additions & 5 deletions coordinator/internal/stateguard/stateguard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,37 @@ func TestResetState(t *testing.T) {
require.ErrorIs(err, assert.AnError)
}

func TestUpdateStateInsecure(t *testing.T) {
ctx := t.Context()

_, insecureManifestBytes, policies := newInsecureManifest(t)
se := newSeedEngine(t)

t.Run("rejected when allowInsecure is false", func(t *testing.T) {
require := require.New(t)

store := aferostore.New(&afero.Afero{Fs: afero.NewMemMapFs()})
hist := history.NewWithStore(slog.Default(), store)
g := New(hist, prometheus.NewRegistry(), slog.Default(), false)

state, err := g.UpdateState(ctx, nil, se, insecureManifestBytes, policies)
require.ErrorIs(err, ErrInsecureNotAllowed)
require.Nil(state)
})

t.Run("accepted when allowInsecure is true", func(t *testing.T) {
require := require.New(t)

store := aferostore.New(&afero.Afero{Fs: afero.NewMemMapFs()})
hist := history.NewWithStore(slog.Default(), store)
g := New(hist, prometheus.NewRegistry(), slog.Default(), true)

state, err := g.UpdateState(ctx, nil, se, insecureManifestBytes, policies)
require.NoError(err)
require.NotNil(state)
})
}

func TestConcurrentUpdateState(t *testing.T) {
ctx := t.Context()
assert := assert.New(t)
Expand All @@ -200,7 +231,7 @@ func TestConcurrentUpdateState(t *testing.T) {
Store: aferostore.New(&afero.Afero{Fs: afero.NewMemMapFs()}),
}
hist := history.NewWithStore(slog.Default(), store)
guard := New(hist, prometheus.NewRegistry(), slog.Default())
guard := New(hist, prometheus.NewRegistry(), slog.Default(), false)

numWorkers := 20

Expand Down Expand Up @@ -303,7 +334,7 @@ func TestWatchHistory(t *testing.T) {
notifications: make(chan []byte),
}
hist := history.NewWithStore(slog.Default(), store)
g := New(hist, prometheus.NewRegistry(), slog.Default())
g := New(hist, prometheus.NewRegistry(), slog.Default(), false)

_, manifestBytes, policies := newManifest(t)

Expand Down Expand Up @@ -352,7 +383,7 @@ func TestWatchHistoryLateNotifications(t *testing.T) {
notifications: make(chan []byte),
}
hist := history.NewWithStore(slog.Default(), store)
g := New(hist, prometheus.NewRegistry(), slog.Default())
g := New(hist, prometheus.NewRegistry(), slog.Default(), false)

_, manifestBytes, policies := newManifest(t)

Expand Down Expand Up @@ -409,7 +440,7 @@ func TestBadStoreWatcherIsRestarted(t *testing.T) {
store.storeUpdates.Store(&ch)
hist := history.NewWithStore(slog.Default(), store)
reg := prometheus.NewRegistry()
a := New(hist, reg, slog.Default())
a := New(hist, reg, slog.Default(), false)
clock := &waitingClock{
FakeClock: testingclock.NewFakeClock(time.Now()),
afterCalls: make(chan struct{}, 1),
Expand Down Expand Up @@ -502,7 +533,7 @@ func newTestGuard(t *testing.T) (*Guard, *prometheus.Registry) {
store := aferostore.New(&afero.Afero{Fs: afero.NewMemMapFs()})
hist := history.NewWithStore(slog.Default(), store)
reg := prometheus.NewRegistry()
return New(hist, reg, slog.Default()), reg
return New(hist, reg, slog.Default(), false), reg
}

func newManifest(t *testing.T) (*manifest.Manifest, []byte, [][]byte) {
Expand Down Expand Up @@ -543,6 +574,28 @@ func newManifest(t *testing.T) (*manifest.Manifest, []byte, [][]byte) {
return mnfst, mnfstBytes, [][]byte{policy}
}

func newInsecureManifest(t *testing.T) (*manifest.Manifest, []byte, [][]byte) {
t.Helper()
policy := []byte("=== SOME REGO HERE ===")
policyHash := sha256.Sum256(policy)
policyHashHex := manifest.NewHexString(policyHash[:])

mnfst := &manifest.Manifest{}
mnfst.Policies = map[manifest.HexString]manifest.PolicyEntry{
policyHashHex: {
SANs: []string{"test"},
WorkloadSecretID: "test2",
Role: manifest.RoleCoordinator,
},
}
mnfst.ReferenceValues.SNP = []manifest.SNPReferenceValues{
{Platform: "Metal-QEMU-Insecure"},
}
mnfstBytes, err := json.Marshal(mnfst)
require.NoError(t, err)
return mnfst, mnfstBytes, [][]byte{policy}
}

func newSeedEngine(t *testing.T) *seedengine.SeedEngine {
t.Helper()
data := make([]byte, 32)
Expand Down
5 changes: 4 additions & 1 deletion coordinator/internal/userapi/userapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,11 @@ func (s *Server) SetManifest(ctx context.Context, req *userapi.SetManifestReques
state, err := s.guard.UpdateState(ctx, oldState, se, req.GetManifest(), req.GetPolicies())
if err != nil {
code := codes.Internal
if errors.Is(err, stateguard.ErrConcurrentUpdate) {
switch {
case errors.Is(err, stateguard.ErrConcurrentUpdate):
code = codes.FailedPrecondition
case errors.Is(err, stateguard.ErrInsecureNotAllowed):
code = codes.InvalidArgument
}
return nil, status.Errorf(code, "updating Coordinator state: %v", err)
}
Expand Down
60 changes: 53 additions & 7 deletions coordinator/internal/userapi/userapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,34 @@ func TestSetManifest(t *testing.T) {
require.Equal(codes.InvalidArgument, status.Code(err))
})

t.Run("insecure manifest rejected", func(t *testing.T) {
require := require.New(t)

// Default coordinator does not allow insecure manifests.
coordinator := newCoordinator()
m := newInsecureManifest(t)
manifestBytes, err := json.Marshal(m)
require.NoError(err)
req := &userapi.SetManifestRequest{Manifest: manifestBytes}
_, err = coordinator.SetManifest(t.Context(), req)
require.Error(err)
require.Equal(codes.InvalidArgument, status.Code(err))
require.ErrorContains(err, "insecure")
})

t.Run("insecure manifest accepted when allowed", func(t *testing.T) {
require := require.New(t)

coordinator := newCoordinatorAllowInsecure()
m := newInsecureManifest(t)
manifestBytes, err := json.Marshal(m)
require.NoError(err)
req := &userapi.SetManifestRequest{Manifest: manifestBytes}
resp, err := coordinator.SetManifest(t.Context(), req)
require.NoError(err)
require.NotNil(resp)
})

t.Run("atomic manifest update", func(t *testing.T) {
require := require.New(t)

Expand Down Expand Up @@ -404,7 +432,7 @@ func TestRecovery(t *testing.T) {
fs := afero.NewMemMapFs()
store := aferostore.New(&afero.Afero{Fs: fs})
hist := history.NewWithStore(slog.Default(), store)
auth := stateguard.New(hist, prometheus.NewRegistry(), logger)
auth := stateguard.New(hist, prometheus.NewRegistry(), logger, false)
discovery := &stubDiscovery{
peers: tc.peers,
err: tc.peersErr,
Expand Down Expand Up @@ -438,7 +466,7 @@ func TestRecovery(t *testing.T) {
}

// Simulate a restarted Coordinator.
a.guard = stateguard.New(hist, prometheus.NewRegistry(), slog.Default())
a.guard = stateguard.New(hist, prometheus.NewRegistry(), slog.Default(), false)
_, err = a.GetManifests(t.Context(), nil)
require.ErrorContains(err, ErrNeedsRecovery.Error())
_, err = a.Recover(rpcContext(t.Context(), seedShareOwnerKey), recoverReq)
Expand All @@ -460,7 +488,7 @@ func TestRecoveryFlow(t *testing.T) {
fs := afero.NewMemMapFs()
store := aferostore.New(&afero.Afero{Fs: fs})
hist := history.NewWithStore(slog.Default(), store)
auth := stateguard.New(hist, prometheus.NewRegistry(), logger)
auth := stateguard.New(hist, prometheus.NewRegistry(), logger, false)
a := New(logger, auth, &stubDiscovery{})

// 2. A manifest is set and the returned seed is recorded.
Expand Down Expand Up @@ -496,7 +524,7 @@ func TestRecoveryFlow(t *testing.T) {
// 3. A new Coordinator is created with the existing history.
// GetManifests and SetManifest are expected to fail.

a.guard = stateguard.New(hist, prometheus.NewRegistry(), slog.Default())
a.guard = stateguard.New(hist, prometheus.NewRegistry(), slog.Default(), false)
_, err = a.SetManifest(t.Context(), req)
require.ErrorContains(err, ErrNeedsRecovery.Error())

Expand Down Expand Up @@ -539,7 +567,7 @@ func TestUserAPIConcurrent(t *testing.T) {
fs := afero.NewBasePathFs(afero.NewOsFs(), t.TempDir())
store := aferostore.New(&afero.Afero{Fs: fs})
hist := history.NewWithStore(slog.Default(), store)
auth := stateguard.New(hist, prometheus.NewRegistry(), logger)
auth := stateguard.New(hist, prometheus.NewRegistry(), logger, false)
coordinator := New(logger, auth, &stubDiscovery{})

setReq := &userapi.SetManifestRequest{
Expand Down Expand Up @@ -853,14 +881,32 @@ func newCoordinatorWithRegistry(reg *prometheus.Registry) *Server {
fs := afero.NewMemMapFs()
store := aferostore.New(&afero.Afero{Fs: fs})
hist := history.NewWithStore(slog.Default(), store)
auth := stateguard.New(hist, reg, logger)
auth := stateguard.New(hist, reg, logger, false)
return New(logger, auth, &stubDiscovery{})
}

func newCoordinatorAllowInsecure() *Server {
logger := slog.Default()
fs := afero.NewMemMapFs()
store := aferostore.New(&afero.Afero{Fs: fs})
hist := history.NewWithStore(slog.Default(), store)
auth := stateguard.New(hist, prometheus.NewRegistry(), logger, true)
return New(logger, auth, &stubDiscovery{})
}

func newInsecureManifest(t *testing.T) *manifest.Manifest {
t.Helper()
mnfst := &manifest.Manifest{}
mnfst.ReferenceValues.SNP = []manifest.SNPReferenceValues{
{Platform: "Metal-QEMU-Insecure"},
}
return mnfst
}

func newCoordinatorWithWatcher(t *testing.T, hist *history.History) *Server {
t.Helper()
logger := slog.Default()
auth := stateguard.New(hist, prometheus.NewRegistry(), logger)
auth := stateguard.New(hist, prometheus.NewRegistry(), logger, false)
coordinator := New(logger, auth, &stubDiscovery{})

ctx, cancel := context.WithCancel(t.Context())
Expand Down
8 changes: 7 additions & 1 deletion coordinator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import (

const (
metricsEnvVar = "CONTRAST_METRICS"
allowInsecureEnvVar = "CONTRAST_ALLOW_INSECURE"
probeAndMetricsPort = 9102
// transitEngineAPIPort specifies the default port to expose the transit engine API.
transitEngineAPIPort = "8200"
Expand Down Expand Up @@ -115,7 +116,12 @@ func run() (retErr error) {

hist := history.NewWithStore(logger.WithGroup("history"), store)

meshAuth := stateguard.New(hist, promRegistry, logger)
_, allowInsecure := os.LookupEnv(allowInsecureEnvVar)
if allowInsecure {
logger.Warn("Coordinator is configured to allow insecure manifests")
}

meshAuth := stateguard.New(hist, promRegistry, logger, allowInsecure)

issuer, err := issuer.New(logger)
if err != nil {
Expand Down
5 changes: 3 additions & 2 deletions internal/atls/issuer/issuer_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
package issuer

import (
"fmt"
"log/slog"

"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation/insecure"
snpissuer "github.com/edgelesssys/contrast/internal/attestation/snp/issuer"
tdxissuer "github.com/edgelesssys/contrast/internal/attestation/tdx/issuer"
"github.com/edgelesssys/contrast/internal/logger"
Expand All @@ -29,6 +29,7 @@ func New(log *slog.Logger) (atls.Issuer, error) {
logger.NewWithAttrs(logger.NewNamed(log, "issuer"), map[string]string{"tee-type": "tdx"}),
), nil
default:
return nil, fmt.Errorf("unsupported platform: %T", cpuid.CPU)
log.Warn("No TEE platform detected, using insecure attestation issuer")
return insecure.NewIssuer(), nil
}
}
Loading
Loading