Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 152 additions & 1 deletion go/cmd/gitter/gitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,16 @@ const gitStoreFileName = "git-store"
var endpointHandlers = map[string]http.HandlerFunc{
"GET /git": gitHandler,
"POST /cache": cacheHandler,
"GET /tags": tagsHandler,
"POST /affected-commits": affectedCommitsHandler,
}

var (
gFetch singleflight.Group
gArchive singleflight.Group
gLoad singleflight.Group
gLsRemote singleflight.Group
gLocalTags singleflight.Group
persistencePath = filepath.Join(defaultGitterWorkDir, persistenceFileName)
gitStorePath = filepath.Join(defaultGitterWorkDir, gitStoreFileName)
fetchTimeout time.Duration
Expand All @@ -67,6 +70,11 @@ var (
repoCache *ristretto.Cache[string, *Repository]
repoTTL time.Duration
repoCacheMaxCostBytes int64
// Cache for invalid (does not exist, or does not have tags) repos
// Maps repo URL to the HTTP status code (404 or 204) to return
invalidRepoCache *ristretto.Cache[string, int]
invalidRepoTTL time.Duration
invalidRepoCacheMaxEntries int64
)

const shutdownTimeout = 10 * time.Second
Expand Down Expand Up @@ -164,6 +172,28 @@ func CloseRepoCache() {
}
}

// InitInvalidRepoCache initializes the cache for invalid repositories.
func InitInvalidRepoCache() {
var err error
invalidRepoCache, err = ristretto.NewCache(&ristretto.Config[string, int]{
NumCounters: invalidRepoCacheMaxEntries * 10,
MaxCost: invalidRepoCacheMaxEntries, // Cost for each entry is 1
BufferItems: 64,
// Check for TTL expiry every 60 seconds
TtlTickerDurationInSec: 60,
})
if err != nil {
logger.FatalContext(context.Background(), "Failed to initialize invalid repository cache", slog.Any("err", err))
}
}

// CloseInvalidRepoCache closes the cache for invalid repositories.
func CloseInvalidRepoCache() {
if invalidRepoCache != nil {
invalidRepoCache.Close()
}
}

// prepareCmd prepares the command with context cancellation handled by sending SIGINT.
func prepareCmd(ctx context.Context, dir string, env []string, name string, args ...string) *exec.Cmd {
cmd := exec.CommandContext(ctx, name, args...)
Expand Down Expand Up @@ -340,7 +370,7 @@ func getFreshRepo(ctx context.Context, w http.ResponseWriter, repoURL string, fo
return nil, err
}

repoAny, err, _ := gLoad.Do(repoPath, func() (any, error) {
repoAny, err, _ := gLoad.Do(repoURL, func() (any, error) {
repoLock := GetRepoLock(repoURL)
repoLock.RLock()
defer repoLock.RUnlock()
Expand Down Expand Up @@ -480,6 +510,8 @@ func main() {
concurrentLimit := flag.Int("concurrent-limit", 100, "Concurrent limit for unique requests")
flag.DurationVar(&repoTTL, "repo-cache-ttl", time.Hour, "Repository LRU cache time-to-live duration")
repoMaxCostStr := flag.String("repo-cache-max-cost", "1GiB", "Repository LRU cache max cost (in bytes)")
flag.DurationVar(&invalidRepoTTL, "invalid-repo-cache-ttl", time.Hour, "Invalid repository cache time-to-live duration")
flag.Int64Var(&invalidRepoCacheMaxEntries, "invalid-repo-cache-max-entries", 5000, "Invalid repository cache max entries")
flag.Parse()
semaphore = make(chan struct{}, *concurrentLimit)

Expand All @@ -501,6 +533,8 @@ func main() {
loadLastFetchMap()
InitRepoCache()
defer CloseRepoCache()
InitInvalidRepoCache()
defer CloseInvalidRepoCache()

// Create a context that listens for the interrupt signal from the OS.
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
Expand Down Expand Up @@ -749,3 +783,120 @@ func affectedCommitsHandler(w http.ResponseWriter, r *http.Request) {
}
logger.InfoContext(ctx, "Request completed successfully: /affected-commits", slog.Duration("duration", time.Since(start)))
}

func makeTagsResponse(tagsMap map[string]SHA1) *pb.TagsResponse {
resp := &pb.TagsResponse{
Tags: make([]*pb.Ref, 0, len(tagsMap)),
}
for tag, hash := range tagsMap {
resp.Tags = append(resp.Tags, &pb.Ref{
Label: tag,
Hash: hash[:],
})
}

return resp
}

func tagsHandler(w http.ResponseWriter, r *http.Request) {
start := time.Now()

repoURL, err := prepareURL(r, r.URL.Query().Get("url"))
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

ctx := context.WithValue(r.Context(), urlKey, repoURL)
logger.InfoContext(ctx, "Received request: /tags")

// Previously cached invalid repo (does not exist or does not have tags)
// Get() will not return if the entry is past its TTL, so we can safely return the same http status code as is.
if code, found := invalidRepoCache.Get(repoURL); found {
logger.InfoContext(ctx, "Invalid repo cache hit", slog.Int("code", code))
w.WriteHeader(code)

return
}

var tagsMap map[string]SHA1

// If repository is recently loaded, we can return the tags directly from the cached repo
if cachedRepo, found := repoCache.Get(repoURL); found {
logger.InfoContext(ctx, "Repo cache hit, returning cached tags")
tagsMap = make(map[string]SHA1)
for tag, idx := range cachedRepo.tagToCommit {
tagsMap[tag] = cachedRepo.commits[idx].Hash
}
} else {
repo := NewRepository(repoURL)

// If repoPath is not empty, it means there is a local git directory for this repo on disk
// We want to use show-ref instead of ls-remote because it's faster and we don't have to worry about rate limits
if repo.repoPath != "" {
logger.DebugContext(ctx, "Local repo found, using show-ref")
if _, errFetch, _ := gFetch.Do(repoURL, func() (any, error) {
return nil, FetchRepo(ctx, repoURL, false)
}); errFetch != nil {
logger.ErrorContext(ctx, "Error fetching repo", slog.Any("error", errFetch))
http.Error(w, "Error fetching repository", http.StatusInternalServerError)

return
}

tagsMapAny, errLocal, _ := gLocalTags.Do(repoURL, func() (any, error) {
return repo.GetLocalTags(ctx)
})
if errLocal != nil {
logger.ErrorContext(ctx, "Error parsing local tags", slog.Any("error", errLocal))
http.Error(w, "Error parsing local tags", http.StatusInternalServerError)

return
}
tagsMap = tagsMapAny.(map[string]SHA1)
} else {
// If repo is not on disk, we use ls-remote to get the tags instead
logger.DebugContext(ctx, "Local repo not found, using ls-remote")
tagsMapAny, errLsRemote, _ := gLsRemote.Do(repoURL, func() (any, error) {
return repo.GetRemoteTags(ctx)
})
if errLsRemote != nil {
if isAuthError(errLsRemote) {
invalidRepoCache.SetWithTTL(repoURL, http.StatusNotFound, 1, invalidRepoTTL)
http.Error(w, "Repository not found", http.StatusNotFound)

return
}
logger.ErrorContext(ctx, "Error running git ls-remote", slog.Any("error", errLsRemote))
http.Error(w, "Error listing remote tags", http.StatusInternalServerError)

return
}
tagsMap = tagsMapAny.(map[string]SHA1)
}
}

if len(tagsMap) == 0 {
logger.InfoContext(ctx, "No tags in repository")
invalidRepoCache.SetWithTTL(repoURL, http.StatusNoContent, 1, invalidRepoTTL)
w.WriteHeader(http.StatusNoContent)

return
}

resp := makeTagsResponse(tagsMap)
out, err := marshalResponse(r, resp)
if err != nil {
logger.ErrorContext(ctx, "Error marshaling tags response", slog.Any("error", err))
http.Error(w, fmt.Sprintf("Error marshaling tags response: %v", err), http.StatusInternalServerError)

return
}

w.Header().Set("Content-Type", r.Header.Get("Content-Type"))
w.WriteHeader(http.StatusOK)
if _, err := w.Write(out); err != nil {
logger.ErrorContext(ctx, "Error writing tags response", slog.Any("error", err))
}
logger.InfoContext(ctx, "Request completed successfully: /tags", slog.Duration("duration", time.Since(start)))
}
99 changes: 85 additions & 14 deletions go/cmd/gitter/gitter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"github.com/google/go-cmp/cmp"
pb "github.com/google/osv.dev/go/cmd/gitter/pb/repository"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)

Expand Down Expand Up @@ -197,7 +196,17 @@ func setupTest(t *testing.T) {
// Initialize semaphore for tests
semaphore = make(chan struct{}, 100)

t.Cleanup(resetSaveTimer)
// Initialize caches for tests
repoCacheMaxCostBytes = 1024 * 1024 // 1MB for test
invalidRepoCacheMaxEntries = 100
InitRepoCache()
InitInvalidRepoCache()

t.Cleanup(func() {
resetSaveTimer()
CloseRepoCache()
CloseInvalidRepoCache()
})
}

func TestGitHandler_Integration(t *testing.T) {
Expand Down Expand Up @@ -268,9 +277,8 @@ func TestCacheHandler(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reqProto := &pb.CacheRequest{Url: tt.url}
body, _ := protojson.Marshal(reqProto)
body, _ := proto.Marshal(reqProto)
req, err := http.NewRequest(http.MethodPost, "/cache", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -357,9 +365,8 @@ func TestAffectedCommitsHandler(t *testing.T) {
Events: events,
}

body, _ := protojson.Marshal(reqProto)
body, _ := proto.Marshal(reqProto)
req, err := http.NewRequest(http.MethodPost, "/affected-commits", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
if err != nil {
t.Fatal(err)
}
Expand All @@ -376,14 +383,8 @@ func TestAffectedCommitsHandler(t *testing.T) {
}

respBody := &pb.AffectedCommitsResponse{}
if rr.Header().Get("Content-Type") == "application/json" {
if err := protojson.Unmarshal(rr.Body.Bytes(), respBody); err != nil {
t.Fatalf("Failed to unmarshal JSON response: %v", err)
}
} else {
if err := proto.Unmarshal(rr.Body.Bytes(), respBody); err != nil {
t.Fatalf("Failed to unmarshal proto response: %v", err)
}
if err := proto.Unmarshal(rr.Body.Bytes(), respBody); err != nil {
t.Fatalf("Failed to unmarshal proto response: %v", err)
}

var gotHashes []string
Expand All @@ -400,3 +401,73 @@ func TestAffectedCommitsHandler(t *testing.T) {
})
}
}

func TestTagsHandler(t *testing.T) {
setupTest(t)

tests := []struct {
name string
url string
expectedCode int
expectedTags map[string]string
}{
{
name: "Valid repo with tags",
url: "https://github.com/oliverchang/osv-test.git",
expectedCode: http.StatusOK,
expectedTags: map[string]string{
"v0.2": "8d8242f545e9cec3e6d0d2e3f5bde8be1c659735",
"branch-v0.1.1": "4c155795426727ea05575bd5904321def23c03f4",
"branch-v0.1.1-with-fix": "b9b3fd4732695b83c3068b7b6a14bb372ec31f98",
"branch_1_cherrypick_regress": "febfac1940086bc1f6d3dc33fda0a1d1ba336209",
"v0.1": "a2ba949290915d445d34d0e8e9de2e7ce38198fc",
"v0.1.1": "b1c95a196f22d06fcf80df8c6691cd113d8fefff",
},
},
{
name: "Repo exist but no tags",
// This repo hasn't gotten a commit in 8 years so should be fairly stable for our testing.
url: "https://github.com/torvalds/test-tlb.git",
expectedCode: http.StatusNoContent,
expectedTags: nil,
},
{
name: "Non-existent repo",
url: "https://github.com/google/this-repo-does-not-exist-12345.git",
expectedCode: http.StatusNotFound,
expectedTags: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "/tags?url="+tt.url, nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
tagsHandler(rr, req)

if status := rr.Code; status != tt.expectedCode {
t.Errorf("handler returned wrong status code: got %v want %v",
status, tt.expectedCode)
}

if tt.expectedTags != nil {
respBody := &pb.TagsResponse{}
if err := proto.Unmarshal(rr.Body.Bytes(), respBody); err != nil {
t.Fatalf("Failed to unmarshal proto response: %v", err)
}

gotTags := make(map[string]string)
for _, ref := range respBody.GetTags() {
gotTags[ref.GetLabel()] = hex.EncodeToString(ref.GetHash())
}

if diff := cmp.Diff(tt.expectedTags, gotTags); diff != "" {
t.Errorf("handler returned wrong tags (-want +got):\n%s", diff)
}
}
})
}
}
Loading