diff --git a/services/auth/go/controller/sign_in_provider_callback_get.go b/services/auth/go/controller/sign_in_provider_callback_get.go index 6c22d9ace..0c58ccfa7 100644 --- a/services/auth/go/controller/sign_in_provider_callback_get.go +++ b/services/auth/go/controller/sign_in_provider_callback_get.go @@ -47,6 +47,15 @@ func (ctrl *Controller) getStateData( return stateData, nil } +func attachURLValues(u *url.URL, values map[string]string) { + q := u.Query() + for k, v := range values { + q.Set(k, v) + } + + u.RawQuery = q.Encode() +} + func (ctrl *Controller) signinProviderProviderCallbackValidate( ctx context.Context, req providerCallbackData, @@ -56,6 +65,10 @@ func (ctrl *Controller) signinProviderProviderCallbackValidate( stateData, apiErr := ctrl.getStateData(ctx, req.State, logger) if apiErr != nil { + attachURLValues(redirectTo, map[string]string{ + "provider_state": req.State, + }) + return nil, nil, redirectTo, apiErr } @@ -72,16 +85,17 @@ func (ctrl *Controller) signinProviderProviderCallbackValidate( } if req.Error != nil && *req.Error != "" { - values := redirectTo.Query() - values.Add("provider_error", deptr(req.Error)) - values.Add("provider_error_description", deptr(req.ErrorDescription)) - values.Add("provider_error_url", deptr(req.ErrorURI)) - - if stateData.State != nil && *stateData.State != "" { - values.Add("state", *stateData.State) + values := map[string]string{ + "provider_error": deptr(req.Error), + "provider_error_description": deptr(req.ErrorDescription), + "provider_error_url": deptr(req.ErrorURI), } - redirectTo.RawQuery = values.Encode() + if stateData.State != nil && *stateData.State != "" { + values["state"] = *stateData.State + } + + attachURLValues(redirectTo, values) return nil, nil, redirectTo, ErrOauthProviderError } @@ -93,9 +107,9 @@ func (ctrl *Controller) signinProviderProviderCallbackValidate( } if stateData.State != nil && *stateData.State != "" { - values := optionsRedirectTo.Query() - values.Add("state", *stateData.State) - optionsRedirectTo.RawQuery = values.Encode() + attachURLValues(optionsRedirectTo, map[string]string{ + "state": *stateData.State, + }) } return stateData.Options, stateData.Connect, optionsRedirectTo, nil diff --git a/services/auth/go/controller/sign_in_provider_callback_get_test.go b/services/auth/go/controller/sign_in_provider_callback_get_test.go index 78419c852..c993a8465 100644 --- a/services/auth/go/controller/sign_in_provider_callback_get_test.go +++ b/services/auth/go/controller/sign_in_provider_callback_get_test.go @@ -666,7 +666,7 @@ func TestSignInProviderCallback(t *testing.T) { //nolint:maintidx }, expectedResponse: controller.ErrorRedirectResponse{ Headers: struct{ Location string }{ - Location: `http://localhost:3000?error=invalid-state&errorDescription=Invalid+state`, + Location: `^http://localhost:3000\?error=invalid-state&errorDescription=Invalid\+state&provider_state=wrong-state$`, //nolint:lll }, }, expectedJWT: nil,