Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,9 @@ There are several ways to adjust functionality of the library:
1. `ClaimsUpdater` - interface with `Update(claims Claims) Claims` method. This is the primary way to alter a token at login time and add any attributes, set ip, email, admin status, roles and so on.
1. `Validator` - interface with `Validate(token string, claims Claims) bool` method. This is post-token hook and will be called on **each request** wrapped with `Auth` middleware. This will be the place for special logic to reject some tokens or users.
1. `UserUpdater` - interface with `Update(claims token.User) token.User` method. This method will be called on **each request** wrapped with `UpdateUser` middleware. This will be the place for special logic modify User Info in request context. [Example of usage.](https://github.com/go-pkgz/auth/blob/19c1b6d26608494955a4480f8f6165af85b1deab/_example/main.go#L189)
1. `AuthErrorHTTPHandler` - interface with `ServeAuthError(w http.ResponseWriter, r *http.Request, and other params)` method. It is possible to change how authentication errors are written into HTTP responses by configuring custom implementations of this interface for the middlewares.

All of the interfaces above have corresponding Func adapters - `SecretFunc`, `ClaimsUpdFunc`, `ValidatorFunc` and `UserUpdFunc`.
Some of the interfaces above have corresponding Func adapters - `SecretFunc`, `ClaimsUpdFunc`, `ValidatorFunc` and `UserUpdFunc`.
Comment thread
paskal marked this conversation as resolved.
Outdated

### Implementing black list logic or some other filters

Expand Down
22 changes: 12 additions & 10 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ type Opts struct {
AvatarRoutePath string // avatar routing prefix, i.e. "/api/v1/avatar", default `/avatar`
UseGravatar bool // for email based auth (verified provider) use gravatar service

AdminPasswd string // if presented, allows basic auth with user admin and given password
BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored
AudienceReader token.Audience // list of allowed aud values, default (empty) allows any
AudSecrets bool // allow multiple secrets (secret per aud)
Logger logger.L // logger interface, default is no logging at all
RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens
AdminPasswd string // if presented, allows basic auth with user admin and given password
BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored
AudienceReader token.Audience // list of allowed aud values, default (empty) allows any
AudSecrets bool // allow multiple secrets (secret per aud)
Logger logger.L // logger interface, default is no logging at all
RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens
AuthErrorHTTPHandler middleware.AuthErrorHTTPHandler // optional HTTP handler for authentication errors
}

// NewService initializes everything
Expand All @@ -81,10 +82,11 @@ func NewService(opts Opts) (res *Service) {
opts: opts,
logger: opts.Logger,
authMiddleware: middleware.Authenticator{
Validator: opts.Validator,
AdminPasswd: opts.AdminPasswd,
BasicAuthChecker: opts.BasicAuthChecker,
RefreshCache: opts.RefreshCache,
Validator: opts.Validator,
AdminPasswd: opts.AdminPasswd,
BasicAuthChecker: opts.BasicAuthChecker,
RefreshCache: opts.RefreshCache,
AuthErrorHTTPHandler: opts.AuthErrorHTTPHandler,
},
issuer: opts.Issuer,
useGravatar: opts.UseGravatar,
Expand Down
201 changes: 201 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package auth
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -246,6 +247,206 @@ func TestIntegrationList(t *testing.T) {
assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b))
}

type testAuthErrorHTTPHandler struct {
wasCalled bool
statusCode int
contentType string
responseBody string
}

func (h *testAuthErrorHTTPHandler) ServeAuthError(
w http.ResponseWriter,
_ *http.Request,
authError error,
reason string,
statusCode int,
) {
h.wasCalled = true
w.Header().Set("Content-Type", h.contentType)
w.WriteHeader(h.statusCode)
fmt.Fprint(w, h.responseBody)
}

func TestIntegrationAuthErrorHTTPHandler(t *testing.T) {
Comment thread
paskal marked this conversation as resolved.
testErrorHandler1 := &testAuthErrorHTTPHandler{
statusCode: 401,
contentType: "application/json",
responseBody: `{"code": 401, "message": "from general error handler"}`,
}
testErrorHandler2 := &testAuthErrorHTTPHandler{
statusCode: 403,
contentType: "text/html",
responseBody: `<html><body><h1>from private2 error handler</h1></body></html>`,
}
testErrorHandler3 := &testAuthErrorHTTPHandler{
statusCode: 403,
contentType: "application/json",
responseBody: `{"code": 401, "message": "from admin error handler"}`,
}
testErrorHandler4 := &testAuthErrorHTTPHandler{
statusCode: 403,
contentType: "text/html",
responseBody: `<html><body><h1>from RBAC error handler</h1></body></html>`,
}

options := Opts{
SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }),
Issuer: "my-test-app",
URL: "http://127.0.0.1:8089",
}

svc := NewService(options)
svc.AddDevProvider("localhost", 18084) // add dev provider on 18084
svc.authMiddleware.AuthErrorHTTPHandler = testErrorHandler1

// setup http server
m := svc.Middleware()
mux := http.NewServeMux()
mux.Handle("/private1",
m.Auth(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("protected route1\n"))
},
),
),
)
mux.Handle("/private2",
m.AuthWithErrorHTTPHandler(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("protected route2\n"))
},
),
testErrorHandler2,
),
)
mux.Handle("/admin1",
m.AdminOnly(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("admin route1\n"))
},
),
),
)
mux.Handle("/admin2",
m.AdminOnlyWithErrorHTTPHandler(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("admin route2\n"))
},
),
testErrorHandler3,
),
)
mux.Handle("/rbac1",
m.RBAC("role1", "role2")(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("rbac route1\n"))
},
),
),
)
mux.Handle("/rbac2",
m.RBACwithErrorHTTPHandler(testErrorHandler4, "role1", "role2")(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("rbac route2\n"))
},
),
),
)

l, listenErr := net.Listen("tcp", "127.0.0.1:8089")
require.Nil(t, listenErr)
ts := httptest.NewUnstartedServer(mux)
assert.NoError(t, ts.Listener.Close())
ts.Listener = l
ts.Start()
defer func() {
ts.Close()
}()

assertBodyEquals := func(t *testing.T, r *http.Response, expectedBody string) {
b, err := io.ReadAll(r.Body)
require.NoError(t, err)
assert.Equal(t, expectedBody, string(b))
}
assertContentTypeEquals := func(t *testing.T, r *http.Response, expectedContentType string) {
assert.Equal(t, expectedContentType, r.Header.Get("Content-Type"))
}

// private1
resp, err := http.Get("http://127.0.0.1:8089/private1")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler1.wasCalled)

assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assertContentTypeEquals(t, resp, "application/json")
assertBodyEquals(t, resp, `{"code": 401, "message": "from general error handler"}`)

// private2
resp, err = http.Get("http://127.0.0.1:8089/private2")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler2.wasCalled)

assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assertContentTypeEquals(t, resp, "text/html")
assertBodyEquals(t, resp, `<html><body><h1>from private2 error handler</h1></body></html>`)

// admin1
testErrorHandler1.wasCalled = false
resp, err = http.Get("http://127.0.0.1:8089/admin1")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler1.wasCalled)

assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assertContentTypeEquals(t, resp, "application/json")
assertBodyEquals(t, resp, `{"code": 401, "message": "from general error handler"}`)

// admin2
resp, err = http.Get("http://127.0.0.1:8089/admin2")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler3.wasCalled)

assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assertContentTypeEquals(t, resp, "application/json")
assertBodyEquals(t, resp, `{"code": 401, "message": "from admin error handler"}`)

// rbac1
testErrorHandler1.wasCalled = false
resp, err = http.Get("http://127.0.0.1:8089/rbac1")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler1.wasCalled)

assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assertContentTypeEquals(t, resp, "application/json")
assertBodyEquals(t, resp, `{"code": 401, "message": "from general error handler"}`)

// rbac2
resp, err = http.Get("http://127.0.0.1:8089/rbac2")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler4.wasCalled)

assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assertContentTypeEquals(t, resp, "text/html")
assertBodyEquals(t, resp, `<html><body><h1>from RBAC error handler</h1></body></html>`)
}

func TestIntegrationUserInfo(t *testing.T) {
_, teardown := prepService(t)
defer teardown()
Expand Down
Loading