1 module mutils.container.hash_set;
2 
3 import core.bitop;
4 import core.simd: ushort8;
5 import std.meta;
6 import std.stdio;
7 import std.traits;
8 
9 import mutils.benchmark;
10 import mutils.container.vector;
11 import mutils.traits;
12 
13 version(DigitalMars){
14 	import core.bitop;
15 	alias firstSetBit=bsr;// DMD treats it as intrinsics
16 }else version(LDC){
17 	import ldc.intrinsics;
18 	int firstSetBit(int i){
19 		return llvm_cttz(i, true)+1;
20 	}
21 }else{
22 	static assert("Compiler not supported.");
23 }
24 
25 enum ushort emptyMask=1;
26 enum ushort neverUsedMask=2;
27 enum ushort hashMask=~emptyMask;
28 // lower 15 bits - part of hash, last bit - isEmpty
29 struct Control{	
30 	nothrow @nogc @safe:
31 
32 	ushort b=neverUsedMask;
33 	
34 	bool isEmpty(){
35 		return (b & emptyMask)==0;
36 	}
37 	
38 	/*void setEmpty(){
39 	 b=emptyMask;
40 	 }*/
41 	
42 	/*bool cmpHash(size_t hash){
43 	 union Tmp{
44 	 size_t h;
45 	 ushort[size_t.sizeof/2] d;
46 	 }
47 	 Tmp t=Tmp(hash);
48 	 return (t.d[0] & hashMask)==(b & hashMask);
49 	 }*/
50 	
51 	void set(size_t hash){
52 		union Tmp{
53 			size_t h;
54 			ushort[size_t.sizeof/2] d;
55 		}
56 		Tmp t=Tmp(hash);
57 		b=(t.d[0] & hashMask) | emptyMask;
58 	}
59 }
60 
61 // Hash helper struct
62 // hash is made out of two parts[     H1 48 bits      ][ H2 16 bits]
63 // whole hash is used to find group
64 // H2 is used to quickly(SIMD) find element in group
65 struct Hash{
66 	nothrow @nogc @safe:
67 	union{
68 		size_t h=void;
69 		ushort[size_t.sizeof/2] d=void;
70 	}
71 	this(size_t hash){
72 		h=hash;
73 	}
74 	
75 	size_t getH1(){
76 		Hash tmp=h;
77 		tmp.d[0]=d[0] & emptyMask;//clear H2 hash
78 		return tmp.h;
79 	}
80 	
81 	ushort getH2(){
82 		return d[0] & hashMask;
83 	}
84 	
85 	ushort getH2WithLastSet(){
86 		return d[0] | emptyMask;
87 	}
88 	
89 }
90 
91 size_t defaultHashFunc(T)(auto ref T t){
92 	static if (isIntegral!(T)){
93 		return hashInt(t);
94 	}else{
95 		return hashInt(t.hashOf);// hashOf is not giving proper distribution between H1 and H2 hash parts
96 	}
97 }
98 
99 // Can turn bad hash function to good one
100 ulong hashInt(ulong x) nothrow @nogc @safe {
101 	x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
102 	x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
103 	x = x ^ (x >> 31);
104 	return x;
105 }
106 
107 // ADV additional value - used to implement HashMap without unnecessary copies
108 struct HashSet(T, alias hashFunc=defaultHashFunc, ADV...){
109 	static assert(ADV.length<=1);// ADV is treated as a optional additional value type
110 	static assert(size_t.sizeof==8);// Only 64 bit
111 	enum hasValue=ADV.length==1;
112 	enum rehashFactor=0.85;
113 	enum size_t getIndexEmptyValue=size_t.max;
114 
115 	static struct Group{
116 		union{
117 			Control[8] control;
118 			ushort8 controlVec;
119 		}
120 		T[8] elements;
121 		static if(hasValue)ADV[0][8] values;
122 		
123 		// Prevent error in Vector!Group
124 		bool opEquals()(auto ref const Group r) const { 
125 			assert(0);
126 		}
127 	}
128 
129 	
130 	Vector!Group groups;// Length should be always power of 2
131 	size_t addedElements;// Used to compute loadFactor
132 	
133 	float getLoadFactor(size_t forElementsNum) {
134 		if(groups.length==0){
135 			return 1;
136 		}
137 		return cast(float)forElementsNum/(groups.length*8);
138 	}
139 	
140 	void rehash() {
141 		mixin(doNotInline);
142 		// Get all elements
143 		Vector!T allElements;
144 		allElements.reserve(groups.length);
145 		static if(hasValue)Vector!(ADV[0]) allValues;
146 		static if(hasValue)allValues.reserve(groups.length);
147 		static if(hasValue){
148 			foreach(ref Control c, ref T el, ref ADV[0] val; this){
149 				allElements~=el;
150 				allValues~=val;
151 				c=Control.init;
152 			}
153 		}else{
154 			foreach(ref Control c, ref T el; this){
155 				allElements~=el;
156 				c=Control.init;
157 			}
158 		}
159 
160 		if(getLoadFactor(addedElements+1)>rehashFactor){// Reallocate
161 			groups.length=(groups.length?groups.length:1)<<1;// Power of two
162 		}
163 	
164 		// Insert elements
165 		foreach(i, el;allElements){
166 			static if(hasValue){
167 				add(el, allValues[i]);
168 			}else{
169 				add(el);
170 			}
171 		}
172 		addedElements=allElements.length;
173 		allElements.clear();
174 	}
175 	
176 	size_t length(){
177 		return addedElements;
178 	}
179 	
180 	bool tryRemove(T el) {
181 		size_t index=getIndex(el);
182 		if(index==getIndexEmptyValue){
183 			return false;
184 		}
185 		addedElements--;
186 		size_t group=index/8;
187 		size_t elIndex=index%8;
188 		groups[group].control[elIndex]=Control.init;
189 		//TODO value destructor
190 		return true;
191 	}
192 
193 	void remove(T el) {
194 		assert(tryRemove(el));
195 	}
196 
197 	void add(T el, ADV value){
198 		if(isIn(el)){
199 			return;
200 		}
201 		
202 		if(getLoadFactor(addedElements+1)>rehashFactor){
203 			assumeNoGC(&rehash)();// rehash is @nogc but compiler cannot deduce that because rehash calls add internally
204 		}
205 		addedElements++;
206 		Hash hash=Hash(hashFunc(el));
207 		int group=hashMod(hash.h);// Starting point
208 		uint groupSkip=0;
209 		while(true){
210 			Group* gr=&groups[group];
211 			foreach(i, ref Control c; gr.control){
212 				if(c.isEmpty){
213 					c.set(hash.h);
214 					gr.elements[i]=el;
215 					static if(hasValue)gr.values[i]=value[0];
216 					return;
217 				}
218 			}
219 			group++;
220 			if(group>=groups.length){
221 				group=0;
222 			}
223 		}
224 	}
225 	
226 	// Sets bits in ushort where value in control matches check value
227 	// Ex. control=[0,1,2,3,4,5,6,7], check=2, return=0b0000_0000_0011_0000
228 	static auto matchSIMD(ushort8 control, ushort check) @nogc {
229 		ushort8 v=ushort8(check);
230 		version(DigitalMars){
231 			import core.simd: __simd, ubyte16, XMM;
232 			ubyte16 ok=__simd(XMM.PCMPEQW, control, v);
233 			ubyte16 bitsMask=[1,2,4,8,16,32,64,128,1,2,4,8,16,32,64,128];
234 			ubyte16 bits=bitsMask&ok;
235 			ubyte16 zeros=0;
236 			ushort8 vv=__simd(XMM.PSADBW, bits, zeros);
237 			ushort num=cast(ushort)(vv[0]+vv[4]*256);
238 		}else version(LDC){
239 			import ldc.simd;
240 			import ldc.gccbuiltins_x86;
241 			ushort8 ok = equalMask!ushort8(control, v);
242 			ushort num=cast(ushort)__builtin_ia32_pmovmskb128(ok);
243 		}else{
244 			static assert(0);
245 		}
246 		return num;
247 	}
248 	// Division is expensive use lookuptable
249 	int hashMod(size_t hash) nothrow @nogc @system{
250 		return cast(int)(hash & (groups.length-1));
251 	}
252 	
253 	bool isIn(T el){
254 		return getIndex(el)!=getIndexEmptyValue;
255 	}
256 
257 	// For debug
258 	/*int numA;
259 	 int numB;
260 	 int numC;*/
261 
262 
263 	size_t getIndex(T el) {
264 		return getIndex(el);
265 	}
266 
267 	size_t getIndex(ref T el) {
268 		//mixin(doNotInline);
269 		size_t groupsLength=groups.length;
270 		if(groupsLength==0){
271 			return getIndexEmptyValue;
272 		}
273 		
274 		Hash hash=Hash(hashFunc(el));
275 		size_t mask=groupsLength-1;
276 		size_t group=cast(int)(hash.h & mask);// Starting point	
277 		//numA++;
278 		while( true ){
279 			//numB++;
280 			Group* gr=&groups[group];
281 			int cntrlV=matchSIMD(gr.controlVec, hash.getH2WithLastSet);// Compare 8 controls at once to h2
282 			while( cntrlV!=0 ){
283 				//numC++;
284 				int ffInd=firstSetBit(cntrlV);
285 				int i=ffInd/2;// Find first set bit and divide by 2 to get element index
286 				if( gr.elements.ptr[i]==el ){
287 					return group*8+i;
288 				}
289 				cntrlV&=0xFFFF_FFFF<<(ffInd+1);
290 			}
291 			cntrlV=matchSIMD(gr.controlVec, neverUsedMask);// If there is neverUsed element, we will never find our element
292 			if( cntrlV!=0 ){
293 				return getIndexEmptyValue;
294 			}
295 			group++;
296 			group=group & mask;
297 		}
298 		
299 	}	
300 	// foreach support
301 	int opApply(DG)(scope DG  dg) { 
302 		int result;
303 		foreach(ref Group gr; groups){
304 			foreach(i, ref Control c; gr.control){
305 				if(c.isEmpty){
306 					continue;
307 				}
308 				static if(hasValue && isForeachDelegateWithTypes!(DG, Control, T, ADV[0]) ){
309 					result=dg(gr.control[i], gr.elements[i], gr.values[i]);
310 				}else static if( isForeachDelegateWithTypes!(DG, Control, T) ){
311 					result=dg(gr.control[i], gr.elements[i]);
312 				}else static if( isForeachDelegateWithTypes!(DG, T) ){
313 					result=dg(gr.elements[i]);
314 				}else{
315 					static assert(0);
316 				}
317 				if (result)
318 					break;	
319 			}
320 		}		
321 		
322 		return result;
323 	}
324 	
325 	void saveGroupDistributionPlot(string path){
326 		BenchmarkData!(1, 8192) distr;// For now use benchamrk as a plotter
327 		
328 		foreach(ref T el; this){
329 			int group=hashMod(hashFunc(el));
330 			if(group>=8192){
331 				continue;
332 			}
333 			distr.times[0][group]++;
334 			
335 		}
336 		distr.plotUsingGnuplot(path, ["group distribution"]);
337 	}
338 	
339 }
340 
341 
342 
343 @nogc nothrow pure unittest{	
344 	ushort8 control=15;
345 	control.array[0]=10;
346 	control.array[7]=10;
347 	ushort check=15;
348 	ushort ret=HashSet!(int).matchSIMD(control, check);
349 	assert(ret==0b0011_1111_1111_1100);	
350 }
351 
352 @nogc unittest{
353 	static struct KeyValue{
354 		int key;
355 		int value;
356 		bool opEquals()(auto ref const KeyValue r) @nogc{ 
357 			return key==r.key;
358 		}
359 	}
360 	static size_t hashFunc(KeyValue kv){
361 		return defaultHashFunc(kv.key);
362 	}
363 
364 	HashSet!(KeyValue, hashFunc) set222;
365 	set222.add(KeyValue(1,2));
366 
367 	HashSet!(int) set;
368 	
369 	assert(set.isIn(123)==false);
370 	set.add(123);
371 	set.add(123);
372 	assert(set.isIn(123)==true);
373 	assert(set.isIn(122)==false);
374 	assert(set.addedElements==1);
375 	set.remove(123);
376 	assert(set.isIn(123)==false);
377 	assert(set.addedElements==0);
378 	assert(set.tryRemove(500)==false);
379 	set.add(123);
380 	assert(set.tryRemove(123)==true);
381 	
382 	
383 	foreach(i;1..130){
384 		set.add(i);		
385 	}
386 
387 	foreach(i;1..130){
388 		assert(set.isIn(i));
389 	}
390 
391 	foreach(i;130..500){
392 		assert(!set.isIn(i));
393 	}
394 
395 	foreach(int el; set){
396 		assert(set.isIn(el));
397 	}
398 }
399 
400 
401 void benchmarkHashSetInt(){
402 	HashSet!(int) set;
403 	byte[int] mapStandard;
404 	uint elementsNumToAdd=200;//cast(uint)(64536*0.9);
405 	// Add elements
406 	foreach(int i;0..elementsNumToAdd){
407 		set.add(i);
408 		mapStandard[i]=true;
409 	}
410 	// Check if isIn is working
411 	foreach(int i;0..elementsNumToAdd){
412 		assert(set.isIn(i));
413 		assert((i in mapStandard) !is null);
414 	}
415 	// Check if isIn is returning false properly
416 	foreach(int i;elementsNumToAdd..elementsNumToAdd+10_000){
417 		assert(!set.isIn(i));
418 		assert((i in mapStandard) is null);
419 	}
420 	//set.numA=set.numB=set.numC=0;
421 	enum itNum=100;
422 	BenchmarkData!(2, itNum) bench;
423 	doNotOptimize(set);// Make some confusion for compiler
424 	doNotOptimize(mapStandard);
425 	ushort myResults;
426 	myResults=0;
427 	//benchmark standard library implementation
428 	foreach(b;0..itNum){
429 		bench.start!(1)(b);
430 		foreach(i;0..1000_000){
431 			auto ret=myResults in mapStandard;
432 			myResults+=1+cast(bool)ret;//cast(typeof(myResults))(cast(bool)ret);
433 			doNotOptimize(ret);
434 		}
435 		bench.end!(1)(b);
436 	}
437 	
438 	auto stResult=myResults;
439 	//benchmark this implementation
440 	myResults=0;
441 	foreach(b;0..itNum){
442 		bench.start!(0)(b);
443 		foreach(i;0..1000_000){
444 			auto ret=set.isIn(myResults);
445 			myResults+=1+ret;//cast(typeof(myResults))(ret);
446 			doNotOptimize(ret);
447 		}
448 		bench.end!(0)(b);
449 	}
450 	assert(myResults==stResult);// Same behavior as standard map
451 	 //writeln(set.getLoadFactor(set.addedElements));
452 	 //writeln(set.numA);
453 	 //writeln(set.numB);
454 	 //writeln(set.numC);
455 	
456 	doNotOptimize(myResults);
457 	bench.plotUsingGnuplot("test.png",["my", "standard"]);
458 	set.saveGroupDistributionPlot("distr.png");	
459 }
460 
461 
462 void benchmarkHashSetPerformancePerElement(){
463 	ushort trueResults;
464 	doNotOptimize(trueResults);
465 	enum itNum=1000;
466 	BenchmarkData!(2, itNum) bench;
467 	HashSet!(int) set;
468 	byte[int] mapStandard;
469 	//writeln(set.getLoadFactor(set.addedElements));
470 	//set.numA=set.numB=set.numC=0;
471 	size_t lastAdded;
472 	size_t numToAdd=16*8;
473 
474 	foreach(b;0..itNum){
475 		foreach(i;lastAdded..lastAdded+numToAdd){
476 			mapStandard[cast(uint)i]=true;
477 		}
478 		lastAdded+=numToAdd;
479 		bench.start!(1)(b);
480 		foreach(i;0..1000_00){
481 			auto ret=trueResults in mapStandard;
482 			trueResults+=1;//cast(typeof(trueResults))(cast(bool)ret);
483 			doNotOptimize(ret);
484 		}
485 		bench.end!(1)(b);
486 	}
487 	lastAdded=0;
488 	trueResults=0;
489 	foreach(b;0..itNum){
490 		foreach(i;lastAdded..lastAdded+numToAdd){
491 			set.add(cast(uint)i);
492 		}
493 		lastAdded+=numToAdd;
494 		bench.start!(0)(b);
495 		foreach(i;0..1000_00){
496 			auto ret=set.isIn(trueResults);
497 			trueResults+=1;//cast(typeof(trueResults))(ret);
498 			doNotOptimize(ret);
499 		}
500 		bench.end!(0)(b);
501 	}
502 	//writeln(set.numA);
503 	//writeln(set.numB);
504 	// writeln(set.numC);
505 	doNotOptimize(trueResults);
506 	bench.plotUsingGnuplot("test.png",["my", "standard"]);
507 	//set.saveGroupDistributionPlot("distr.png");
508 
509 }