Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions util/pkg/vfs/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ import (
"time"

"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/gophercloud/gophercloud/v2"
"google.golang.org/api/option"
storage "google.golang.org/api/storage/v1"
Expand Down Expand Up @@ -329,6 +331,8 @@ func RetryWithBackoff(backoff wait.Backoff, condition func() (bool, error)) (boo
}

func (c *VFSContext) buildS3Path(p string) (*S3Path, error) {
endpoint := os.Getenv("S3_ENDPOINT")

u, err := url.Parse(p)
if err != nil {
return nil, fmt.Errorf("invalid s3 path: %q", p)
Expand All @@ -342,11 +346,24 @@ func (c *VFSContext) buildS3Path(p string) (*S3Path, error) {
return nil, fmt.Errorf("invalid s3 path: %q", p)
}

s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, true)
s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, true, func(o *s3.Options) {
if endpoint != "" {
o.BaseEndpoint = aws.String(endpoint)
o.UsePathStyle = true
o.DisableLogOutputChecksumValidationSkipped = true
} else {
o.EndpointResolverV2 = &ResolverV2{}
}
})
return s3path, nil
}

func (c *VFSContext) buildDOPath(p string) (*S3Path, error) {
endpoint := os.Getenv("S3_ENDPOINT")
if endpoint == "" {
return nil, fmt.Errorf("required S3_ENDPOINT env var for path: %q", p)
}

u, err := url.Parse(p)
if err != nil {
return nil, fmt.Errorf("invalid spaces path: %q", p)
Expand All @@ -360,11 +377,20 @@ func (c *VFSContext) buildDOPath(p string) (*S3Path, error) {
return nil, fmt.Errorf("invalid spaces path: %q", p)
}

s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false)
s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false, func(o *s3.Options) {
o.BaseEndpoint = aws.String(endpoint)
o.UsePathStyle = true
o.DisableLogOutputChecksumValidationSkipped = true
})
return s3path, nil
}

func (c *VFSContext) buildLinodePath(p string) (*S3Path, error) {
endpoint := os.Getenv("S3_ENDPOINT")
if endpoint == "" {
return nil, fmt.Errorf("required S3_ENDPOINT env var for path: %q", p)
}

u, err := url.Parse(p)
if err != nil {
return nil, fmt.Errorf("invalid Linode object storage path: %q", p)
Expand All @@ -378,11 +404,23 @@ func (c *VFSContext) buildLinodePath(p string) (*S3Path, error) {
return nil, fmt.Errorf("invalid Linode object storage path: %q", p)
}

s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false)
s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false, func(o *s3.Options) {
o.BaseEndpoint = aws.String(endpoint)
o.UsePathStyle = true
o.DisableLogOutputChecksumValidationSkipped = true
// Linode (Akamai) requires checksum-when-required behavior
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
o.ResponseChecksumValidation = aws.ResponseChecksumValidationWhenRequired
})
return s3path, nil
}

func (c *VFSContext) buildHetznerPath(p string) (*S3Path, error) {
endpoint := os.Getenv("S3_ENDPOINT")
if endpoint == "" {
return nil, fmt.Errorf("required S3_ENDPOINT env var for path: %q", p)
}

u, err := url.Parse(p)
if err != nil {
return nil, fmt.Errorf("invalid Hetzner Object Storage path: %q", p)
Expand All @@ -396,7 +434,11 @@ func (c *VFSContext) buildHetznerPath(p string) (*S3Path, error) {
return nil, fmt.Errorf("invalid Hetzner object storage path: %q", p)
}

s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false)
s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false, func(o *s3.Options) {
o.BaseEndpoint = aws.String(endpoint)
o.UsePathStyle = true
o.DisableLogOutputChecksumValidationSkipped = true
})
return s3path, nil
}

Expand Down Expand Up @@ -546,6 +588,11 @@ func (c *VFSContext) getAzureBlobClient(ctx context.Context) (*azblob.Client, er
}

func (c *VFSContext) buildSCWPath(p string) (*S3Path, error) {
endpoint := os.Getenv("S3_ENDPOINT")
if endpoint == "" {
return nil, fmt.Errorf("required S3_ENDPOINT env var for path: %q", p)
}

u, err := url.Parse(p)
if err != nil {
return nil, fmt.Errorf("invalid bucket path: %q", p)
Expand All @@ -559,6 +606,10 @@ func (c *VFSContext) buildSCWPath(p string) (*S3Path, error) {
return nil, fmt.Errorf("invalid bucket path: %q", p)
}

s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false)
s3path := newS3Path(c.s3Context, u.Scheme, bucket, u.Path, false, func(o *s3.Options) {
o.BaseEndpoint = aws.String(endpoint)
o.UsePathStyle = true
o.DisableLogOutputChecksumValidationSkipped = true
})
return s3path, nil
}
74 changes: 33 additions & 41 deletions util/pkg/vfs/s3context.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,50 +84,40 @@ func (*ResolverV2) ResolveEndpoint(ctx context.Context, params s3.EndpointParame
return s3.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params)
}

func (s *S3Context) getClient(ctx context.Context, region string, scheme string) (*s3.Client, error) {
func (s *S3Context) getClient(ctx context.Context, region string, optFn func(*s3.Options)) (*s3.Client, error) {
s.mutex.Lock()
defer s.mutex.Unlock()

// Client configuration is currently determined by region and process-wide environment.
s3Client := s.clients[region]
if s3Client == nil {
_, span := tracer.Start(ctx, "S3Context::getClient")
defer span.End()

var config aws.Config
var err error
endpoint := os.Getenv("S3_ENDPOINT")
if endpoint == "" {
config, err = awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region))
if err != nil {
return nil, fmt.Errorf("error loading AWS config: %v", err)
}
} else {
// Use customized S3 storage
klog.V(2).Infof("Found S3_ENDPOINT=%q, using as non-AWS S3 backend", endpoint)
config, err = getCustomS3Config(ctx, region)
if err != nil {
return nil, err
}
}
if s3Client := s.clients[region]; s3Client != nil {
return s3Client, nil
}

s3Client = s3.NewFromConfig(config, func(o *s3.Options) {
if endpoint != "" {
o.BaseEndpoint = aws.String(endpoint)
o.UsePathStyle = true
o.DisableLogOutputChecksumValidationSkipped = true
// Linode (Akamai) requires checksum-when-required behavior
if scheme == "linode" {
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
o.ResponseChecksumValidation = aws.ResponseChecksumValidationWhenRequired
}
} else {
o.EndpointResolverV2 = &ResolverV2{}
}
})
s.clients[region] = s3Client
// Client configuration is determined by region and process-wide environment.
// The first request for a region creates the shared client for that region.
_, span := tracer.Start(ctx, "S3Context::getClient")
defer span.End()

var config aws.Config
var err error
endpoint := os.Getenv("S3_ENDPOINT")
if endpoint == "" {
config, err = awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region))
if err != nil {
return nil, fmt.Errorf("error loading AWS config: %v", err)
}
} else {
// Use customized S3 storage
klog.V(2).Infof("Found S3_ENDPOINT=%q, using as non-AWS S3 backend", endpoint)
config, err = getCustomS3Config(ctx, region)
if err != nil {
return nil, err
}
}

s3Client := s3.NewFromConfig(config, optFn)

s.clients[region] = s3Client

return s3Client, nil
}

Expand Down Expand Up @@ -204,7 +194,9 @@ func (s *S3Context) getDetailsForBucket(ctx context.Context, bucket string) (*S3
}
var response *s3.GetBucketLocationOutput

s3Client, err := s.getClient(ctx, awsRegion, "s3")
s3Client, err := s.getClient(ctx, awsRegion, func(o *s3.Options) {
o.EndpointResolverV2 = &ResolverV2{}
})
if err != nil {
return bucketDetails, fmt.Errorf("error connecting to S3: %s", err)
}
Expand Down Expand Up @@ -241,7 +233,7 @@ func (s *S3Context) getDetailsForBucket(ctx context.Context, bucket string) (*S3
return bucketDetails, nil
}

func (b *S3BucketDetails) hasServerSideEncryptionByDefault(ctx context.Context, scheme string) bool {
func (b *S3BucketDetails) hasServerSideEncryptionByDefault(ctx context.Context, optFn func(*s3.Options)) bool {
b.mutex.Lock()
defer b.mutex.Unlock()

Expand All @@ -257,7 +249,7 @@ func (b *S3BucketDetails) hasServerSideEncryptionByDefault(ctx context.Context,
// We only make one attempt to find the SSE policy (even if there's an error)
b.applyServerSideEncryptionByDefault = &applyServerSideEncryptionByDefault

client, err := b.context.getClient(ctx, b.region, scheme)
client, err := b.context.getClient(ctx, b.region, optFn)
if err != nil {
klog.Warningf("Unable to read bucket encryption policy for %q in region %q: will encrypt using AES256", b.name, b.region)
return false
Expand Down
12 changes: 9 additions & 3 deletions util/pkg/vfs/s3fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ type S3Path struct {
scheme string
// sse specifies if server side encryption should be enabled
sse bool
// optFn configures provider-specific S3 client options.
optFn func(*s3.Options)
}

var (
Expand All @@ -64,7 +66,7 @@ type S3Acl struct {
RequestACL *types.ObjectCannedACL
}

func newS3Path(s3Context *S3Context, scheme string, bucket string, key string, sse bool) *S3Path {
func newS3Path(s3Context *S3Context, scheme string, bucket string, key string, sse bool, optFn func(*s3.Options)) *S3Path {
bucket = strings.TrimSuffix(bucket, "/")
key = strings.TrimPrefix(key, "/")

Expand All @@ -74,6 +76,7 @@ func newS3Path(s3Context *S3Context, scheme string, bucket string, key string, s
key: key,
scheme: scheme,
sse: sse,
optFn: optFn,
}
}

Expand Down Expand Up @@ -267,6 +270,7 @@ func (p *S3Path) Join(relativePath ...string) Path {
key: joined,
scheme: p.scheme,
sse: p.sse,
optFn: p.optFn,
}
}

Expand All @@ -282,7 +286,7 @@ func (p *S3Path) getServerSideEncryption(ctx context.Context) (sse types.ServerS
if err != nil {
return "", "", err
}
defaultEncryption := bucketDetails.hasServerSideEncryptionByDefault(ctx, p.scheme)
defaultEncryption := bucketDetails.hasServerSideEncryptionByDefault(ctx, p.optFn)
if defaultEncryption {
sseLog = "DefaultBucketEncryption"
} else {
Expand Down Expand Up @@ -465,6 +469,7 @@ func (p *S3Path) ReadDir() ([]Path, error) {
etag: o.ETag,
scheme: p.scheme,
sse: p.sse,
optFn: p.optFn,
}
paths = append(paths, child)
}
Expand Down Expand Up @@ -507,6 +512,7 @@ func (p *S3Path) ReadTree(ctx context.Context) ([]Path, error) {
etag: o.ETag,
scheme: p.scheme,
sse: p.sse,
optFn: p.optFn,
}
paths = append(paths, child)
}
Expand All @@ -529,7 +535,7 @@ func (p *S3Path) client(ctx context.Context) (*s3.Client, error) {
return nil, err
}

client, err := p.s3Context.getClient(ctx, bucketDetails.region, p.scheme)
client, err := p.s3Context.getClient(ctx, bucketDetails.region, p.optFn)
if err != nil {
return nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions util/pkg/vfs/s3fs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ func Test_LinodePath_Parse(t *testing.T) {
},
}
for _, g := range grid {
t.Setenv("S3_ENDPOINT", "https://example.com")
s3path, err := Context.buildLinodePath(g.Input)
if !g.ExpectError {
if err != nil {
Expand Down Expand Up @@ -132,6 +133,7 @@ func Test_NonLinodeObjectStoragePaths_HaveCorrectScheme(t *testing.T) {

for _, tc := range grid {
t.Run(tc.name, func(t *testing.T) {
t.Setenv("S3_ENDPOINT", "https://example.com")
s3path, err := tc.build(tc.input)
if err != nil {
t.Fatalf("unexpected error parsing %s path: %v", tc.name, err)
Expand Down
Loading