diff --git a/app/services/account_getter.go b/app/services/account_getter.go index e5403ab77..6ba83fb85 100644 --- a/app/services/account_getter.go +++ b/app/services/account_getter.go @@ -6,16 +6,37 @@ import ( "github.com/pkg/errors" ) -func AccountGetter(store data.AccountStore, accountID int) (*models.Account, error) { - account, err := store.Find(accountID) - if err != nil { - return nil, errors.Wrap(err, "Find") +type AccountGetterParams struct { + AccountID *int + Username *string +} + +func AccountGetter(store data.AccountStore, params AccountGetterParams) (*models.Account, error) { + var account *models.Account + + if params.AccountID != nil { + ac, err := store.Find(*params.AccountID) + if err != nil { + return nil, errors.Wrap(err, "Find") + } + + account = ac } + + if params.Username != nil && params.AccountID == nil { + ac, err := store.FindByUsername(*params.Username) + if err != nil { + return nil, errors.Wrap(err, "FindByUsername") + } + + account = ac + } + if account == nil { return nil, FieldErrors{{"account", ErrNotFound}} } - oauthAccounts, err := store.GetOauthAccounts(accountID) + oauthAccounts, err := store.GetOauthAccounts(account.ID) if err != nil { return nil, errors.Wrap(err, "GetOauthAccounts") } diff --git a/app/services/account_getter_test.go b/app/services/account_getter_test.go index 8bec14a33..3d18f287d 100644 --- a/app/services/account_getter_test.go +++ b/app/services/account_getter_test.go @@ -11,9 +11,29 @@ import ( func TestAccountGetter(t *testing.T) { + t.Run("get username", func(t *testing.T) { + accountStore := mock.NewAccountStore() + acc, err := accountStore.Create("user@keratin.tech", []byte("password")) + require.NoError(t, err) + + accountID := acc.ID + + account, err := services.AccountGetter(accountStore, services.AccountGetterParams{ + Username: &acc.Username, + }) + require.NoError(t, err) + + require.Equal(t, accountID, account.ID) + }) + t.Run("get non existing account", func(t *testing.T) { accountStore := mock.NewAccountStore() - account, err := services.AccountGetter(accountStore, 9999) + + accountID := 9999 + + account, err := services.AccountGetter(accountStore, services.AccountGetterParams{ + AccountID: &accountID, + }) require.NotNil(t, err) require.Nil(t, account) @@ -24,7 +44,11 @@ func TestAccountGetter(t *testing.T) { acc, err := accountStore.Create("user@keratin.tech", []byte("password")) require.NoError(t, err) - account, err := services.AccountGetter(accountStore, acc.ID) + accountID := acc.ID + + account, err := services.AccountGetter(accountStore, services.AccountGetterParams{ + AccountID: &accountID, + }) require.NoError(t, err) require.Equal(t, 0, len(account.OauthAccounts)) @@ -41,7 +65,11 @@ func TestAccountGetter(t *testing.T) { err = accountStore.AddOauthAccount(acc.ID, "trial", "ID2", "email2", "TOKEN2") require.NoError(t, err) - account, err := services.AccountGetter(accountStore, acc.ID) + accountID := acc.ID + + account, err := services.AccountGetter(accountStore, services.AccountGetterParams{ + AccountID: &accountID, + }) require.NoError(t, err) oAccounts := account.OauthAccounts diff --git a/app/services/totp_creator.go b/app/services/totp_creator.go index cdc1fa4d9..f0ce99243 100644 --- a/app/services/totp_creator.go +++ b/app/services/totp_creator.go @@ -12,7 +12,9 @@ var ErrExistingTOTPSecret = errors.New("a OTP secret has already been establishe // TOTPCreator handles the creation and storage of new OTP tokens func TOTPCreator(accountStore data.AccountStore, totpCache data.TOTPCache, accountID int, audience *route.Domain) (*otp.Key, error) { - account, err := AccountGetter(accountStore, accountID) + account, err := AccountGetter(accountStore, AccountGetterParams{ + AccountID: &accountID, + }) if err != nil { return nil, err } diff --git a/app/services/totp_setter.go b/app/services/totp_setter.go index c7eac8f52..7e6aac1e5 100644 --- a/app/services/totp_setter.go +++ b/app/services/totp_setter.go @@ -14,7 +14,9 @@ func TOTPSetter(accountStore data.AccountStore, totpCache data.TOTPCache, cfg *a return FieldErrors{{"otp", ErrInvalidOrExpired}} } - account, err := AccountGetter(accountStore, accountID) + account, err := AccountGetter(accountStore, AccountGetterParams{ + AccountID: &accountID, + }) if err != nil { return err } diff --git a/server/handlers/get_account.go b/server/handlers/get_account.go index a47c87786..b41599f10 100644 --- a/server/handlers/get_account.go +++ b/server/handlers/get_account.go @@ -11,13 +11,25 @@ import ( func GetAccount(app *app.App) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - id, err := strconv.Atoi(mux.Vars(r)["id"]) - if err != nil { - WriteNotFound(w, "account") + var paramID *int + var paramUsername *string + + idOrUsername := mux.Vars(r)["id"] + if idOrUsername == "" { return } - account, err := services.AccountGetter(app.AccountStore, id) + id, err := strconv.Atoi(idOrUsername) + if err != nil { + paramUsername = &idOrUsername + } else { + paramID = &id + } + + account, err := services.AccountGetter(app.AccountStore, services.AccountGetterParams{ + AccountID: paramID, + Username: paramUsername, + }) if err != nil { if _, ok := err.(services.FieldErrors); ok { WriteNotFound(w, "account") diff --git a/server/handlers/get_account_test.go b/server/handlers/get_account_test.go index cf0668eec..2259b1170 100644 --- a/server/handlers/get_account_test.go +++ b/server/handlers/get_account_test.go @@ -52,6 +52,26 @@ func TestGetAccount(t *testing.T) { assertGetAccountResponse(t, res, account, oauthAccounts) }) + + t.Run("valid account username", func(t *testing.T) { + account, err := app.AccountStore.Create("unlocked2@test.com", []byte("bar")) + require.NoError(t, err) + + err = app.AccountStore.AddOauthAccount(account.ID, "test2", "ID21", "email21", "TOKEN21") + require.NoError(t, err) + + err = app.AccountStore.AddOauthAccount(account.ID, "trial2", "ID22", "email22", "TOKEN22") + require.NoError(t, err) + + oauthAccounts, err := app.AccountStore.GetOauthAccounts(account.ID) + require.NoError(t, err) + + res, err := client.Get(fmt.Sprintf("/accounts/%v", account.Username)) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) + + assertGetAccountResponse(t, res, account, oauthAccounts) + }) } func assertGetAccountResponse(t *testing.T, res *http.Response, acc *models.Account, oAccs []*models.OauthAccount) { diff --git a/server/handlers/get_oauth_accounts.go b/server/handlers/get_oauth_accounts.go index 31eead8e8..741afb48c 100644 --- a/server/handlers/get_oauth_accounts.go +++ b/server/handlers/get_oauth_accounts.go @@ -16,7 +16,9 @@ func GetOauthAccounts(app *app.App) http.HandlerFunc { return } - account, err := services.AccountGetter(app.AccountStore, accountID) + account, err := services.AccountGetter(app.AccountStore, services.AccountGetterParams{ + AccountID: &accountID, + }) if err != nil { WriteErrors(w, err) return diff --git a/server/private_routes.go b/server/private_routes.go index f394b359b..aa24024e1 100644 --- a/server/private_routes.go +++ b/server/private_routes.go @@ -32,7 +32,7 @@ func PrivateRoutes(app *app.App) []*route.HandledRoute { SecuredWith(authentication). Handle(handlers.PostAccountsImport(app)), - route.Get("/accounts/{id:[0-9]+}"). + route.Get("/accounts/{id}"). SecuredWith(authentication). Handle(handlers.GetAccount(app)),