The code implemented in KeycloakSecurityContext
and other related classes was included in order to mitigate a bug associated with the CVE CVE-2020-1714
:
A flaw was found in Keycloak, where the code base contains usages of ObjectInputStream without type checks. This flaw allows an attacker to inject arbitrarily serialized Java Objects, which would then get deserialized in a privileged context and potentially lead to remote code execution.
The solution tries to address the above-mentioned vulnerability by implementing a custom ObjectInputFilter
.
As indicated in the bug description, the idea behind the Java serialization filtering mechanism is to prevent deserialization vulnerabilities that can lead to remote code execution that can cause a security issue to the application.
In many situations, like when dealing with sessions, several objects, those stored in the session following the example, are serialized and will be later deserialized together, in other words, they will be written to the same ObjectOutputStream
and then they will be read from the same ObjectInputStream
.
When a ObjectInputFilter
is applied, which can be done at several levels - process, application and specific ObjectInputStream
- only the objects that satisfy the configured filter pattern will be deserialized; depending on the pattern itself, the rest will be either rejected or the decision will be delegated to a process-wide filter if one exists.
Please, consider the following example, where A
, B
, C
and D
are classes, and a filter A;D;!*
is applied over an hypothetical object input stream, represented by the top marble diagram line, being the bottom one the deseriaization result after the filter is applied:
This pattern can be defined in terms of modules, packages and/or individual classes, and can be set as the jdk.serialFilter
system property, or by editing the java.security
properties file.
You can create custom filters as well. They are implemented using the API provided by ObjectInputFilter
, and allow for a more granular serialization control because they can be specific to a specific ObjectInputStream
.
Please, see the relevant Oracle Serialization documentation for more information.
The Keycloak
serialization filters use the utility class DelegatingSerializationFilter
.
In the implementation provided, the filter is applied inside the KeycloakSecurityContext
readObject
method:
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
DelegatingSerializationFilter.builder()
.addAllowedClass(KeycloakSecurityContext.class)
.setFilter(in);
in.defaultReadObject();
token = parseToken(tokenString, AccessToken.class);
idToken = parseToken(idTokenString, IDToken.class);
}
As a consequence, in order to work it properly, as you pointed out, it is necessary to define the jdk.serialSetFilterAfterRead
system property as true
when running your program.
This filter will be always applied unless a previous, non process-wide filter (see the section labeled Setting a Process-Wide Custom Filter in the Oracle documentation), has been already applied to the ObjectInputStream
; it will be guaranteed by the setObjectInputFilter
method and it is checked by the DelegatingSerializationFilter
class as well:
private void setFilter(ObjectInputStream ois, String filterPattern) {
LOG.debug("Using: " + serializationFilterAdapter.getClass().getSimpleName());
if (serializationFilterAdapter.getObjectInputFilter(ois) == null) {
serializationFilterAdapter.setObjectInputFilter(ois, filterPattern);
}
}
In other words, the only way to avoid a custom filter to be applied is to provide first an initial filter over the target ObjectInputStream
, the one that contains your session data this time. As indicated in the docs:
The filter mechanism is called for each new object in the stream. If more than one active filter (process-wide filter, application filter, or stream-specific filter) exists, only the most specific filter is called.
The way in which the creation of this initial filter should be accomplished is highly dependent on the code that is actually dealing with the deserialization functionality.
In your use case, you probably are using memcached-session-manager
, either the original version or some more updated projects in Github.
In a normal use case, the sessions in memcached-session-manager
are mainly handled by the code defined in MemcachedSessionService
.
This class uses TranscoderService
for handling the Java serialization stuff.
TranscoderService
in turn delegates that responsability to a proper implementation of TranscoderFactory
and SessionAttributesTranscoder
.
JavaSerializationTranscoderFactory
and the associated class JavaSerializationTranscoder
are the default implementations of these interfaces.
Please, pay attention to the deserializeAttributes
method of JavaSerializationTranscoder
, it defines the logic for session deserialization:
/**
* Get the object represented by the given serialized bytes.
*
* @param in
* the bytes to deserialize
* @return the resulting object
*/
@Override
public ConcurrentMap<String, Object> deserializeAttributes(final byte[] in ) {
ByteArrayInputStream bis = null;
ObjectInputStream ois = null;
try {
bis = new ByteArrayInputStream( in );
ois = createObjectInputStream( bis );
final ConcurrentMap<String, Object> attributes = new ConcurrentHashMap<String, Object>();
final int n = ( (Integer) ois.readObject() ).intValue();
for ( int i = 0; i < n; i++ ) {
final String name = (String) ois.readObject();
final Object value = ois.readObject();
if ( ( value instanceof String ) && ( value.equals( NOT_SERIALIZED ) ) ) {
continue;
}
if ( LOG.isDebugEnabled() ) {
LOG.debug( " loading attribute '" + name + "' with value '" + value + "'" );
}
attributes.put( name, value );
}
return attributes;
} catch ( final ClassNotFoundException e ) {
LOG.warn( "Caught CNFE decoding "+ in.length +" bytes of data", e );
throw new TranscoderDeserializationException( "Caught CNFE decoding data", e );
} catch ( final IOException e ) {
LOG.warn( "Caught IOException decoding "+ in.length +" bytes of data", e );
throw new TranscoderDeserializationException( "Caught IOException decoding data", e );
} finally {
closeSilently( bis );
closeSilently( ois );
}
}
As you can see, the problem is that the session information, represented by the input byte array, can contain several attributes, and all of them are deserialized from the same ObjectInputStream
. Once the Keycloak
ObjectInputFilter
is applied on this ObjectInputStream
, as you indicated, it will reject the rest of classes that are not allowed by the filter. The reason is that DelegatingSerializationFilter
append a final !*
to the filter pattern that is being constructed, excluding everything but the explicitly provided class and text based patterns (well, and the classes of java.util.*
to allow collections).
In order to avoid this problem, try providing your own implementation of SessionAttributesTranscoder
, and include a method something similar to deserializeAttributes
but defining an initial filter over the constructed ObjectInputStream
.
For example (please, forgive to define the whole class, you can probably reuse the code of JavaSerializationTranscoder
in a certain way):
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.NotSerializableException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.apache.catalina.session.StandardSession;
import org.apache.catalina.util.CustomObjectInputStream;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import de.javakaffee.web.msm.MemcachedSessionService.SessionManager;
public class CustomJavaSerializationTranscoder implements SessionAttributesTranscoder {
private static final Log LOG = LogFactory.getLog( CustomJavaSerializationTranscoder.class );
private static final String EMPTY_ARRAY[] = new String[0];
/**
* The dummy attribute value serialized when a NotSerializableException is
* encountered in <code>writeObject()</code>.
*/
protected static final String NOT_SERIALIZED = "___NOT_SERIALIZABLE_EXCEPTION___";
private final SessionManager _manager;
/**
* Constructor.
*
* @param manager
* the manager
*/
public CustomJavaSerializationTranscoder() {
this( null );
}
/**
* Constructor.
*
* @param manager
* the manager
*/
public CustomJavaSerializationTranscoder( final SessionManager manager ) {
_manager = manager;
}
/**
* {@inheritDoc}
*/
@Override
public byte[] serializeAttributes( final MemcachedBackupSession session, final ConcurrentMap<String, Object> attributes ) {
if ( attributes == null ) {
throw new NullPointerException( "Can't serialize null" );
}
ByteArrayOutputStream bos = null;
ObjectOutputStream oos = null;
try {
bos = new ByteArrayOutputStream();
oos = new ObjectOutputStream( bos );
writeAttributes( session, attributes, oos );
return bos.toByteArray();
} catch ( final IOException e ) {
throw new IllegalArgumentException( "Non-serializable object", e );
} finally {
closeSilently( bos );
closeSilently( oos );
}
}
private void writeAttributes( final MemcachedBackupSession session, final Map<String, Object> attributes,
final ObjectOutputStream oos ) throws IOException {
// Accumulate the names of serializable and non-serializable attributes
final String keys[] = attributes.keySet().toArray( EMPTY_ARRAY );
final List<String> saveNames = new ArrayList<String>();
final List<Object> saveValues = new ArrayList<Object>();
for ( int i = 0; i < keys.length; i++ ) {
final Object value = attributes.get( keys[i] );
if ( value == null || session.exclude( keys[i], value ) ) {
continue;
} else if ( value instanceof Serializable ) {
saveNames.add( keys[i] );
saveValues.add( value );
} else {
if ( LOG.isDebugEnabled() ) {
LOG.debug( "Ignoring attribute '" + keys[i] + "' as it does not implement Serializable" );
}
}
}
// Serialize the attribute count and the Serializable attributes
final int n = saveNames.size();
oos.writeObject( Integer.valueOf( n ) );
for ( int i = 0; i < n; i++ ) {
oos.writeObject( saveNames.get( i ) );
try {
oos.writeObject( saveValues.get( i ) );
if ( LOG.isDebugEnabled() ) {
LOG.debug( " storing attribute '" + saveNames.get( i ) + "' with value '" + saveValues.get( i ) + "'" );
}
} catch ( final NotSerializableException e ) {
LOG.warn( _manager.getString( "standardSession.notSerializable", saveNames.get( i ), session.getIdInternal() ), e );
oos.writeObject( NOT_SERIALIZED );
if ( LOG.isDebugEnabled() ) {
LOG.debug( " storing attribute '" + saveNames.get( i ) + "' with value NOT_SERIALIZED" );
}
}
}
}
/**
* Get the object represented by the given serialized bytes.
*
* @param in
* the bytes to deserialize
* @return the resulting object
*/
@Override
public ConcurrentMap<String, Object> deserializeAttributes(final byte[] in ) {
ByteArrayInputStream bis = null;
ObjectInputStream ois = null;
try {
bis = new ByteArrayInputStream( in );
ois = createObjectInputStream( bis );
// Fix deserialization
fixDeserialization(ois);
final ConcurrentMap<String, Object> attributes = new ConcurrentHashMap<String, Object>();
final int n = ( (Integer) ois.readObject() ).intValue();
for ( int i = 0; i < n; i++ ) {
final String name = (String) ois.readObject();
final Object value = ois.readObject();
if ( ( value instanceof String ) && ( value.equals( NOT_SERIALIZED ) ) ) {
continue;
}
if ( LOG.isDebugEnabled() ) {
LOG.debug( " loading attribute '" + name + "' with value '" + value + "'" );
}
attributes.put( name, value );
}
return attributes;
} catch ( final ClassNotFoundException e ) {
LOG.warn( "Caught CNFE decoding "+ in.length +" bytes of data", e );
throw new TranscoderDeserializationException( "Caught CNFE decoding data", e );
} catch ( final IOException e ) {
LOG.warn( "Caught IOException decoding "+ in.length +" bytes of data", e );
throw new TranscoderDeserializationException( "Caught IOException decoding data", e );
} finally {
closeSilently( bis );
closeSilently( ois );
}
}
private ObjectInputStream createObjectInputStream( final ByteArrayInputStream bis ) throws IOException {
final ObjectInputStream ois;
ClassLoader classLoader = null;
if ( _manager != null && _manager.getContext() != null ) {
classLoader = _manager.getContainerClassLoader();
}
if ( classLoader != null ) {
ois = new CustomObjectInputStream( bis, classLoader );
} else {
ois = new ObjectInputStream( bis );
}
return ois;
}
private void closeSilently( final OutputStream os ) {
if ( os != null ) {
try {
os.close();
} catch ( final IOException f ) {
// fail silently
}
}
}
private void closeSilently( final InputStream is ) {
if ( is != null ) {
try {
is.close();
} catch ( final IOException f ) {
// fail silently
}
}
}
// Helper method, reusing the `DelegatingSerializationFilter` class, which in fact is convenient because of its portability
// accross JDK versions, to define an allow everything pattern
// Probably it should be improved to restrict to certain patterns to
// prevent security vulnerabilities
private void fixDeserialization(ObjectInputStream ois) {
DelegatingSerializationFilter.builder()
.addAllowedPattern("*")
.setFilter(ois);
}
}
Now, define a custom TranscoderFactory
. Let's reuse the code of the class JavaSerializationTranscoderFactory
this time:
import de.javakaffee.web.msm.MemcachedSessionService.SessionManager;
public class CustomJavaSerializationTranscoderFactory extends JavaSerializationTranscoderFactory {
/**
* {@inheritDoc}
*/
@Override
public SessionAttributesTranscoder createTranscoder( final SessionManager manager ) {
return new CustomJavaSerializationTranscoder( manager );
}
}
Place this classes on your classpath, with the rest of libraries from memcached-session-manager
, and provide a convenient value for the transcoderFactoryClass
memcached-session-manager
configuration property, as indicated in the docs:
The class name of the factory that creates the transcoder to use for serializing/deserializing sessions to/from memcached. The specified class must implement de.javakaffee.web.msm.TranscoderFactory
and provide a no-args constructor.
I have no the ability to test the solution, although a simple test seems to work properly:
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.NotSerializableException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
public class SerializationTest {
private static final Log LOG = LogFactory.getLog( SerializationTest.class );
protected static final String NOT_SERIALIZED = "___NOT_SERIALIZABLE_EXCEPTION___";
private static final String EMPTY_ARRAY[] = new String[0];
public byte[] serializeAttributes( final ConcurrentMap<String, Object> attributes ) {
if ( attributes == null ) {
throw new NullPointerException( "Can't serialize null" );
}
ByteArrayOutputStream bos = null;
ObjectOutputStream oos = null;
try {
bos = new ByteArrayOutputStream();
oos = new ObjectOutputStream( bos );
writeAttributes( attributes, oos );
return bos.toByteArray();
} catch ( final IOException e ) {
throw new IllegalArgumentException( "Non-serializable object", e );
} finally {
closeSilently( bos );
closeSilently( oos );
}
}
private void writeAttributes(final Map<String, Object> attributes, final ObjectOutputStream oos ) throws IOException {
// Accumulate the names of serializable and non-serializable attributes
final String keys[] = attributes.keySet().toArray( EMPTY_ARRAY );
final List<String> saveNames = new ArrayList<String>();
final List<Object> saveValues = new ArrayList<Object>();
for ( int i = 0; i < keys.length; i++ ) {
final Object value = attributes.get( keys[i] );
if ( value == null ) {
continue;
} else if ( value instanceof Serializable) {
saveNames.add( keys[i] );
saveValues.add( value );
} else {
if ( LOG.isDebugEnabled() ) {
LOG.debug( "Ignoring attribute '" + keys[i] + "' as it does not implement Serializable" );
}
}
}
// Serialize the attribute count and the Serializable attributes
final int n = saveNames.size();
oos.writeObject( Integer.valueOf( n ) );
for ( int i = 0; i < n; i++ ) {
oos.writeObject( saveNames.get( i ) );
try {
oos.writeObject( saveValues.get( i ) );
if ( LOG.isDebugEnabled() ) {
LOG.debug( " storing attribute '" + saveNames.get( i ) + "' with value '" + saveValues.get( i ) + "'" );
}
} catch ( final NotSerializableException e ) {
LOG.warn( "standardSession.notSerializable" + saveNames.get( i ), e );
oos.writeObject( NOT_SERIALIZED );
if ( LOG.isDebugEnabled() ) {
LOG.debug( " storing attribute '" + saveNames.get( i ) + "' with value NOT_SERIALIZED" );
}
}
}
}
public ConcurrentMap<String, Object> deserializeAttributes(final byte[] in ) {
ByteArrayInputStream bis = null;
ObjectInputStream ois = null;
try {
bis = new ByteArrayInputStream( in );
ois = new ObjectInputStream( bis );
fixDeserialization(ois);
final ConcurrentMap<String, Object> attributes = new ConcurrentHashMap<String, Object>();
final int n = ( (Integer) ois.readObject() ).intValue();
for ( int i = 0; i < n; i++ ) {
final String name = (String) ois.readObject();
final Object value = ois.readObject();
if ( ( value instanceof String ) && ( value.equals( NOT_SERIALIZED ) ) ) {
continue;
}
if ( LOG.isDebugEnabled() ) {
LOG.debug( " loading attribute '" + name + "' with value '" + value + "'" );
}
attributes.put( name, value );
}
return attributes;
} catch ( final ClassNotFoundException e ) {
LOG.warn( "Caught CNFE decoding "+ in.length +" bytes of data", e );
throw new RuntimeException( "Caught CNFE decoding data", e );
} catch ( final IOException e ) {
LOG.warn( "Caught IOException decoding "+ in.length +" bytes of data", e );
throw new RuntimeException( "Caught IOException decoding data", e );
} finally {
closeSilently( bis );
closeSilently( ois );
}
}
private void fixDeserialization(ObjectInputStream ois) {
DelegatingSerializationFilter.builder()
.addAllowedPattern("*")
.setFilter(ois);
}
private void closeSilently( final OutputStream os ) {
if ( os != null ) {
try {
os.close();
} catch ( final IOException f ) {
// fail silently
}
}
}
private void closeSilently( final InputStream is ) {
if ( is != null ) {
try {
is.close();
} catch ( final IOException f ) {
// fail silently
}
}
}
public static void main(String[] args) throws Exception{
Person person = new Person("Sherlock Holmes","Consulting detective");
Address address = new Address("221B Baker Street");
ConcurrentMap<String, Object> attributes = new ConcurrentHashMap<String, Object>();
attributes.put("person", person);
attributes.put("address", address);
SerializationTest test = new SerializationTest();
byte[] in = test.serializeAttributes(attributes);
System.setProperty("jdk.serialSetFilterAfterRead", "true");
ConcurrentMap<String, Object> attributesAfter = test.deserializeAttributes(in);
System.out.println(attributesAfter);
}
}
Person
and Address
two simple POJOs. Pay attention to the definition Person
readObject
method:
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
public class Person implements Serializable {
private static final long serialVersionUID = 1L;
String name;
String title;
public Person() {
}
public Person(String name, String title) {
this.name = name;
this.title = title;
}
private void writeObject(ObjectOutputStream out) throws IOException {
out.defaultWriteObject();
}
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
DelegatingSerializationFilter.builder()
.addAllowedClass(Person.class)
.setFilter(in);
in.defaultReadObject();
}
@Override
public String toString() {
return "Person{" +
"name='" + name + '\'' +
", title='" + title + '\'' +
'}';
}
}
import java.io.Serializable;
public class Address implements Serializable {
private static final long serialVersionUID = 1L;
String address;
public Address() {
}
public Address(String address) {
this.address = address;
}
@Override
public String toString() {
return "Address{" +
"address='" + address + '\'' +
'}';
}
}
Please, be aware that the filter is there in order to prevent security flaws: as also suggested in the code comments, it would be advisable to improve the logic implemented in the fixDeserialization
to restrict in some way the possible classes that your session is supposed to contain instead of use the wildcard *
as in the example.
In fact, this functionality can be included in the memcached-session-manager
library, probably by defining some kind of configuration property, serialFilter
, for instance, which value, a valid filter pattern, should be provided to the indicated underlying Java deserialization mechanisms.
I created a fork of the project: it is still a WIP, but please, see this commit, I hope you get the idea. I will try to pull a request to the forked repo once finished.
KeycloakSecurityContext
class is not final/sealed by default, why not override its read method to translateaddAllowedClass
intoaddAllowedPattern
with com.whalin.MemCached.* / org.keycloack.* / [yourclasses*] ? – Jene