Here it is a full implementation of linear-probe map w/ double keys, single value. It should outperform java.util.HashMap nicely as well.
Warning, it's written in the very early hours of the day from scratch, so it might contain bugs. Please feel free to edit it.
The solution must beat any wrapper, concat one at any time. The no allocation on get/put also makes it quick general purpose map.
Hope this solves the issue. (The code comes w/ some simple tests that are unneeded)
package bestsss.util;
@SuppressWarnings("unchecked")
public class DoubleKeyMap<K1, K2, V> {
private static final int MAX_CAPACITY = 1<<29;
private static final Object TOMBSTONE = new String("TOMBSTONE");
Object[] kvs;
int[] hashes;
int count = 0;
final int rehashOnProbes;
public DoubleKeyMap(){
this(8, 5);
}
public DoubleKeyMap(int capacity, int rehashOnProbes){
capacity = nextCapacity(Math.max(2, capacity-1));
if (rehashOnProbes>capacity){
throw new IllegalArgumentException("rehashOnProbes too high");
}
hashes = new int[capacity];
kvs = new Object[kvsIndex(capacity)];
count = 0;
this.rehashOnProbes = rehashOnProbes;
}
private static int nextCapacity(int c) {
int n = Integer.highestOneBit(c)<<1;
if (n<0 || n>MAX_CAPACITY){
throw new Error("map too large");
}
return n;
}
//alternatively this method can become non-static, protected and overriden, the perfoamnce can drop a little
//but if better spread of the lowest bit is possible, all good and proper
private static<K1, K2> int hash(K1 key1, K2 key2){
//spread more, if need be
int h1 = key1.hashCode();
int h2 = key2.hashCode();
return h1+ (h2<<4) + h2; //h1+h2*17
}
private static int kvsIndex(int baseIdx){
int idx = baseIdx;
idx+=idx<<1;//idx*3
return idx;
}
private int baseIdx(int hash){
return hash & (hashes.length-1);
}
public V get(K1 key1, K2 key2){
final int hash = hash(key1, key2);
final int[] hashes = this.hashes;
final Object[] kvs = this.kvs;
final int mask = hashes.length-1;
for(int base = baseIdx(hash);;base=(base+1)&mask){
int k = kvsIndex(base);
K1 k1 = (K1) kvs[k];
if (k1==null)
return null;//null met; no such value
Object value;
if (hashes[base]!=hash || TOMBSTONE==(value=kvs[k+2]))
continue;//next
K2 k2 = (K2) kvs[k+1];
if ( (key1==k1 || key1.equals(k1)) && (key2==k2 || key2.equals(k2)) ){
return (V) value;
}
}
}
public boolean contains(K1 key1, K2 key2){
return get(key1, key2)!=null;
}
public boolean containsValue(final V value){
final Object[] kvs = this.kvs;
if (value==null)
return false;
for(int i=0;i<kvs.length;i+=3){
Object v = kvs[2];
if (v==null || v==TOMBSTONE)
continue;
if (value==v || value.equals(v))
return true;
}
return false;
}
public V put(K1 key1, K2 key2, V value){
int hash = hash(key1, key2);
return doPut(key1, key2, value, hash);
}
public V remove(K1 key1, K2 key2){
int hash = hash(key1, key2);
return doPut(key1, key2, null, hash);
}
//note, instead of remove a TOMBSTONE is used to mark the deletion
//this may leak keys but deletion doesn't need to shift the array like in Knuth 6.4
protected V doPut(final K1 key1, final K2 key2, Object value, final int hash){
//null value -> remove
int probes = 0;
final int[] hashes = this.hashes;
final Object[] kvs = this.kvs;
final int mask = hashes.length-1;
//conservative resize: when too many probes and the count is greater than the half of the capacity
for(int base = baseIdx(hash);probes<rehashOnProbes || count<(mask>>1);base=(base+1)&mask, probes++){
final int k = kvsIndex(base);
K1 k1 = (K1) kvs[k];
K2 k2;
//find a gap, or resize
Object old = kvs[k+2];
final boolean emptySlot = k1==null || (value!=null && old==TOMBSTONE);
if (emptySlot || (
hashes[base] == hash &&
(k1==key1 || k1.equals(key1)) &&
((k2=(K2) kvs[k+1])==key2 || k2.equals(key2)))
){
if (value==null){//remove()
if (emptySlot)
return null;//not found, and no value ->nothing to do
value = TOMBSTONE;
count-=2;//offset the ++later
}
if (emptySlot){//new entry, update keys
hashes[base] = hash;
kvs[k] = key1;
kvs[k+1] = key2;
}//else -> keys and hash are equal
if (old==TOMBSTONE)
old=null;
kvs[k+2] = value;
count++;
return (V) old;
}
}
resize();
return doPut(key1, key2, value, hash);//hack w/ recursion, after the resize
}
//optimized version during resize, doesn't check equals which is the slowest part
protected void doPutForResize(K1 key1, K2 key2, V value, final int hash){
final int[] hashes = this.hashes;
final Object[] kvs = this.kvs;
final int mask = hashes.length-1;
//find the 1st gap and insert there
for(int base = baseIdx(hash);;base=(base+1)&mask){//it's ensured, no equal keys exist, so skip equals part
final int k = kvsIndex(base);
K1 k1 = (K1) kvs[k];
if (k1!=null)
continue;
hashes[base] = hash;
kvs[k] = key1;
kvs[k+1] = key2;
kvs[k+2] = value;
return;
}
}
//resizes the map by doubling the capacity,
//the method uses altervative varian of put that doesn't check equality, or probes; just inserts at a gap
protected void resize(){
final int[] hashes = this.hashes;
final Object[] kvs = this.kvs;
final int capacity = nextCapacity(hashes.length);
this.hashes = new int[capacity];
this.kvs = new Object[kvsIndex(capacity)];
for (int i=0;i<hashes.length; i++){
int k = kvsIndex(i);
K1 key1 = (K1) kvs[k];
Object value = kvs[k+2];
if (key1!=null && TOMBSTONE!=value){
K2 key2 = (K2) kvs[k+1];
doPutForResize(key1, key2, (V) value, hashes[i]);
}
}
}
public static void main(String[] args) {
DoubleKeyMap<String, String, Integer> map = new DoubleKeyMap<String, String, Integer>(4,2);
map.put("eur/usd", "usd/jpy", 1);
map.put("eur/usd", "usd/jpy", 2);
map.put("eur/jpy", "usd/jpy", 3);
System.out.println(map.get("eur/jpy", "usd/jpy"));
System.out.println(map.get("eur/usd", "usd/jpy"));
System.out.println("======");
map.remove("eur/usd", "usd/jpy");
System.out.println(map.get("eur/jpy", "usd/jpy"));
System.out.println(map.get("eur/usd", "usd/jpy"));
System.out.println("======");
testResize();
}
static void testResize(){
DoubleKeyMap<String, Integer, Integer> map = new DoubleKeyMap<String, Integer, Integer>(18, 17);
long s = 0;
String pref="xxx";
for (int i=0;i<14000;i++){
map.put(pref+i, i, i);
if ((i&1)==1)
map.remove(pref+i, i);
else
s+=i;
}
System.out.println("sum: "+s);
long sum = 0;
for (int i=0;i<14000;i++){
Integer n = map.get(pref+i, i);
if (n!=null && n!=i){
throw new AssertionError();
}
if (n!=null){
System.out.println(n);
sum+=n;
}
}
System.out.println("1st sum: "+s);
System.out.println("2nd sum: "+sum);
}
}