diff --git a/samlsp/middleware.go b/samlsp/middleware.go index 93a3cf95..9f1e3e81 100644 --- a/samlsp/middleware.go +++ b/samlsp/middleware.go @@ -3,6 +3,7 @@ package samlsp import ( "bytes" "encoding/xml" + "errors" "net/http" "github.com/crewjam/saml" @@ -115,17 +116,16 @@ func (m *Middleware) ServeACS(w http.ResponseWriter, r *http.Request) { func (m *Middleware) RequireAccount(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session, err := m.Session.GetSession(r) - if session != nil { - r = r.WithContext(ContextWithSession(r.Context(), session)) - handler.ServeHTTP(w, r) - return - } - if err == ErrNoSession { + if err != nil && errors.Is(err, ErrNoSession) { m.HandleStartAuthFlow(w, r) return + } else if err != nil || session == nil { + m.OnError(w, r, err) + return } - m.OnError(w, r, err) + r = r.WithContext(ContextWithSession(r.Context(), session)) + handler.ServeHTTP(w, r) }) }