Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ WORKDIR /src
COPY go.mod go.sum ./
RUN go mod download
COPY *.go ./
COPY crypt ./crypt
COPY wrap ./wrap
COPY cmd ./cmd
RUN CGO_ENABLED=0 go build -trimpath -buildvcs=false -ldflags=-buildid= -o /modelwrap ./cmd/modelwrap
Expand Down
4 changes: 0 additions & 4 deletions cmd/modelwrap/launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ func launch(opts cliOptions) int {
// of the same CLI inside the packer image.
func dockerRunArgs(opts cliOptions) ([]string, error) {
args := []string{"run", "--rm"}
if opts.Encrypt {
// EMWP packing needs loop device and device-mapper access.
args = append(args, "--privileged")
}

hostDir := func(path, fallback string) (string, error) {
if path == "" {
Expand Down
9 changes: 8 additions & 1 deletion cmd/modelwrap/launcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestDockerRunArgs(t *testing.T) {
}

want := []string{
"run", "--rm", "--privileged",
"run", "--rm",
"-v", filepath.Join(dir, "output") + ":/output",
"-v", filepath.Join(dir, "cache") + ":/cache",
"-v", filepath.Join(dir, "weights") + ":/model:ro",
Expand All @@ -54,6 +54,13 @@ func TestDockerRunArgs(t *testing.T) {
t.Fatalf("dockerRunArgs mismatch:\n got %q\nwant %q", got, want)
}

// EMWP packing is fully userspace and must not request privilege.
for _, arg := range got {
if arg == "--privileged" {
t.Fatal("EMWP packing must not be privileged")
}
}

// Secret values must never appear in the docker command line.
for _, arg := range got {
if arg == "secret" {
Expand Down
160 changes: 160 additions & 0 deletions crypt/crypt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Package crypt implements the EMWP dm-crypt encryption (aes-xts-plain64,
// 512-bit key, 4096-byte sectors) in pure Go.
//
// The output is byte-identical to `cryptsetup open --type plain` with the
// parameters the packer uses, so the kernel dm-crypt consumer decrypts
// modelwrap artifacts unchanged. Producing the ciphertext in userspace lets
// the packer encrypt EMWP artifacts without cryptsetup, loop devices,
// device-mapper, or a privileged container, and without writing the volume
// key to disk.
//
// Only the IV-per-sector convention is dm-crypt specific; everything else
// (the two-key split, the tweak derivation, the little-endian sector
// encoding) is standard XTS as implemented by golang.org/x/crypto/xts. That
// one convention is pinned by the dm-crypt golden-vector test.
package crypt

import (
"crypto/aes"
"fmt"
"io"

"golang.org/x/crypto/xts"

"github.com/tinfoilsh/modelwrap"
)

// SectorSize is the dm-crypt data-unit size (cryptsetup --sector-size). Each
// SectorSize-byte block is encrypted as one XTS data unit: the tweak is
// derived once from the sector's IV and chained by the GF(2^128) multiply
// across the block's 16-byte AES blocks.
const SectorSize = modelwrap.EMWPSectorSize

// ivSectorRatio converts a SectorSize-byte sector index into the 512-byte
// sector number that plain64 uses for the IV. dm-crypt keeps 512-byte IV
// numbering unless cryptsetup is given --iv-large-sectors, which the packer
// does not use, so each 4096-byte sector advances the IV by 8. This is the
// single dm-crypt specific convention; if TestDmcryptGolden ever fails, this
// constant (8 vs 1) is the first thing to check.
const ivSectorRatio = SectorSize / 512

// streamChunkBytes is the streaming buffer size: 4 MiB, sector-aligned,
// matching the old copyToDevice buffer so large artifacts never load fully
// into memory.
const streamChunkBytes = 1024 * SectorSize

// The packer and consumer always open dm-crypt with skip 0, so the volume's
// first sector is IV unit 0; there is no non-zero-skip path to support.

func newCipher(volumeKey []byte) (*xts.Cipher, error) {
if len(volumeKey) != modelwrap.EMWPKeyBytes {
return nil, fmt.Errorf("volume key is %d bytes, want %d", len(volumeKey), modelwrap.EMWPKeyBytes)
}
return xts.NewCipher(aes.NewCipher, volumeKey)
}

// transform encrypts or decrypts a sector-aligned buffer in place. baseUnit
// is the 0-based sector index of buf[0] within the volume.
func transform(c *xts.Cipher, buf []byte, baseUnit uint64, decrypt bool) {
for off := 0; off < len(buf); off += SectorSize {
iv := (baseUnit + uint64(off/SectorSize)) * ivSectorRatio
s := buf[off : off+SectorSize]
if decrypt {
c.Decrypt(s, s, iv)
} else {
c.Encrypt(s, s, iv)
}
}
}

// Encrypt encrypts a whole sector-aligned plaintext with the raw 64-byte
// dm-crypt volume key (already derived via modelwrap.DeriveKey). It is a
// convenience for small buffers and tests; the packer uses EncryptStream.
func Encrypt(volumeKey, plaintext []byte) ([]byte, error) {
return inMemory(volumeKey, plaintext, false)
}

// Decrypt is the inverse of Encrypt.
func Decrypt(volumeKey, ciphertext []byte) ([]byte, error) {
return inMemory(volumeKey, ciphertext, true)
}

func inMemory(volumeKey, in []byte, decrypt bool) ([]byte, error) {
if len(in)%SectorSize != 0 {
return nil, fmt.Errorf("data length %d is not a multiple of sector size %d", len(in), SectorSize)
}
c, err := newCipher(volumeKey)
if err != nil {
return nil, err
}
out := make([]byte, len(in))
copy(out, in)
transform(c, out, 0, decrypt)
return out, nil
}

// EncryptStream reads plaintext from src and writes ciphertext to dst,
// encrypting in sector-aligned chunks so large artifacts never load fully
// into memory. It returns the number of ciphertext bytes written (always a
// multiple of SectorSize).
//
// A trailing partial sector is zero-padded. In practice MWP images are always
// sector-aligned (the EROFS image and the dm-verity hash tree are both whole
// multiples of the 4096-byte block), so real artifacts encrypt with no
// padding and the padding path is purely defensive.
func EncryptStream(volumeKey []byte, dst io.Writer, src io.Reader) (int64, error) {
return stream(volumeKey, dst, src, false)
}

// DecryptStream is the inverse of EncryptStream. Its input must be
// sector-aligned (ciphertext always is); a trailing partial sector is an
// error rather than being padded.
func DecryptStream(volumeKey []byte, dst io.Writer, src io.Reader) (int64, error) {
return stream(volumeKey, dst, src, true)
}

func stream(volumeKey []byte, dst io.Writer, src io.Reader, decrypt bool) (int64, error) {
c, err := newCipher(volumeKey)
if err != nil {
return 0, err
}
buf := make([]byte, streamChunkBytes)
var baseUnit uint64
var written int64
for {
n, readErr := io.ReadFull(src, buf)

// A short read from io.ReadFull means either end of stream (io.EOF or
// io.ErrUnexpectedEOF) or a genuine read error. Only end of stream
// justifies padding a trailing partial sector; on a real error we must
// not emit that (incorrectly padded) ciphertext, so surface the error
// before writing anything.
eof := readErr == io.EOF || readErr == io.ErrUnexpectedEOF
if readErr != nil && !eof {
return written, readErr
}

if n > 0 {
full := n
if rem := n % SectorSize; rem != 0 {
if decrypt {
return written, fmt.Errorf("ciphertext is not sector-aligned (trailing %d bytes)", rem)
}
// A partial sector here implies clean EOF, so the tail is the
// real end of the data: zero-pad it to a full sector.
full = n - rem + SectorSize
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
clear(buf[n:full])
}
transform(c, buf[:full], baseUnit, decrypt)
if _, err := dst.Write(buf[:full]); err != nil {
return written, err
}
baseUnit += uint64(full / SectorSize)
written += int64(full)
}

if eof {
return written, nil
}
}
}
177 changes: 177 additions & 0 deletions crypt/crypt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package crypt

import (
"bytes"
"errors"
"os"
"path/filepath"
"testing"
)

// TestDmcryptGolden is the authoritative compatibility check: it asserts
// crypt.Encrypt reproduces, byte for byte, the ciphertext that the real
// cryptsetup produced for the same key and plaintext with the packer's exact
// flags (see testdata/gen-golden.sh). Passing this proves the kernel
// dm-crypt consumer will decrypt artifacts this package encrypts, and pins
// the one dm-crypt specific convention (the per-sector IV) without needing
// cryptsetup at test time.
func TestDmcryptGolden(t *testing.T) {
key := readTestdata(t, "key.bin")
pt := readTestdata(t, "plaintext.bin")
ct := readTestdata(t, "ciphertext.bin")

got, err := Encrypt(key, pt)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
if !bytes.Equal(got, ct) {
t.Fatalf("ciphertext does not match dm-crypt golden vector;\n"+
"the per-sector IV convention is likely wrong (try ivSectorRatio=1).\n"+
"first mismatch at byte %d", firstDiff(got, ct))
}

back, err := Decrypt(key, ct)
if err != nil {
t.Fatalf("Decrypt: %v", err)
}
if !bytes.Equal(back, pt) {
t.Fatal("Decrypt(golden ciphertext) != plaintext")
}
}

// TestDmcryptGoldenStream checks the streaming path produces the same golden
// ciphertext and round-trips, including across chunk boundaries.
func TestDmcryptGoldenStream(t *testing.T) {
key := readTestdata(t, "key.bin")
pt := readTestdata(t, "plaintext.bin")
ct := readTestdata(t, "ciphertext.bin")

var enc bytes.Buffer
if _, err := EncryptStream(key, &enc, bytes.NewReader(pt)); err != nil {
t.Fatalf("EncryptStream: %v", err)
}
if !bytes.Equal(enc.Bytes(), ct) {
t.Fatalf("streamed ciphertext != golden (first mismatch at %d)", firstDiff(enc.Bytes(), ct))
}

var dec bytes.Buffer
if _, err := DecryptStream(key, &dec, bytes.NewReader(ct)); err != nil {
t.Fatalf("DecryptStream: %v", err)
}
if !bytes.Equal(dec.Bytes(), pt) {
t.Fatal("DecryptStream(golden) != plaintext")
}
}

// TestStreamPadsTrailingSector confirms a non-sector-aligned plaintext is
// zero-padded on encrypt and recovered (with padding) on decrypt, matching
// the old backing-file behavior.
func TestStreamPadsTrailingSector(t *testing.T) {
key := bytes.Repeat([]byte{0x5A}, 64)
pt := bytes.Repeat([]byte{0xEE}, 2*SectorSize+123) // unaligned

var enc bytes.Buffer
n, err := EncryptStream(key, &enc, bytes.NewReader(pt))
if err != nil {
t.Fatalf("EncryptStream: %v", err)
}
if n != 3*SectorSize || enc.Len() != 3*SectorSize {
t.Fatalf("padded ciphertext = %d bytes, want %d", enc.Len(), 3*SectorSize)
}

var dec bytes.Buffer
if _, err := DecryptStream(key, &dec, &enc); err != nil {
t.Fatalf("DecryptStream: %v", err)
}
if !bytes.Equal(dec.Bytes()[:len(pt)], pt) {
t.Fatal("decrypted prefix != original plaintext")
}
for i := len(pt); i < dec.Len(); i++ {
if dec.Bytes()[i] != 0 {
t.Fatalf("pad byte at %d = %d, want 0", i, dec.Bytes()[i])
}
}
}

// TestIdenticalSectorsDistinctCiphertext confirms the per-sector IV is
// actually applied: identical plaintext sectors at different offsets must
// encrypt to different ciphertext. testdata's sector 5 duplicates sector 0.
func TestIdenticalSectorsDistinctCiphertext(t *testing.T) {
key := readTestdata(t, "key.bin")
pt := readTestdata(t, "plaintext.bin")
if !bytes.Equal(pt[0:SectorSize], pt[5*SectorSize:6*SectorSize]) {
t.Skip("testdata sector 5 no longer duplicates sector 0")
}
ct, err := Encrypt(key, pt)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
if bytes.Equal(ct[0:SectorSize], ct[5*SectorSize:6*SectorSize]) {
t.Fatal("identical plaintext sectors produced identical ciphertext; IV not per-sector")
}
}

func TestRejectsBadInput(t *testing.T) {
good := bytes.Repeat([]byte{1}, 64)
if _, err := Encrypt(good[:32], make([]byte, SectorSize)); err == nil {
t.Fatal("expected error for short key")
}
if _, err := Encrypt(good, make([]byte, SectorSize+1)); err == nil {
t.Fatal("expected error for non-sector-multiple length")
}
if _, err := DecryptStream(good, &bytes.Buffer{}, bytes.NewReader(make([]byte, SectorSize+1))); err == nil {
t.Fatal("expected error decrypting non-sector-aligned ciphertext")
}
}

// errReader yields data once together with a non-EOF error, mimicking a
// reader that fails mid-stream after a partial, non-sector-aligned read.
type errReader struct {
data []byte
err error
done bool
}

func (r *errReader) Read(p []byte) (int, error) {
if r.done {
return 0, r.err
}
r.done = true
return copy(p, r.data), r.err
}

// TestEncryptStreamReadErrorNoWrite guards against treating a failed partial
// read as end-of-stream: a genuine read error must surface without emitting
// any (zero-padded) ciphertext.
func TestEncryptStreamReadErrorNoWrite(t *testing.T) {
key := bytes.Repeat([]byte{0x11}, 64)
boom := errors.New("boom")
r := &errReader{data: bytes.Repeat([]byte{0xCD}, 100), err: boom} // 100 % SectorSize != 0

var dst bytes.Buffer
n, err := EncryptStream(key, &dst, r)
if !errors.Is(err, boom) {
t.Fatalf("err = %v, want boom", err)
}
if n != 0 || dst.Len() != 0 {
t.Fatalf("wrote %d bytes (reported %d) despite read error; want none", dst.Len(), n)
}
}

func readTestdata(t *testing.T, name string) []byte {
t.Helper()
b, err := os.ReadFile(filepath.Join("testdata", name))
if err != nil {
t.Fatalf("reading testdata/%s: %v", name, err)
}
return b
}

func firstDiff(a, b []byte) int {
for i := 0; i < len(a) && i < len(b); i++ {
if a[i] != b[i] {
return i
}
}
return min(len(a), len(b))
}
Binary file added crypt/testdata/ciphertext.bin
Binary file not shown.
Loading
Loading