-
Notifications
You must be signed in to change notification settings - Fork 806
feat(plugins): add ATR plugin (Agent Threat Rules guardrail) #4591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| # ATR Plugin | ||
|
|
||
| Screens LLM requests with [Agent Threat Rules (ATR)](https://github.com/Agent-Threat-Rule/agent-threat-rules) | ||
| — an open, MIT-licensed detection-rule standard for AI-agent / LLM / MCP threats | ||
| (prompt injection, tool poisoning, credential exfiltration, skill supply-chain | ||
| attacks). | ||
|
|
||
| The plugin keeps the gateway **language-agnostic**: instead of porting the ATR | ||
| engine to Go, it calls an OpenAI-compatible `/v1/moderations` endpoint backed by | ||
| ATR (for example `pyatr.adapters.openai_moderation`, or any service that returns | ||
| the OpenAI moderation shape). On a flagged prompt the request is short-circuited | ||
| in `PreLLMHook` with a `403` before the provider call. | ||
|
|
||
| ## Usage | ||
|
|
||
| ```go | ||
| import "github.com/maximhq/bifrost/plugins/atr" | ||
|
|
||
| plugin, err := atr.Init(&atr.Config{ | ||
| Endpoint: "http://localhost:8000/v1/moderations", // ATR-backed moderation | ||
| FailClosed: false, // fail open if ATR is down | ||
| }) | ||
| // register `plugin` with your Bifrost instance like any other LLMPlugin | ||
| ``` | ||
|
|
||
| Or construct directly: `atr.New(endpoint, failClosed)`. | ||
|
|
||
| ## Config | ||
|
|
||
| | Field | Type | Default | Description | | ||
| |-------|------|---------|-------------| | ||
| | `endpoint` | string | — (required) | OpenAI-compatible `/v1/moderations` URL backed by ATR. | | ||
| | `fail_closed` | bool | `false` | When the moderation endpoint is unreachable: `true` blocks the request, `false` lets it through. | | ||
|
|
||
| ## Behavior | ||
|
|
||
| - `PreLLMHook` flattens the chat messages, POSTs `{"input": "<text>"}` to the | ||
| endpoint, and reads the standard `{"results":[{"flagged":bool,"categories":{…}}]}` | ||
| response. If `flagged`, it returns an `LLMPluginShortCircuit` carrying a `403` | ||
| `BifrostError` whose message lists the matched categories. | ||
| - Empty prompts and benign prompts pass through untouched. | ||
| - `PreRequestHook` / `PostLLMHook` are pass-throughs. | ||
|
|
||
| ## Tests | ||
|
|
||
| `go test ./...` — covers prompt extraction, block-on-flagged, allow-on-benign, | ||
| fail-open, fail-closed, and config validation (uses `httptest`, no live endpoint | ||
| required). |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,205 @@ | ||
| // Package atr is a Bifrost plugin that screens LLM requests with Agent Threat | ||
| // Rules (ATR), an open detection-rule standard for AI-agent / LLM / MCP threats. | ||
| // | ||
| // The plugin calls an OpenAI-compatible /v1/moderations endpoint backed by ATR | ||
| // (e.g. `pyatr.adapters.openai_moderation`), so the gateway stays language- | ||
| // agnostic — no Go port of the ATR engine is required. When the endpoint flags | ||
| // the prompt, the request is short-circuited before the provider call. | ||
| package atr | ||
|
|
||
| import ( | ||
| "bytes" | ||
| "context" | ||
| "encoding/json" | ||
| "errors" | ||
| "fmt" | ||
| "net/http" | ||
| "strings" | ||
| "time" | ||
|
|
||
| "github.com/maximhq/bifrost/core/schemas" | ||
| ) | ||
|
|
||
| // PluginName is the registered name of the ATR plugin. | ||
| const PluginName = "atr" | ||
|
|
||
| // Plugin satisfies the Bifrost LLM plugin interface. | ||
| var _ schemas.LLMPlugin = (*Plugin)(nil) | ||
|
|
||
| // Config configures the ATR plugin. | ||
| type Config struct { | ||
| // Endpoint is the OpenAI-compatible /v1/moderations URL backed by ATR | ||
| // (e.g. "http://localhost:8000/v1/moderations"). | ||
| Endpoint string `json:"endpoint"` | ||
| // FailClosed blocks the request when the moderation endpoint is unreachable. | ||
| // Default false (fail open) keeps the gateway available if ATR is down. | ||
| FailClosed bool `json:"fail_closed"` | ||
| } | ||
|
|
||
| // Plugin screens requests against an ATR moderation endpoint. | ||
| type Plugin struct { | ||
| endpoint string | ||
| client *http.Client | ||
| failClosed bool | ||
| } | ||
|
|
||
| // Init constructs the ATR plugin from config (matches the Bifrost built-in | ||
| // plugin convention used by InstantiatePlugin). | ||
| func Init(config *Config) (schemas.LLMPlugin, error) { | ||
| if config == nil || strings.TrimSpace(config.Endpoint) == "" { | ||
| return nil, errors.New("atr: config.endpoint is required") | ||
| } | ||
| return New(config.Endpoint, config.FailClosed), nil | ||
| } | ||
|
|
||
| // New returns an ATR plugin that calls the given OpenAI-compatible | ||
| // /v1/moderations endpoint (e.g. "http://localhost:8000/v1/moderations"). | ||
| func New(endpoint string, failClosed bool) *Plugin { | ||
| return &Plugin{ | ||
| endpoint: endpoint, | ||
| failClosed: failClosed, | ||
| client: &http.Client{Timeout: 10 * time.Second}, | ||
| } | ||
| } | ||
|
|
||
| func (p *Plugin) GetName() string { return PluginName } | ||
|
|
||
| func (p *Plugin) Cleanup() error { return nil } | ||
|
|
||
| // PreRequestHook does not participate in routing. | ||
| func (p *Plugin) PreRequestHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) error { | ||
| return nil | ||
| } | ||
|
|
||
| // PreLLMHook scans the chat prompt and short-circuits with a 403 when ATR flags it. | ||
| func (p *Plugin) PreLLMHook( | ||
| ctx *schemas.BifrostContext, | ||
| req *schemas.BifrostRequest, | ||
| ) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { | ||
| text := promptText(req) | ||
| if strings.TrimSpace(text) == "" { | ||
| return req, nil, nil | ||
| } | ||
|
|
||
| // Propagate the request's cancellation/deadline into the moderation call. | ||
| // *BifrostContext implements context.Context; guard the nil pointer. | ||
| var reqCtx context.Context = context.Background() | ||
| if ctx != nil { | ||
| reqCtx = ctx | ||
| } | ||
|
|
||
| flagged, categories, err := p.moderate(reqCtx, text) | ||
| if err != nil { | ||
| if p.failClosed { | ||
| return req, blockShortCircuit("Agent Threat Rules moderation endpoint unreachable"), nil | ||
| } | ||
| return req, nil, nil // fail open | ||
| } | ||
| if flagged { | ||
| msg := "Blocked by Agent Threat Rules" | ||
| if categories != "" { | ||
| msg += ": " + categories | ||
| } | ||
| return req, blockShortCircuit(msg), nil | ||
| } | ||
| return req, nil, nil | ||
| } | ||
|
|
||
| // PostLLMHook is a pass-through. | ||
| func (p *Plugin) PostLLMHook( | ||
| ctx *schemas.BifrostContext, | ||
| resp *schemas.BifrostResponse, | ||
| bifrostErr *schemas.BifrostError, | ||
| ) (*schemas.BifrostResponse, *schemas.BifrostError, error) { | ||
| return resp, bifrostErr, nil | ||
| } | ||
|
|
||
| func blockShortCircuit(message string) *schemas.LLMPluginShortCircuit { | ||
| code := http.StatusForbidden | ||
| return &schemas.LLMPluginShortCircuit{ | ||
| Error: &schemas.BifrostError{ | ||
| StatusCode: &code, | ||
| IsBifrostError: true, | ||
| Error: &schemas.ErrorField{Message: message}, | ||
| }, | ||
| } | ||
| } | ||
|
|
||
| // promptText flattens the chat request's messages into scannable text. | ||
| func promptText(req *schemas.BifrostRequest) string { | ||
| if req == nil || req.ChatRequest == nil { | ||
| return "" | ||
| } | ||
| var sb strings.Builder | ||
| for _, msg := range req.ChatRequest.Input { | ||
| if msg.Content == nil { | ||
| continue | ||
| } | ||
| if msg.Content.ContentStr != nil { | ||
| sb.WriteString(*msg.Content.ContentStr) | ||
| sb.WriteString("\n") | ||
| } | ||
| for _, block := range msg.Content.ContentBlocks { | ||
| if block.Text != nil { | ||
| sb.WriteString(*block.Text) | ||
| sb.WriteString("\n") | ||
| } | ||
| } | ||
| } | ||
| return sb.String() | ||
| } | ||
|
|
||
| type moderationResponse struct { | ||
| Results []struct { | ||
| Flagged bool `json:"flagged"` | ||
| Categories map[string]bool `json:"categories"` | ||
| } `json:"results"` | ||
| } | ||
|
|
||
| // moderate POSTs the text to the ATR moderation endpoint and reports whether it | ||
| // was flagged plus the truthy category names. A non-2xx status or a malformed | ||
| // (resultless) body is returned as an error so the caller's fail-closed policy | ||
| // applies rather than silently allowing the request. | ||
| func (p *Plugin) moderate(ctx context.Context, text string) (bool, string, error) { | ||
| payload, err := json.Marshal(map[string]string{"input": text}) | ||
| if err != nil { | ||
| return false, "", err | ||
| } | ||
|
|
||
| httpReq, err := http.NewRequestWithContext( | ||
| ctx, http.MethodPost, p.endpoint, bytes.NewReader(payload), | ||
| ) | ||
|
greptile-apps[bot] marked this conversation as resolved.
|
||
| if err != nil { | ||
| return false, "", err | ||
| } | ||
| httpReq.Header.Set("Content-Type", "application/json") | ||
|
|
||
| resp, err := p.client.Do(httpReq) | ||
| if err != nil { | ||
| return false, "", err | ||
| } | ||
| defer resp.Body.Close() | ||
|
|
||
| if resp.StatusCode < 200 || resp.StatusCode >= 300 { | ||
| return false, "", fmt.Errorf("atr: moderation endpoint returned status %d", resp.StatusCode) | ||
| } | ||
|
|
||
| var out moderationResponse | ||
| if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { | ||
|
Comment on lines
+187
to
+188
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| return false, "", err | ||
| } | ||
| if len(out.Results) == 0 { | ||
| return false, "", errors.New("atr: moderation response contained no results") | ||
| } | ||
| if !out.Results[0].Flagged { | ||
| return false, "", nil | ||
| } | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| var categories []string | ||
| for name, truthy := range out.Results[0].Categories { | ||
| if truthy { | ||
| categories = append(categories, name) | ||
| } | ||
| } | ||
| return true, strings.Join(categories, ", "), nil | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.