Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ export JWKS_URL=http://lfx-platform-heimdall.lfx.svc.cluster.local:4457/.well-kn
# JWT audience
export AUDIENCE=lfx-v2-project-service

# JWT signature algorithm (PS256, PS384, PS512, RS256, RS384, RS512, ES256, ES384, ES512)
export JWT_SIGNATURE_ALGORITHM=PS256

# Skip the ETag validation that requires the correct revision on PUT/DELETE requests.
# When this is set to false, it means you need to make a GET request on the resource
# to get the ETag response header and use it as the ETag request header on the PUT/DELETE
Expand Down
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ func TestEndpoint(t *testing.T) {
| `JWKS_URL` | JWT verification endpoint | - | No |
| `AUDIENCE` | JWT audience | lfx-v2-project-service | No |
| `JWT_AUTH_DISABLED_MOCK_LOCAL_PRINCIPAL` | Mock auth for local dev | - | No |
| `JWT_SIGNATURE_ALGORITHM` | JWT signature algorithm | PS256 | No |

## Authorization (OpenFGA)

Expand Down
2 changes: 1 addition & 1 deletion charts/lfx-v2-project-service/Chart.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ apiVersion: v2
name: lfx-v2-project-service
description: LFX Platform V2 Project Service chart
type: application
version: 0.5.3
version: 0.5.4
appVersion: "latest"
2 changes: 2 additions & 0 deletions charts/lfx-v2-project-service/templates/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ spec:
value: {{ .Values.app.skipEtagValidation | quote }}
- name: JWT_AUTH_DISABLED_MOCK_LOCAL_PRINCIPAL
value: {{ .Values.app.jwtAuthDisabledMockLocalPrincipal }}
- name: JWT_SIGNATURE_ALGORITHM
value: {{ .Values.app.jwtSignatureAlgorithm }}
ports:
- containerPort: {{ .Values.service.port }}
name: web
Expand Down
4 changes: 4 additions & 0 deletions charts/lfx-v2-project-service/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ app:
# jwtAuthDisabledMockLocalPrincipal mocks auth for local development to use a set principal
# (only use for local development)
jwtAuthDisabledMockLocalPrincipal: ""
# jwtSignatureAlgorithm is the JWT signature algorithm for token validation
# Supported: PS256 (default), PS384, PS512, RS256, RS384, RS512, ES256, ES384, ES512
# Algorithm names are case-sensitive and must be uppercase
jwtSignatureAlgorithm: "PS256"
# use_oidc_contextualizer is a boolean to determine if the OIDC contextualizer should be used
use_oidc_contextualizer: true

Expand Down
1 change: 1 addition & 0 deletions cmd/project-api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func main() {
JWKSURL: os.Getenv("JWKS_URL"),
Audience: os.Getenv("AUDIENCE"),
MockLocalPrincipal: os.Getenv("JWT_AUTH_DISABLED_MOCK_LOCAL_PRINCIPAL"),
SignatureAlgorithm: os.Getenv("JWT_SIGNATURE_ALGORITHM"),
}
jwtAuth, err := auth.NewJWTAuth(jwtAuthConfig)
if err != nil {
Expand Down
57 changes: 50 additions & 7 deletions internal/infrastructure/auth/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,40 @@ import (
)

const (
// PS256 is the default for Heimdall's JWT finalizer.
signatureAlgorithm = validator.PS256
defaultIssuer = "heimdall"
defaultAudience = "lfx-v2-project-service"
defaultJWKSURL = "http://heimdall:4457/.well-known/jwks"
// PS256 is the default signature algorithm used when JWT_SIGNATURE_ALGORITHM is not set.
defaultSignatureAlgorithm = validator.PS256
defaultIssuer = "heimdall"
defaultAudience = "lfx-v2-project-service"
defaultJWKSURL = "http://heimdall:4457/.well-known/jwks"
)

// parseSignatureAlgorithm converts the algorithm string to a validator.SignatureAlgorithm.
// Returns PS256 as default if algoString is empty.
// Algorithm names are case-sensitive and must be uppercase (e.g., "PS256").
func parseSignatureAlgorithm(algoString string) (validator.SignatureAlgorithm, error) {
if algoString == "" {
return validator.PS256, nil
}

algorithms := map[string]validator.SignatureAlgorithm{
"PS256": validator.PS256,
"PS384": validator.PS384,
"PS512": validator.PS512,
"RS256": validator.RS256,
"RS384": validator.RS384,
"RS512": validator.RS512,
"ES256": validator.ES256,
"ES384": validator.ES384,
"ES512": validator.ES512,
}

if algo, exists := algorithms[algoString]; exists {
return algo, nil
}

return "", errors.New("unsupported JWT signature algorithm: " + algoString + " (supported: PS256, PS384, PS512, RS256, RS384, RS512, ES256, ES384, ES512)")
}

// JWTAuthConfig holds the configuration parameters for JWT authentication.
type JWTAuthConfig struct {
// JWKSURL is the URL to the JSON Web Key Set endpoint
Expand All @@ -33,6 +60,8 @@ type JWTAuthConfig struct {
Audience string
// MockLocalPrincipal is used for local development to bypass JWT validation
MockLocalPrincipal string
// SignatureAlgorithm is the JWT signature algorithm (e.g., PS256, RS256, ES256)
SignatureAlgorithm string
}

var (
Expand Down Expand Up @@ -67,6 +96,20 @@ type JWTAuth struct {
var _ domain.Authenticator = (*JWTAuth)(nil)

func NewJWTAuth(config JWTAuthConfig) (*JWTAuth, error) {
// Parse signature algorithm
algo, err := parseSignatureAlgorithm(config.SignatureAlgorithm)
if err != nil {
slog.With(constants.ErrKey, err).Error("invalid JWT signature algorithm")
return nil, err
}

// Log algorithm selection (especially if non-default)
if config.SignatureAlgorithm != "" && config.SignatureAlgorithm != "PS256" {
slog.Info("using non-default JWT signature algorithm",
"algorithm", config.SignatureAlgorithm,
)
}

// Set up defaults if not provided
jwksURLStr := config.JWKSURL
if jwksURLStr == "" {
Expand All @@ -92,10 +135,10 @@ func NewJWTAuth(config JWTAuthConfig) (*JWTAuth, error) {
}
provider := jwks.NewCachingProvider(issuer, 5*time.Minute, jwks.WithCustomJWKSURI(jwksURL))

// Set up the JWT validator.
// Set up the JWT validator with selected algorithm.
jwtValidator, err := validator.New(
provider.KeyFunc,
signatureAlgorithm,
algo,
issuer.String(),
[]string{audience},
validator.WithCustomClaims(customClaims),
Expand Down
86 changes: 85 additions & 1 deletion internal/infrastructure/auth/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func TestJWTAuth_Constants(t *testing.T) {
assert.Equal(t, "heimdall", defaultIssuer)
assert.Equal(t, "lfx-v2-project-service", defaultAudience)
assert.Equal(t, "http://heimdall:4457/.well-known/jwks", defaultJWKSURL)
assert.NotNil(t, signatureAlgorithm)
assert.NotNil(t, defaultSignatureAlgorithm)
})
}

Expand Down Expand Up @@ -307,6 +307,38 @@ func TestJWTAuth_ConfigurationHandling(t *testing.T) {
shouldError: false,
description: "should accept mock principal",
},
{
name: "custom signature algorithm ES256",
config: JWTAuthConfig{
SignatureAlgorithm: "ES256",
},
shouldError: false,
description: "should accept valid signature algorithm",
},
{
name: "custom signature algorithm RS256",
config: JWTAuthConfig{
SignatureAlgorithm: "RS256",
},
shouldError: false,
description: "should accept RS256 signature algorithm",
},
{
name: "invalid signature algorithm",
config: JWTAuthConfig{
SignatureAlgorithm: "INVALID",
},
shouldError: true,
description: "should reject invalid signature algorithm",
},
{
name: "lowercase signature algorithm rejected",
config: JWTAuthConfig{
SignatureAlgorithm: "ps256",
},
shouldError: true,
description: "should reject lowercase signature algorithm",
},
}

for _, tt := range tests {
Expand All @@ -326,3 +358,55 @@ func TestJWTAuth_ConfigurationHandling(t *testing.T) {
})
}
}

func TestParseSignatureAlgorithm(t *testing.T) {
tests := []struct {
name string
algorithm string
wantErr bool
}{
// Valid algorithms - PS family
{name: "PS256 valid", algorithm: "PS256", wantErr: false},
{name: "PS384 valid", algorithm: "PS384", wantErr: false},
{name: "PS512 valid", algorithm: "PS512", wantErr: false},

// Valid algorithms - RS family
{name: "RS256 valid", algorithm: "RS256", wantErr: false},
{name: "RS384 valid", algorithm: "RS384", wantErr: false},
{name: "RS512 valid", algorithm: "RS512", wantErr: false},

// Valid algorithms - ES family
{name: "ES256 valid", algorithm: "ES256", wantErr: false},
{name: "ES384 valid", algorithm: "ES384", wantErr: false},
{name: "ES512 valid", algorithm: "ES512", wantErr: false},

// Empty string uses default
{name: "empty defaults to PS256", algorithm: "", wantErr: false},

// Invalid - case sensitivity
{name: "lowercase rejected", algorithm: "ps256", wantErr: true},
{name: "mixed case rejected", algorithm: "Ps256", wantErr: true},

// Invalid - HMAC algorithms not supported
{name: "HS256 unsupported", algorithm: "HS256", wantErr: true},
{name: "HS384 unsupported", algorithm: "HS384", wantErr: true},
{name: "HS512 unsupported", algorithm: "HS512", wantErr: true},

// Invalid - unknown algorithms
{name: "unknown algorithm", algorithm: "UNKNOWN", wantErr: true},
{name: "typo", algorithm: "PS265", wantErr: true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
algo, err := parseSignatureAlgorithm(tt.algorithm)
if tt.wantErr {
assert.Error(t, err, "expected error for algorithm %q", tt.algorithm)
assert.Empty(t, algo, "expected empty algorithm for %q", tt.algorithm)
} else {
assert.NoError(t, err, "unexpected error for algorithm %q", tt.algorithm)
assert.NotEmpty(t, algo, "expected valid algorithm for %q", tt.algorithm)
}
})
}
}
Loading