Refreshing Access Tokens in a Reactive Environment with Spring-Boot and Webflux

 Let's say we have a typical oauth2 setup with a frontend service, an authorization server and a resource server, based on Spring Boot and using Webflux. The login of the end-user is processed using the "authorization_code" grant type.

Here is how it works:

1. The user opens the URL of the frontend service in their browser (e.g. https://www.mysite.com). The frontend service redirects the user to the authorization server (e.g. https://auth.mysite.com).

2. The browser follows the redirect, opens the authorization server's URL and the user logs in with their credentials. The authorization server redirects the user back to the frontend service, including a code (authorization_code) parameter in the URL.

3. The browser follows the redirect and opens the frontend service URL. The frontend service calls internally the authorization server with the code provided and receives the access and refresh tokens. It stores them in the http session of the user. The login is completed.

4. The user opens a page or a view, requesting a resource from the resource server. The frontend service calls internally the resource server, including the access token, saved in the user's http session.

5. The resource server checks the validity of the access token in the authorization server and if valid, it returns the requested resource.

6. The frontend service forwards the resource to the browser.




Delegation of the Access Token

In point 4. above it is mentioned, that the call to the resource server includes the access token, saved in the session. To do that easily, we can define and configure a WebClient with an ExchangeFilterFunction, which will insert the access token on each call.

First we define the client repository:

@Bean
public OAuth2AuthorizedClientRepository authorizedClientRepository() {
    return new HttpSessionOAuth2AuthorizedClientRepository();
}

Then the WebClient, autowiring the repository above:

@Bean
public WebClient rest() {
    return WebClient.builder()
            .filter(new WebClientAccessTokenExchangeFilterFunction(authorizedClientRepository))
            .build();
}

The ExchangeFilterFunction above looks like this:

public class WebClientAccessTokenExchangeFilterFunction implements ExchangeFilterFunction {

    private final OAuth2AuthorizedClientRepository authorizedClientRepository;

    public WebClientAccessTokenExchangeFilterFunction(OAuth2AuthorizedClientRepository authorizedClientRepository) {
        this.authorizedClientRepository = authorizedClientRepository;
    }

    @Override
    public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {

        RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
        if (requestAttributes != null) {
            HttpServletRequest servletRequest = ((ServletRequestAttributes) requestAttributes).getRequest();

            OAuth2AuthenticationToken oauth2Token = (OAuth2AuthenticationToken)
                    SecurityContextHolder.getContext().getAuthentication();

            OAuth2AuthorizedClient client = authorizedClientRepository
                    .loadAuthorizedClient(oauth2Token.getAuthorizedClientRegistrationId(), oauth2Token, servletRequest);

            return next.exchange(ClientRequest.from(request).headers(headers ->
                    headers.setBearerAuth(client.getAccessToken().getTokenValue())).build());
        }

        return next.exchange(request);
    }
}

Here is how to use the WebClient:

@Autowired
private WebClient webClient;

public Mono<ResponseEntity<String>> getMyResource(String id) {
    return webClient.get()
            .uri("http://resource_service/resources/{id}", id)
            .retrieve()
            .toEntity(String.class);
}

The oauth2 configuration of the frontend service looks like this (example):

spring:
  security:
    oauth2:
      client:
        registration:
          mysite:
            client-name: mysite
            client-id: myclient
            client-secret: mysecret
            client-authentication-method: client_secret_basic
            authorization-grant-type: authorization_code
            redirect-uri: https://www.mysite.com/login
            scope: some:scope
        provider:
          mysite:
            authorization-uri: https://auth.mysite.com/authorize
            token-uri: https://auth.mysite.com/token
            user-info-uri: https://auth.mysite.com/me
            user-info-authentication-method: header
            user-name-attribute: user


Refreshing the Access Token

Maybe you have noticed, that the above code always inserts the same access token, which was stored in the http session during the login. Even if it is expired. In this case, the call to the resource server would return a HTTP 401 Unauthorized error. To tackle this, we must refresh the access token if it is expired.

You can create a timer, that checks regularly for expired tokens and refreshes them, but nevertheless there is still a chance that you migh get an error. So, additionally it is a good idea to refresh the token, if you get a 401 error from the resource server. In the WebClient you can achieve that with a custom ExchangeFunction:

public class Oauth2AwareExchangeFunction implements ExchangeFunction {

    private static final Logger LOGGER = LoggerFactory.getLogger(Oauth2AwareExchangeFunction.class);

    private final ExchangeFunction defaultFunction;
    private final OAuth2AuthorizedClientRepository authorizedClientRepository;
    private final RefreshTokenReactiveOAuth2AuthorizedClientProvider refreshTokenProvider;

    public Oauth2AwareExchangeFunction(ExchangeFunction defaultFunction,
                                       OAuth2AuthorizedClientRepository authorizedClientRepository, RefreshTokenReactiveOAuth2AuthorizedClientProvider refreshTokenProvider) {
        this.defaultFunction = defaultFunction;
        this.authorizedClientRepository = authorizedClientRepository;
        this.refreshTokenProvider = refreshTokenProvider;
    }

    @Override
    public ExchangeFunction filter(ExchangeFilterFunction filter) {
        return defaultFunction.filter(filter);
    }

    @Override
    public Mono<ClientResponse> exchange(ClientRequest request) {
        RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
        OAuth2AuthenticationToken oauth2Token = (OAuth2AuthenticationToken)
                SecurityContextHolder.getContext().getAuthentication();

        return defaultFunction.exchange(request)
                .flatMap(res -> handleError(request, res, requestAttributes, oauth2Token));
    }

    private Mono<ClientResponse> handleError(ClientRequest request, ClientResponse response,
                                             RequestAttributes requestAttributes, OAuth2AuthenticationToken oauth2Token) {
        if (response.statusCode() == HttpStatus.UNAUTHORIZED) {
            return refreshAccessTokenAndRetry(request, requestAttributes, oauth2Token);
        }

        return Mono.just(response);
    }

    private Mono<ClientResponse> refreshAccessTokenAndRetry(ClientRequest request, RequestAttributes requestAttributes,
                                                            OAuth2AuthenticationToken oauth2Token) {
        LOGGER.debug("Refreshing access token");

        if (requestAttributes != null) {
            HttpServletRequest servletRequest = ((ServletRequestAttributes) requestAttributes).getRequest();
            HttpServletResponse servletResponse = ((ServletRequestAttributes) requestAttributes).getResponse();

            OAuth2AuthorizedClient client = authorizedClientRepository
                    .loadAuthorizedClient(oauth2Token.getAuthorizedClientRegistrationId(), oauth2Token, servletRequest);

            OAuth2AuthorizationContext oauth2Context = OAuth2AuthorizationContext
                    .withAuthorizedClient(client)
                    .principal(oauth2Token)
                    .build();

            return refreshTokenProvider.authorize(oauth2Context)
                    .flatMap(c -> {
                        authorizedClientRepository.saveAuthorizedClient(c, oauth2Token, servletRequest, servletResponse);
                        ClientRequest newRequest = ClientRequest.from(request).headers(headers ->
                                headers.setBearerAuth(c.getAccessToken().getTokenValue())).build();
                        return defaultFunction.exchange(newRequest);
                    });
        }

        return Mono.error(new IllegalStateException("requestAttributes cannot be null"));
    }
}

And this is how to add it to the WebClient configuration:

@Autowired
private ClientHttpConnector connector;

@Bean
public WebClient rest() {
    return WebClient.builder()
            .filter(new WebClientAccessTokenExchangeFilterFunction(authorizedClientRepository))
            .exchangeFunction(new Oauth2AwareExchangeFunction(ExchangeFunctions.create(connector),
                    authorizedClientRepository, new RefreshTokenReactiveOAuth2AuthorizedClientProvider()))
            .build();
}


Synchronizing the Timeouts

As you might already see, there are multiple entities with different durations of validity: the web service's http session, the authorization server's http session, the access token and the refresh token.

If they are not synced properly, you might either "never be able to logout" or you will be "logged out partially", which would lead to getting "Unauthorized" errors.

The rule is simple:

- the timeouts of the http sessions and the timeout of the access token are of your choice.

- the timeout of the refresh token should be longer than the timeout of the http session of the web service plus the timeout of the access token.

Example:

- authoirization server http session timeout: 5min

- frontend service http session timeout: 30min

- access token timeout: 10min

- refresh token timeout: 41min


The Logout

A nice-to-have:

You can set the logout success URL in the frontend service to the logout URL of the authorization server and the logout success URL of the authorization server to the login (or base) URL of the frontend service. This way when the user logs out in the frontend service, they will be redirected and logged out from the authorization server as well and then redirected back to the frontend service again.

In the frontend service it looks like this:

@Bean
public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
    return http
        ...
        .logout()
            .logoutRequestMatcher(new AntPathRequestMatcher("/logout"))
            .logoutSuccessUrl("https://auth.mysite.com/logout")
            .invalidateHttpSession(true)
            .deleteCookies(COOKIE_NAME)
            .and()
        ...
        .build();
}


Thread-Safety

Is the above thread-safe? No. :-) 

Normally it is also not required or not critical. If it happens that the browser calls the frontend service multiple times simultaneously, it could happen that the access token is refreshed multiple times. Normally the authorization servers can coup with that.