We have an app that the Python module will write data to redis shards and the Java module will read data from redis shards, so I need to implement the exact same consistent hashing algorithm for Java and Python to make sure the data can be found.
I googled around and tried several implementations, but found the Java and Python implementations are always different, can't be used togather. Need your help.
Edit, online implementations I have tried:
Java: http://weblogs.java.net/blog/tomwhite/archive/2007/11/consistent_hash.html
Python: http://techspot.zzzeek.org/2012/07/07/the-absolutely-simplest-consistent-hashing-example/
http://amix.dk/blog/post/19367
Edit, attached Java (Google Guava lib used) and Python code I wrote. Code are based on the above articles.
import java.util.Collection;
import java.util.SortedMap;
import java.util.TreeMap;
import com.google.common.hash.HashFunction;
public class ConsistentHash<T> {
private final HashFunction hashFunction;
private final int numberOfReplicas;
private final SortedMap<Long, T> circle = new TreeMap<Long, T>();
public ConsistentHash(HashFunction hashFunction, int numberOfReplicas,
Collection<T> nodes) {
this.hashFunction = hashFunction;
this.numberOfReplicas = numberOfReplicas;
for (T node : nodes) {
add(node);
}
}
public void add(T node) {
for (int i = 0; i < numberOfReplicas; i++) {
circle.put(hashFunction.hashString(node.toString() + i).asLong(),
node);
}
}
public void remove(T node) {
for (int i = 0; i < numberOfReplicas; i++) {
circle.remove(hashFunction.hashString(node.toString() + i).asLong());
}
}
public T get(Object key) {
if (circle.isEmpty()) {
return null;
}
long hash = hashFunction.hashString(key.toString()).asLong();
if (!circle.containsKey(hash)) {
SortedMap<Long, T> tailMap = circle.tailMap(hash);
hash = tailMap.isEmpty() ? circle.firstKey() : tailMap.firstKey();
}
return circle.get(hash);
}
}
Test code:
ArrayList<String> al = new ArrayList<String>();
al.add("redis1");
al.add("redis2");
al.add("redis3");
al.add("redis4");
String[] userIds =
{"-84942321036308",
"-76029520310209",
"-68343931116147",
"-54921760962352"
};
HashFunction hf = Hashing.md5();
ConsistentHash<String> consistentHash = new ConsistentHash<String>(hf, 100, al);
for (String userId : userIds) {
System.out.println(consistentHash.get(userId));
}
Python code:
import bisect
import md5
class ConsistentHashRing(object):
"""Implement a consistent hashing ring."""
def __init__(self, replicas=100):
"""Create a new ConsistentHashRing.
:param replicas: number of replicas.
"""
self.replicas = replicas
self._keys = []
self._nodes = {}
def _hash(self, key):
"""Given a string key, return a hash value."""
return long(md5.md5(key).hexdigest(), 16)
def _repl_iterator(self, nodename):
"""Given a node name, return an iterable of replica hashes."""
return (self._hash("%s%s" % (nodename, i))
for i in xrange(self.replicas))
def __setitem__(self, nodename, node):
"""Add a node, given its name.
The given nodename is hashed
among the number of replicas.
"""
for hash_ in self._repl_iterator(nodename):
if hash_ in self._nodes:
raise ValueError("Node name %r is "
"already present" % nodename)
self._nodes[hash_] = node
bisect.insort(self._keys, hash_)
def __delitem__(self, nodename):
"""Remove a node, given its name."""
for hash_ in self._repl_iterator(nodename):
# will raise KeyError for nonexistent node name
del self._nodes[hash_]
index = bisect.bisect_left(self._keys, hash_)
del self._keys[index]
def __getitem__(self, key):
"""Return a node, given a key.
The node replica with a hash value nearest
but not less than that of the given
name is returned. If the hash of the
given name is greater than the greatest
hash, returns the lowest hashed node.
"""
hash_ = self._hash(key)
start = bisect.bisect(self._keys, hash_)
if start == len(self._keys):
start = 0
return self._nodes[self._keys[start]]
Test code:
import ConsistentHashRing
if __name__ == '__main__':
server_infos = ["redis1", "redis2", "redis3", "redis4"];
hash_ring = ConsistentHashRing()
test_keys = ["-84942321036308",
"-76029520310209",
"-68343931116147",
"-54921760962352",
"-53401599829545"
];
for server in server_infos:
hash_ring[server] = server
for key in test_keys:
print str(hash_ring[key])