Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 96 additions & 146 deletions common/websocket/update_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@
package websocket

import (
"archive/zip"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
Expand All @@ -40,27 +38,22 @@
// ---------------------------------------------------------------------------

const (
defaultGitHubRepo = "Tencent/AI-Infra-Guard"
defaultGitHubRepo = "https://github.com/Tencent/AI-Infra-Guard.git"
defaultGitHubBranch = "main"
githubZipURLFmt = "https://codeload.github.com/%s/zip/refs/heads/%s"
githubTagZipURLFmt = "https://codeload.github.com/%s/zip/refs/tags/%s"

// dataDirs lists the sub-directories inside data/ that are synced.
// Callers may override via UpdateDataRequest.Dirs.
// dataDirsDefault lists the sub-directories inside data/ that are synced by default.
dataDirsDefault = "fingerprints,vuln,vuln_en,mcp,eval,agents"
)

// UpdateStatus holds the current state of a data-sync operation.
type UpdateStatus struct {
Running bool `json:"running"`
Success *bool `json:"success,omitempty"`
StartedAt time.Time `json:"started_at,omitempty"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
Message string `json:"message"`
// FilesUpdated is the number of files written to disk.
FilesUpdated int `json:"files_updated"`
// Ref is the branch or tag that was used.
Ref string `json:"ref,omitempty"`
Running bool `json:"running"`
Success *bool `json:"success,omitempty"`
StartedAt time.Time `json:"started_at,omitempty"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
Message string `json:"message"`
FilesUpdated int `json:"files_updated"`
Ref string `json:"ref,omitempty"`
}

var (
Expand All @@ -75,16 +68,14 @@
// UpdateDataRequest is the JSON body for POST /api/v1/system/update-data.
//
// {
// "ref": "main", // branch or tag, default: "main"
// "is_tag": false, // set true when ref is a tag
// "github_token": "", // optional, avoids GitHub rate-limit (60 req/h anon)
// "dirs": "fingerprints,vuln,vuln_en,mcp,eval,agents" // optional
// "ref": "main", // branch name or tag, default: "main"
// "is_tag": false, // set true when ref is a Git tag (e.g. "v4.1.3")
// "dirs": "fingerprints,vuln,vuln_en,mcp,eval,agents" // optional
// }
type UpdateDataRequest struct {
Ref string `json:"ref"`
IsTag bool `json:"is_tag"`
GithubToken string `json:"github_token"`
Dirs string `json:"dirs"`
Ref string `json:"ref"`
IsTag bool `json:"is_tag"`
Dirs string `json:"dirs"`
}

// ---------------------------------------------------------------------------
Expand All @@ -109,10 +100,11 @@
// HandleTriggerDataUpdate godoc
//
// @Summary Trigger data directory sync from GitHub
// @Description Downloads the repository archive from GitHub and overwrites the local
// @Description data/ sub-directories (fingerprints, vuln, vuln_en, mcp, eval, agents).
// @Description Clones the repository into a temporary directory and copies the requested
// @Description data/ sub-directories (fingerprints, vuln, vuln_en, mcp, eval, agents)
// @Description to the working directory. No GitHub token is required.
// @Description The operation runs asynchronously; poll GET /api/v1/system/update-status
// @Description for progress. Only one sync may run at a time.
// @Description for progress. Only one sync may run at a time.
// @Tags system
// @Accept json
// @Produce json
Expand Down Expand Up @@ -143,7 +135,7 @@
updateStatus = &UpdateStatus{
Running: true,
StartedAt: time.Now(),
Message: "downloading archive from GitHub…",
Message: "cloning repository…",
Ref: req.Ref,
}
updateMu.Unlock()
Expand Down Expand Up @@ -180,145 +172,103 @@
updateMu.Unlock()
}

// 1. Build download URL
var downloadURL string
if req.IsTag {
downloadURL = fmt.Sprintf(githubTagZipURLFmt, defaultGitHubRepo, req.Ref)
} else {
downloadURL = fmt.Sprintf(githubZipURLFmt, defaultGitHubRepo, req.Ref)
}

// 2. Download archive
setStatus(fmt.Sprintf("downloading %s …", downloadURL), 0)
body, err := downloadArchive(downloadURL, req.GithubToken)
// 1. Create a temporary directory for the clone.
tmpDir, err := os.MkdirTemp("", "aig-data-sync-*")
if err != nil {
finish(false, fmt.Sprintf("download failed: %v", err), 0)
finish(false, fmt.Sprintf("failed to create temp dir: %v", err), 0)
return
}
defer os.RemoveAll(tmpDir)

// 2. git clone --depth 1 --branch <ref> <repo> <tmpDir>
setStatus(fmt.Sprintf("git clone --depth 1 --branch %s …", req.Ref), 0)
cloneArgs := []string{
"clone", "--depth", "1",
"--branch", req.Ref,
defaultGitHubRepo,
Comment on lines +235 to +237
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Honor is_tag when resolving clone target

The API still accepts is_tag, but runDataUpdate always executes git clone --branch <ref> and never checks req.IsTag. When a branch and tag share the same name, Git resolves --branch to the branch, so callers requesting a tag can silently sync the wrong revision. The prior implementation used separate heads/tags URLs and respected this flag.

Useful? React with 👍 / 👎.

tmpDir,
}
cloneCmd := exec.Command("git", cloneArgs...) // #nosec G204 — args are not user-controlled paths

Check failure

Code scanning / CodeQL

Command built from user-controlled sources Critical

This command depends on a
user-provided value
.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
cloneCmd.Env = append(os.Environ(), "GIT_TERMINAL_PROMPT=0")
if out, err := cloneCmd.CombinedOutput(); err != nil {
finish(false, fmt.Sprintf("git clone failed: %v\n%s", err, strings.TrimSpace(string(out))), 0)
return
}

// 3. Extract & overwrite
setStatus("extracting archive …", 0)
// 3. Copy the requested data/ sub-directories into the working directory.
setStatus("copying data directories…", 0)
dirs := splitDirs(req.Dirs)
n, err := extractDataDirs(body, dirs)
filesWritten, err := copyDataDirs(tmpDir, dirs)
if err != nil {
finish(false, fmt.Sprintf("extraction failed: %v", err), n)
finish(false, fmt.Sprintf("copy failed: %v", err), filesWritten)
return
}

finish(true, fmt.Sprintf("sync complete — %d file(s) updated from ref %q", n, req.Ref), n)
finish(true, fmt.Sprintf("sync complete — %d file(s) updated from ref %q", filesWritten, req.Ref), filesWritten)
}

// downloadArchive fetches the zip archive and returns its bytes.
func downloadArchive(url, token string) ([]byte, error) {
client := &http.Client{Timeout: 5 * time.Minute}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
if token != "" {
req.Header.Set("Authorization", "token "+token)
}
req.Header.Set("User-Agent", "AI-Infra-Guard/data-updater")
// copyDataDirs copies data/<dir>/ from srcRoot (the cloned repo) into the
// current working directory, overwriting existing files.
func copyDataDirs(srcRoot string, dirs []string) (int, error) {
total := 0
for _, d := range dirs {
d = strings.TrimSpace(d)
if d == "" {
continue
}
srcDir := filepath.Join(srcRoot, "data", d)
dstDir := filepath.Join("data", d)
Comment on lines +276 to +283
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Validate requested dirs remain inside data/

copyDataDirs joins each user-supplied entry directly into both source and destination paths without sanitizing path traversal segments. A request like {"dirs":"../cmd"} resolves to <clone>/cmd and writes into ./cmd, so this endpoint can overwrite arbitrary project files outside data/ when called with crafted input. The previous archive-based logic only matched fixed data/<subdir> entries and did not allow escaping the data root.

Useful? React with 👍 / 👎.


resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if _, err := os.Stat(srcDir); os.IsNotExist(err) {

Check failure

Code scanning / CodeQL

Uncontrolled data used in path expression High

This path depends on a
user-provided value
.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
// sub-directory not present in this ref — skip silently
continue
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, url)
n, err := copyDir(srcDir, dstDir)
if err != nil {
return total, fmt.Errorf("copying data/%s: %w", d, err)
}
total += n
}

return io.ReadAll(resp.Body)
return total, nil
}

// extractDataDirs extracts the requested data sub-directories from the zip
// archive and writes them to the local filesystem.
//
// GitHub's archive has a single top-level directory named
// "<repo>-<ref>/", e.g. "AI-Infra-Guard-main/".
// We strip that prefix and write only the files under data/<dir>/.
func extractDataDirs(zipBytes []byte, dirs []string) (int, error) {
zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
if err != nil {
return 0, fmt.Errorf("invalid zip: %w", err)
// copyDir recursively copies all files from src to dst, creating dst if needed.
// Returns the number of files written.
func copyDir(src, dst string) (int, error) {
if err := os.MkdirAll(dst, 0o755); err != nil {

Check failure

Code scanning / CodeQL

Uncontrolled data used in path expression High

This path depends on a
user-provided value
.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
return 0, err
}

// Find the top-level prefix (first directory entry).
prefix := ""
for _, f := range zr.File {
if f.FileInfo().IsDir() {
parts := strings.SplitN(f.Name, "/", 2)
prefix = parts[0] + "/"
break
}
entries, err := os.ReadDir(src)

Check failure

Code scanning / CodeQL

Uncontrolled data used in path expression High

This path depends on a
user-provided value
.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
if err != nil {
return 0, err
}

// Build a quick lookup set for the requested dirs.
wantDir := make(map[string]bool, len(dirs))
for _, d := range dirs {
wantDir[strings.TrimSpace(d)] = true
}
total := 0
for _, e := range entries {
srcPath := filepath.Join(src, e.Name())
dstPath := filepath.Join(dst, e.Name())

filesWritten := 0
for _, f := range zr.File {
// Strip the top-level prefix.
rel := strings.TrimPrefix(f.Name, prefix)
// We only care about files under data/<wantDir>/
if !strings.HasPrefix(rel, "data/") {
continue
}
// rel is now like "data/fingerprints/foo.yaml"
parts := strings.SplitN(rel, "/", 3) // ["data", "subdir", "rest"]
if len(parts) < 3 {
continue // skip "data/" itself or "data/subdir/" directory entries
}
subDir := parts[1]
if !wantDir[subDir] {
continue
}
if f.FileInfo().IsDir() {
if err := os.MkdirAll(rel, 0o755); err != nil {
return filesWritten, fmt.Errorf("mkdir %s: %w", rel, err)
if e.IsDir() {
n, err := copyDir(srcPath, dstPath)
if err != nil {
return total, err
}
total += n
continue
}

// Ensure parent directory exists.
if err := os.MkdirAll(filepath.Dir(rel), 0o755); err != nil {
return filesWritten, fmt.Errorf("mkdir %s: %w", filepath.Dir(rel), err)
}

// Write file.
rc, err := f.Open()
data, err := os.ReadFile(srcPath) // #nosec G304 — srcPath is under tmpDir controlled by us

Check failure

Code scanning / CodeQL

Uncontrolled data used in path expression High

This path depends on a
user-provided value
.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
if err != nil {
return filesWritten, fmt.Errorf("open zip entry %s: %w", f.Name, err)
}
written, writeErr := writeFile(rel, rc)
rc.Close()
if writeErr != nil {
return filesWritten, fmt.Errorf("write %s: %w", rel, writeErr)
return total, fmt.Errorf("read %s: %w", srcPath, err)
}
if written {
filesWritten++
if err := os.WriteFile(dstPath, data, 0o644); err != nil {

Check failure

Code scanning / CodeQL

Uncontrolled data used in path expression High

This path depends on a
user-provided value
.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
return total, fmt.Errorf("write %s: %w", dstPath, err)
}
total++
}

return filesWritten, nil
}

// writeFile atomically writes the content of rc to path.
// It reports whether the file was actually written (always true on success).
func writeFile(path string, rc io.Reader) (bool, error) {
data, err := io.ReadAll(rc)
if err != nil {
return false, err
}
if err := os.WriteFile(path, data, 0o644); err != nil {
return false, err
}
return true, nil
return total, nil
}

// splitDirs splits a comma-separated list of directory names.
Expand All @@ -340,13 +290,13 @@

// updateStatusJSON is used only for Swagger doc generation.
type updateStatusJSON struct {
Running bool `json:"running"`
Success *bool `json:"success,omitempty"`
StartedAt time.Time `json:"started_at,omitempty"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
Message string `json:"message"`
FilesUpdated int `json:"files_updated"`
Ref string `json:"ref,omitempty"`
Running bool `json:"running"`
Success *bool `json:"success,omitempty"`
StartedAt time.Time `json:"started_at,omitempty"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
Message string `json:"message"`
FilesUpdated int `json:"files_updated"`
Ref string `json:"ref,omitempty"`
}

// MarshalJSON implements json.Marshaler so UpdateStatus can be serialised
Expand Down
Loading
Loading