diff --git a/util/pkg/vfs/context.go b/util/pkg/vfs/context.go index 1275effe3dd7a..b555b6e281dce 100644 --- a/util/pkg/vfs/context.go +++ b/util/pkg/vfs/context.go @@ -28,8 +28,10 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/gophercloud/gophercloud/v2" "google.golang.org/api/option" storage "google.golang.org/api/storage/v1" @@ -329,6 +331,8 @@ func RetryWithBackoff(backoff wait.Backoff, condition func() (bool, error)) (boo } func (c *VFSContext) buildS3Path(p string) (*S3Path, error) { + endpoint := os.Getenv("S3_ENDPOINT") + u, err := url.Parse(p) if err != nil { return nil, fmt.Errorf("invalid s3 path: %q", p) @@ -342,11 +346,24 @@ func (c *VFSContext) buildS3Path(p string) (*S3Path, error) { return nil, fmt.Errorf("invalid s3 path: %q", p) } - s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, true) + s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, true, func(o *s3.Options) { + if endpoint != "" { + o.BaseEndpoint = aws.String(endpoint) + o.UsePathStyle = true + o.DisableLogOutputChecksumValidationSkipped = true + } else { + o.EndpointResolverV2 = &ResolverV2{} + } + }) return s3path, nil } func (c *VFSContext) buildDOPath(p string) (*S3Path, error) { + endpoint := os.Getenv("S3_ENDPOINT") + if endpoint == "" { + return nil, fmt.Errorf("required S3_ENDPOINT env var for path: %q", p) + } + u, err := url.Parse(p) if err != nil { return nil, fmt.Errorf("invalid spaces path: %q", p) @@ -360,11 +377,20 @@ func (c *VFSContext) buildDOPath(p string) (*S3Path, error) { return nil, fmt.Errorf("invalid spaces path: %q", p) } - s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false) + s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false, func(o *s3.Options) { + o.BaseEndpoint = aws.String(endpoint) + o.UsePathStyle = true + o.DisableLogOutputChecksumValidationSkipped = true + }) return s3path, nil } func (c *VFSContext) buildLinodePath(p string) (*S3Path, error) { + endpoint := os.Getenv("S3_ENDPOINT") + if endpoint == "" { + return nil, fmt.Errorf("required S3_ENDPOINT env var for path: %q", p) + } + u, err := url.Parse(p) if err != nil { return nil, fmt.Errorf("invalid Linode object storage path: %q", p) @@ -378,11 +404,23 @@ func (c *VFSContext) buildLinodePath(p string) (*S3Path, error) { return nil, fmt.Errorf("invalid Linode object storage path: %q", p) } - s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false) + s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false, func(o *s3.Options) { + o.BaseEndpoint = aws.String(endpoint) + o.UsePathStyle = true + o.DisableLogOutputChecksumValidationSkipped = true + // Linode (Akamai) requires checksum-when-required behavior + o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired + o.ResponseChecksumValidation = aws.ResponseChecksumValidationWhenRequired + }) return s3path, nil } func (c *VFSContext) buildHetznerPath(p string) (*S3Path, error) { + endpoint := os.Getenv("S3_ENDPOINT") + if endpoint == "" { + return nil, fmt.Errorf("required S3_ENDPOINT env var for path: %q", p) + } + u, err := url.Parse(p) if err != nil { return nil, fmt.Errorf("invalid Hetzner Object Storage path: %q", p) @@ -396,7 +434,11 @@ func (c *VFSContext) buildHetznerPath(p string) (*S3Path, error) { return nil, fmt.Errorf("invalid Hetzner object storage path: %q", p) } - s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false) + s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false, func(o *s3.Options) { + o.BaseEndpoint = aws.String(endpoint) + o.UsePathStyle = true + o.DisableLogOutputChecksumValidationSkipped = true + }) return s3path, nil } @@ -546,6 +588,11 @@ func (c *VFSContext) getAzureBlobClient(ctx context.Context) (*azblob.Client, er } func (c *VFSContext) buildSCWPath(p string) (*S3Path, error) { + endpoint := os.Getenv("S3_ENDPOINT") + if endpoint == "" { + return nil, fmt.Errorf("required S3_ENDPOINT env var for path: %q", p) + } + u, err := url.Parse(p) if err != nil { return nil, fmt.Errorf("invalid bucket path: %q", p) @@ -559,6 +606,10 @@ func (c *VFSContext) buildSCWPath(p string) (*S3Path, error) { return nil, fmt.Errorf("invalid bucket path: %q", p) } - s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false) + s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false, func(o *s3.Options) { + o.BaseEndpoint = aws.String(endpoint) + o.UsePathStyle = true + o.DisableLogOutputChecksumValidationSkipped = true + }) return s3path, nil } diff --git a/util/pkg/vfs/s3context.go b/util/pkg/vfs/s3context.go index 6678e7f37f8ef..4593cbed9e6d4 100644 --- a/util/pkg/vfs/s3context.go +++ b/util/pkg/vfs/s3context.go @@ -84,50 +84,40 @@ func (*ResolverV2) ResolveEndpoint(ctx context.Context, params s3.EndpointParame return s3.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) } -func (s *S3Context) getClient(ctx context.Context, region string, scheme string) (*s3.Client, error) { +func (s *S3Context) getClient(ctx context.Context, region string, optFn func(*s3.Options)) (*s3.Client, error) { s.mutex.Lock() defer s.mutex.Unlock() - // Client configuration is currently determined by region and process-wide environment. - s3Client := s.clients[region] - if s3Client == nil { - _, span := tracer.Start(ctx, "S3Context::getClient") - defer span.End() - - var config aws.Config - var err error - endpoint := os.Getenv("S3_ENDPOINT") - if endpoint == "" { - config, err = awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) - if err != nil { - return nil, fmt.Errorf("error loading AWS config: %v", err) - } - } else { - // Use customized S3 storage - klog.V(2).Infof("Found S3_ENDPOINT=%q, using as non-AWS S3 backend", endpoint) - config, err = getCustomS3Config(ctx, region) - if err != nil { - return nil, err - } - } + if s3Client := s.clients[region]; s3Client != nil { + return s3Client, nil + } - s3Client = s3.NewFromConfig(config, func(o *s3.Options) { - if endpoint != "" { - o.BaseEndpoint = aws.String(endpoint) - o.UsePathStyle = true - o.DisableLogOutputChecksumValidationSkipped = true - // Linode (Akamai) requires checksum-when-required behavior - if scheme == "linode" { - o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired - o.ResponseChecksumValidation = aws.ResponseChecksumValidationWhenRequired - } - } else { - o.EndpointResolverV2 = &ResolverV2{} - } - }) - s.clients[region] = s3Client + // Client configuration is determined by region and process-wide environment. + // The first request for a region creates the shared client for that region. + _, span := tracer.Start(ctx, "S3Context::getClient") + defer span.End() + + var config aws.Config + var err error + endpoint := os.Getenv("S3_ENDPOINT") + if endpoint == "" { + config, err = awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) + if err != nil { + return nil, fmt.Errorf("error loading AWS config: %v", err) + } + } else { + // Use customized S3 storage + klog.V(2).Infof("Found S3_ENDPOINT=%q, using as non-AWS S3 backend", endpoint) + config, err = getCustomS3Config(ctx, region) + if err != nil { + return nil, err + } } + s3Client := s3.NewFromConfig(config, optFn) + + s.clients[region] = s3Client + return s3Client, nil } @@ -204,7 +194,9 @@ func (s *S3Context) getDetailsForBucket(ctx context.Context, bucket string) (*S3 } var response *s3.GetBucketLocationOutput - s3Client, err := s.getClient(ctx, awsRegion, "s3") + s3Client, err := s.getClient(ctx, awsRegion, func(o *s3.Options) { + o.EndpointResolverV2 = &ResolverV2{} + }) if err != nil { return bucketDetails, fmt.Errorf("error connecting to S3: %s", err) } @@ -241,7 +233,7 @@ func (s *S3Context) getDetailsForBucket(ctx context.Context, bucket string) (*S3 return bucketDetails, nil } -func (b *S3BucketDetails) hasServerSideEncryptionByDefault(ctx context.Context, scheme string) bool { +func (b *S3BucketDetails) hasServerSideEncryptionByDefault(ctx context.Context, optFn func(*s3.Options)) bool { b.mutex.Lock() defer b.mutex.Unlock() @@ -257,7 +249,7 @@ func (b *S3BucketDetails) hasServerSideEncryptionByDefault(ctx context.Context, // We only make one attempt to find the SSE policy (even if there's an error) b.applyServerSideEncryptionByDefault = &applyServerSideEncryptionByDefault - client, err := b.context.getClient(ctx, b.region, scheme) + client, err := b.context.getClient(ctx, b.region, optFn) if err != nil { klog.Warningf("Unable to read bucket encryption policy for %q in region %q: will encrypt using AES256", b.name, b.region) return false diff --git a/util/pkg/vfs/s3fs.go b/util/pkg/vfs/s3fs.go index 1f91621757aff..723e24f75b1dc 100644 --- a/util/pkg/vfs/s3fs.go +++ b/util/pkg/vfs/s3fs.go @@ -51,6 +51,8 @@ type S3Path struct { scheme string // sse specifies if server side encryption should be enabled sse bool + // optFn configures provider-specific S3 client options. + optFn func(*s3.Options) } var ( @@ -64,7 +66,7 @@ type S3Acl struct { RequestACL *types.ObjectCannedACL } -func newS3Path(s3Context *S3Context, scheme string, bucket string, key string, sse bool) *S3Path { +func newS3Path(s3Context *S3Context, scheme string, bucket string, key string, sse bool, optFn func(*s3.Options)) *S3Path { bucket = strings.TrimSuffix(bucket, "/") key = strings.TrimPrefix(key, "/") @@ -74,6 +76,7 @@ func newS3Path(s3Context *S3Context, scheme string, bucket string, key string, s key: key, scheme: scheme, sse: sse, + optFn: optFn, } } @@ -267,6 +270,7 @@ func (p *S3Path) Join(relativePath ...string) Path { key: joined, scheme: p.scheme, sse: p.sse, + optFn: p.optFn, } } @@ -282,7 +286,7 @@ func (p *S3Path) getServerSideEncryption(ctx context.Context) (sse types.ServerS if err != nil { return "", "", err } - defaultEncryption := bucketDetails.hasServerSideEncryptionByDefault(ctx, p.scheme) + defaultEncryption := bucketDetails.hasServerSideEncryptionByDefault(ctx, p.optFn) if defaultEncryption { sseLog = "DefaultBucketEncryption" } else { @@ -465,6 +469,7 @@ func (p *S3Path) ReadDir() ([]Path, error) { etag: o.ETag, scheme: p.scheme, sse: p.sse, + optFn: p.optFn, } paths = append(paths, child) } @@ -507,6 +512,7 @@ func (p *S3Path) ReadTree(ctx context.Context) ([]Path, error) { etag: o.ETag, scheme: p.scheme, sse: p.sse, + optFn: p.optFn, } paths = append(paths, child) } @@ -529,7 +535,7 @@ func (p *S3Path) client(ctx context.Context) (*s3.Client, error) { return nil, err } - client, err := p.s3Context.getClient(ctx, bucketDetails.region, p.scheme) + client, err := p.s3Context.getClient(ctx, bucketDetails.region, p.optFn) if err != nil { return nil, err } diff --git a/util/pkg/vfs/s3fs_test.go b/util/pkg/vfs/s3fs_test.go index 884bb181caa88..898001cdba693 100644 --- a/util/pkg/vfs/s3fs_test.go +++ b/util/pkg/vfs/s3fs_test.go @@ -96,6 +96,7 @@ func Test_LinodePath_Parse(t *testing.T) { }, } for _, g := range grid { + t.Setenv("S3_ENDPOINT", "https://example.com") s3path, err := Context.buildLinodePath(g.Input) if !g.ExpectError { if err != nil { @@ -132,6 +133,7 @@ func Test_NonLinodeObjectStoragePaths_HaveCorrectScheme(t *testing.T) { for _, tc := range grid { t.Run(tc.name, func(t *testing.T) { + t.Setenv("S3_ENDPOINT", "https://example.com") s3path, err := tc.build(tc.input) if err != nil { t.Fatalf("unexpected error parsing %s path: %v", tc.name, err)