Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions plugins/atr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# ATR Plugin

Screens LLM requests with [Agent Threat Rules (ATR)](https://github.com/Agent-Threat-Rule/agent-threat-rules)
— an open, MIT-licensed detection-rule standard for AI-agent / LLM / MCP threats
(prompt injection, tool poisoning, credential exfiltration, skill supply-chain
attacks).

The plugin keeps the gateway **language-agnostic**: instead of porting the ATR
engine to Go, it calls an OpenAI-compatible `/v1/moderations` endpoint backed by
ATR (for example `pyatr.adapters.openai_moderation`, or any service that returns
the OpenAI moderation shape). On a flagged prompt the request is short-circuited
in `PreLLMHook` with a `403` before the provider call.

## Usage

```go
import "github.com/maximhq/bifrost/plugins/atr"

plugin, err := atr.Init(&atr.Config{
Endpoint: "http://localhost:8000/v1/moderations", // ATR-backed moderation
FailClosed: false, // fail open if ATR is down
})
// register `plugin` with your Bifrost instance like any other LLMPlugin
```

Or construct directly: `atr.New(endpoint, failClosed)`.

## Config

| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `endpoint` | string | — (required) | OpenAI-compatible `/v1/moderations` URL backed by ATR. |
| `fail_closed` | bool | `false` | When the moderation endpoint is unreachable: `true` blocks the request, `false` lets it through. |

## Behavior

- `PreLLMHook` flattens the chat messages, POSTs `{"input": "<text>"}` to the
endpoint, and reads the standard `{"results":[{"flagged":bool,"categories":{…}}]}`
response. If `flagged`, it returns an `LLMPluginShortCircuit` carrying a `403`
`BifrostError` whose message lists the matched categories.
- Empty prompts and benign prompts pass through untouched.
- `PreRequestHook` / `PostLLMHook` are pass-throughs.

## Tests

`go test ./...` — covers prompt extraction, block-on-flagged, allow-on-benign,
fail-open, fail-closed, and config validation (uses `httptest`, no live endpoint
required).
205 changes: 205 additions & 0 deletions plugins/atr/atr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
// Package atr is a Bifrost plugin that screens LLM requests with Agent Threat
// Rules (ATR), an open detection-rule standard for AI-agent / LLM / MCP threats.
//
// The plugin calls an OpenAI-compatible /v1/moderations endpoint backed by ATR
// (e.g. `pyatr.adapters.openai_moderation`), so the gateway stays language-
// agnostic — no Go port of the ATR engine is required. When the endpoint flags
// the prompt, the request is short-circuited before the provider call.
package atr

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"

"github.com/maximhq/bifrost/core/schemas"
)

// PluginName is the registered name of the ATR plugin.
const PluginName = "atr"

// Plugin satisfies the Bifrost LLM plugin interface.
var _ schemas.LLMPlugin = (*Plugin)(nil)

// Config configures the ATR plugin.
type Config struct {
// Endpoint is the OpenAI-compatible /v1/moderations URL backed by ATR
// (e.g. "http://localhost:8000/v1/moderations").
Endpoint string `json:"endpoint"`
// FailClosed blocks the request when the moderation endpoint is unreachable.
// Default false (fail open) keeps the gateway available if ATR is down.
FailClosed bool `json:"fail_closed"`
}

// Plugin screens requests against an ATR moderation endpoint.
type Plugin struct {
endpoint string
client *http.Client
failClosed bool
}

// Init constructs the ATR plugin from config (matches the Bifrost built-in
// plugin convention used by InstantiatePlugin).
func Init(config *Config) (schemas.LLMPlugin, error) {
if config == nil || strings.TrimSpace(config.Endpoint) == "" {
return nil, errors.New("atr: config.endpoint is required")
}
return New(config.Endpoint, config.FailClosed), nil
}

// New returns an ATR plugin that calls the given OpenAI-compatible
// /v1/moderations endpoint (e.g. "http://localhost:8000/v1/moderations").
func New(endpoint string, failClosed bool) *Plugin {
return &Plugin{
endpoint: endpoint,
failClosed: failClosed,
client: &http.Client{Timeout: 10 * time.Second},
}
}

func (p *Plugin) GetName() string { return PluginName }

func (p *Plugin) Cleanup() error { return nil }

// PreRequestHook does not participate in routing.
func (p *Plugin) PreRequestHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) error {
return nil
}

// PreLLMHook scans the chat prompt and short-circuits with a 403 when ATR flags it.
func (p *Plugin) PreLLMHook(
ctx *schemas.BifrostContext,
req *schemas.BifrostRequest,
) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) {
text := promptText(req)
if strings.TrimSpace(text) == "" {
return req, nil, nil
}

// Propagate the request's cancellation/deadline into the moderation call.
// *BifrostContext implements context.Context; guard the nil pointer.
var reqCtx context.Context = context.Background()
if ctx != nil {
reqCtx = ctx
}

flagged, categories, err := p.moderate(reqCtx, text)
if err != nil {
if p.failClosed {
return req, blockShortCircuit("Agent Threat Rules moderation endpoint unreachable"), nil
}
return req, nil, nil // fail open
}
if flagged {
msg := "Blocked by Agent Threat Rules"
if categories != "" {
msg += ": " + categories
}
return req, blockShortCircuit(msg), nil
}
return req, nil, nil
}

// PostLLMHook is a pass-through.
func (p *Plugin) PostLLMHook(
ctx *schemas.BifrostContext,
resp *schemas.BifrostResponse,
bifrostErr *schemas.BifrostError,
) (*schemas.BifrostResponse, *schemas.BifrostError, error) {
return resp, bifrostErr, nil
}

func blockShortCircuit(message string) *schemas.LLMPluginShortCircuit {
code := http.StatusForbidden
return &schemas.LLMPluginShortCircuit{
Error: &schemas.BifrostError{
StatusCode: &code,
IsBifrostError: true,
Error: &schemas.ErrorField{Message: message},
},
}
}
Comment thread
greptile-apps[bot] marked this conversation as resolved.

// promptText flattens the chat request's messages into scannable text.
func promptText(req *schemas.BifrostRequest) string {
if req == nil || req.ChatRequest == nil {
return ""
}
var sb strings.Builder
for _, msg := range req.ChatRequest.Input {
if msg.Content == nil {
continue
}
if msg.Content.ContentStr != nil {
sb.WriteString(*msg.Content.ContentStr)
sb.WriteString("\n")
}
for _, block := range msg.Content.ContentBlocks {
if block.Text != nil {
sb.WriteString(*block.Text)
sb.WriteString("\n")
}
}
}
return sb.String()
}

type moderationResponse struct {
Results []struct {
Flagged bool `json:"flagged"`
Categories map[string]bool `json:"categories"`
} `json:"results"`
}

// moderate POSTs the text to the ATR moderation endpoint and reports whether it
// was flagged plus the truthy category names. A non-2xx status or a malformed
// (resultless) body is returned as an error so the caller's fail-closed policy
// applies rather than silently allowing the request.
func (p *Plugin) moderate(ctx context.Context, text string) (bool, string, error) {
payload, err := json.Marshal(map[string]string{"input": text})
if err != nil {
return false, "", err
}

httpReq, err := http.NewRequestWithContext(
ctx, http.MethodPost, p.endpoint, bytes.NewReader(payload),
)
Comment thread
greptile-apps[bot] marked this conversation as resolved.
if err != nil {
return false, "", err
}
httpReq.Header.Set("Content-Type", "application/json")

resp, err := p.client.Do(httpReq)
if err != nil {
return false, "", err
}
defer resp.Body.Close()

if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return false, "", fmt.Errorf("atr: moderation endpoint returned status %d", resp.StatusCode)
}

var out moderationResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
Comment on lines +187 to +188

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unbounded response body read

json.NewDecoder(resp.Body) reads from the raw body without a size cap. A misconfigured or adversarial ATR service could stream a very large body and cause unbounded memory growth inside the plugin. Wrap the body with io.LimitReader(resp.Body, maxBodyBytes) (e.g. 1 MB) before passing it to the decoder.

return false, "", err
}
if len(out.Results) == 0 {
return false, "", errors.New("atr: moderation response contained no results")
}
if !out.Results[0].Flagged {
return false, "", nil
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

var categories []string
for name, truthy := range out.Results[0].Categories {
if truthy {
categories = append(categories, name)
}
}
return true, strings.Join(categories, ", "), nil
}
Loading