Disclaimer: This is a hacky way, but it works
(and if its stupid but it works, it ain't stupid)
All relevant parts will always use the ClientRegistrationRepository
to find the ClientRegistration (and with that, the scopes).
So my hacky-way to solve this was to build a wrapper around the InMemoryClientRegistrationRepository
. In my case, I wanted to allow any additional scope to be requested from the client side, so I wanted to add additional scopes from the query parameter scope
.
Here's the example code for this solution:
@Bean
public ClientRegistrationRepository clientRegistrationRepository(OAuth2ClientProperties properties) {
final List<ClientRegistration> registrations = new ArrayList<>(OAuth2ClientPropertiesRegistrationAdapter.getClientRegistrations(properties).values());
// this is the ClientRegistrationRepository that would be used by default configuration
final ClientRegistrationRepository parent = new InMemoryClientRegistrationRepository(registrations);
// this lambda is our wrapper around the configuration based ClientRegistrationRepository
return (registrationId) -> {
final ClientRegistration clientRegistration = parent.findByRegistrationId(registrationId);
if (clientRegistration == null) {
return null;
}
final HttpServletRequest request = Optional.ofNullable(RequestContextHolder.getRequestAttributes())
.filter(ServletRequestAttributes.class::isInstance)
.map(ServletRequestAttributes.class::cast)
.map(ServletRequestAttributes::getRequest)
.orElse(null);
final String query;
if (request == null || (query = request.getQueryString()) == null) {
return clientRegistration;
}
final List<String> scopeQueryParam = parseQuery(query).get(OAuth2ParameterNames.SCOPE);
if (scopeQueryParam == null) {
return clientRegistration;
}
final Set<String> scopes = scopeQueryParam.stream()
.flatMap((v) -> Arrays.stream(v.split(" ")))
.collect(Collectors.toSet());
if (clientRegistration.getScopes().containsAll(scopes)) {
return clientRegistration;
}
final Set<String> resultingScopes = new HashSet<>(scopes);
resultingScopes.addAll(clientRegistration.getScopes());
return ClientRegistration.withClientRegistration(clientRegistration)
.scope(resultingScopes)
.build();
};
}
private static MultiValueMap<String, String> parseQuery(String query) {
final MultiValueMap<String, String> result = new LinkedMultiValueMap<>();
final String[] pairs = query.split("&");
String[] pair;
for (String _pair : pairs) {
pair = _pair.split("=");
if (pair.length >= 1) {
final List<String> values = result.computeIfAbsent(URLDecoder.decode(pair[0], StandardCharsets.UTF_8), (k) -> new ArrayList<>());
if (pair.length >= 2) {
values.add(URLDecoder.decode(pair[1], StandardCharsets.UTF_8));
}
}
}
return result;
}