1 module mutils.container.hash_set_simd;
2 
3 version (X86_64)  : import core.bitop;
4 import core.simd : ushort8;
5 import std.meta;
6 import std.traits;
7 
8 import mutils.benchmark;
9 import mutils.container.vector;
10 import mutils.traits;
11 import mutils.utils;
12 
13 //There is some bug, maybe with not tracking deleting entries (ending is never true)
14 
15 version (DigitalMars) {
16 	import core.bitop;
17 
18 	alias firstSetBit = bsr; // DMD treats it as intrinsics
19 } else version (LDC) {
20 	import ldc.intrinsics;
21 
22 	int firstSetBit(int i) {
23 		return llvm_cttz(i, true) + 1;
24 	}
25 } else {
26 	static assert("Compiler not supported.");
27 }
28 
29 enum ushort emptyMask = 1;
30 enum ushort neverUsedMask = 2;
31 enum ushort hashMask = ushort.max - 1;
32 // lower 15 bits - part of hash, last bit - isEmpty
33 struct Control {
34 nothrow @nogc @safe:
35 
36 	ushort b = neverUsedMask;
37 
38 	bool isEmpty() {
39 		return (b & emptyMask) == 0;
40 	}
41 
42 	/*void setEmpty(){
43 	 b=emptyMask;
44 	 }*/
45 
46 	/*bool cmpHash(size_t hash){
47 	 union Tmp{
48 	 size_t h;
49 	 ushort[size_t.sizeof/2] d;
50 	 }
51 	 Tmp t=Tmp(hash);
52 	 return (t.d[0] & hashMask)==(b & hashMask);
53 	 }*/
54 
55 	void set(size_t hash) {
56 		union Tmp {
57 			size_t h;
58 			ushort[size_t.sizeof / 2] d;
59 		}
60 
61 		Tmp t = Tmp(hash);
62 		b = (t.d[0] & hashMask) | emptyMask;
63 	}
64 }
65 
66 // Hash helper struct
67 // hash is made out of two parts[     H1 48 bits      ][ H2 16 bits]
68 // whole hash is used to find group
69 // H2 is used to quickly(SIMD) find element in group
70 struct Hash {
71 nothrow @nogc @safe:
72 	union {
73 		size_t h = void;
74 		ushort[size_t.sizeof / 2] d = void;
75 	}
76 
77 	this(size_t hash) {
78 		h = hash;
79 	}
80 
81 	size_t getH1() {
82 		Hash tmp = h;
83 		tmp.d[0] = d[0] & emptyMask; //clear H2 hash
84 		return tmp.h;
85 	}
86 
87 	ushort getH2() {
88 		return d[0] & hashMask;
89 	}
90 
91 	ushort getH2WithLastSet() {
92 		return d[0] | emptyMask;
93 	}
94 
95 }
96 
97 size_t defaultHashFunc(T)(auto ref T t) {
98 	static if (isIntegral!(T)) {
99 		return hashInt(t);
100 	} else {
101 		return hashInt(t.hashOf); // hashOf is not giving proper distribution between H1 and H2 hash parts
102 	}
103 }
104 
105 // Can turn bad hash function to good one
106 ulong hashInt(ulong x) nothrow @nogc @safe {
107 	x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
108 	x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
109 	x = x ^ (x >> 31);
110 	return x;
111 }
112 
113 // ADV additional value - used to implement HashMap without unnecessary copies
114 struct HashSet(T, alias hashFunc = defaultHashFunc, ADV...) {
115 	static assert(ADV.length <= 1); // ADV is treated as a optional additional value type
116 	static assert(size_t.sizeof == 8); // Only 64 bit
117 	enum hasValue = ADV.length == 1;
118 	enum rehashFactor = 0.85;
119 	enum size_t getIndexEmptyValue = size_t.max;
120 
121 	static struct Group {
122 		union {
123 			Control[8] control;
124 			ushort8 controlVec;
125 		}
126 
127 		T[8] elements;
128 		static if (hasValue)
129 			ADV[0][8] values;
130 
131 		// Prevent error in Vector!Group
132 		bool opEquals()(auto ref const Group r) const {
133 			assert(0);
134 		}
135 	}
136 
137 	void clear() {
138 		groups.clear();
139 		addedElements = 0;
140 	}
141 
142 	void reset() {
143 		groups.reset();
144 		addedElements = 0;
145 	}
146 
147 	Vector!Group groups; // Length should be always power of 2
148 	size_t addedElements; // Used to compute loadFactor
149 
150 	float getLoadFactor(size_t forElementsNum) {
151 		if (groups.length == 0) {
152 			return 1;
153 		}
154 		return cast(float) forElementsNum / (groups.length * 8);
155 	}
156 
157 	void rehash() {
158 		mixin(doNotInline);
159 		// Get all elements
160 		Vector!T allElements;
161 		allElements.reserve(groups.length);
162 		static if (hasValue)
163 			Vector!(ADV[0]) allValues;
164 		static if (hasValue)
165 			allValues.reserve(groups.length);
166 		static if (hasValue) {
167 			foreach (ref Control c, ref T el, ref ADV[0] val; this) {
168 				allElements ~= el;
169 				allValues ~= val;
170 				c = Control.init;
171 			}
172 		} else {
173 			foreach (ref Control c, ref T el; this) {
174 				allElements ~= el;
175 				c = Control.init;
176 			}
177 		}
178 
179 		if (getLoadFactor(addedElements + 1) > rehashFactor) { // Reallocate
180 			groups.length = (groups.length ? groups.length : 1) << 1; // Power of two
181 		}
182 
183 		// Insert elements
184 		foreach (i, ref el; allElements) {
185 			static if (hasValue) {
186 				add(el, allValues[i]);
187 			} else {
188 				add(el);
189 			}
190 		}
191 		addedElements = allElements.length;
192 		//allElements.clear();
193 	}
194 
195 	size_t length() {
196 		return addedElements;
197 	}
198 
199 	bool tryRemove(T el) {
200 		size_t index = getIndex(el);
201 		if (index == getIndexEmptyValue) {
202 			return false;
203 		}
204 		addedElements--;
205 		size_t group = index / 8;
206 		size_t elIndex = index % 8;
207 		groups[group].control[elIndex] = Control.init;
208 		//TODO value destructor
209 		return true;
210 	}
211 
212 	void remove(T el) {
213 		bool ok = tryRemove(el);
214 		assert(ok);
215 	}
216 
217 	void add(T el, ADV value) {
218 		add(el, value);
219 	}
220 
221 	void add(ref T el, ADV value) {
222 		if (isIn(el)) {
223 			return;
224 		}
225 
226 		if (getLoadFactor(addedElements + 1) > rehashFactor) {
227 			assumeNoGC(&rehash)(); // rehash is @nogc but compiler cannot deduce that because rehash calls add internally
228 		}
229 		addedElements++;
230 		Hash hash = Hash(hashFunc(el));
231 		int group = hashMod(hash.h); // Starting point
232 		uint groupSkip = 0;
233 		while (true) {
234 			Group* gr = &groups[group];
235 			foreach (i, ref Control c; gr.control) {
236 				if (c.isEmpty) {
237 					c.set(hash.h);
238 					gr.elements[i] = el;
239 					static if (hasValue)
240 						gr.values[i] = value[0];
241 					return;
242 				}
243 			}
244 			group++;
245 			if (group >= groups.length) {
246 				group = 0;
247 			}
248 		}
249 	}
250 
251 	// Sets bits in ushort where value in control matches check value
252 	// Ex. control=[0,1,2,3,4,5,6,7], check=2, return=0b0000_0000_0011_0000
253 	static auto matchSIMD(ushort8 control, ushort check) @nogc {
254 		// DMD has bug in release mode
255 		// https://issues.dlang.org/show_bug.cgi?id=18034
256 		/*version(DigitalMars){
257 			import core.simd: __simd, ubyte16, XMM;
258 			ushort8 v=ushort8(check);
259 			ubyte16 ok=__simd(XMM.PCMPEQW, control, v);
260 			ubyte16 bitsMask=[1,2,4,8,16,32,64,128,1,2,4,8,16,32,64,128];
261 			ubyte16 bits=bitsMask&ok;
262 			ubyte16 zeros=0;
263 			ushort8 vv=__simd(XMM.PSADBW, bits, zeros);
264 			ushort num=cast(ushort)(vv[0]+vv[4]*256);		
265 		}else*/
266 		version (LDC) {
267 			import ldc.simd;
268 
269 			ushort8 v = ushort8(check);
270 			ushort8 ok = equalMask!ushort8(control, v);
271 			version (X86) {
272 				import ldc.gccbuiltins_x86;
273 
274 				ushort num = cast(ushort) __builtin_ia32_pmovmskb128(ok);
275 			} else version (X86_64) {
276 				import ldc.gccbuiltins_x86;
277 
278 				ushort num = cast(ushort) __builtin_ia32_pmovmskb128(ok);
279 			} else {
280 				ushort num = 0;
281 				foreach (i; 0 .. 8) {
282 					if (control.ptr[i] == check) {
283 						num |= (0b11 << i * 2);
284 					}
285 				}
286 			}
287 		} else {
288 			ushort num = 0;
289 			foreach (i; 0 .. 8) {
290 				if (control.ptr[i] == check) {
291 					num |= (0b11 << i * 2);
292 				}
293 			}
294 		}
295 		return num;
296 	}
297 	// Division is expensive but groups.length is power of two so use trick
298 	int hashMod(size_t hash) nothrow @nogc @system {
299 		return cast(int)(hash & (groups.length - 1));
300 	}
301 
302 	bool isIn(ref T el) {
303 		return getIndex(el) != getIndexEmptyValue;
304 	}
305 
306 	bool isIn(T el) {
307 		return getIndex(el) != getIndexEmptyValue;
308 	}
309 
310 	// For debug
311 	/*int numA;
312 	 int numB;
313 	 int numC;*/
314 
315 	size_t getIndex(T el) {
316 		return getIndex(el);
317 	}
318 
319 	size_t getIndex(ref T el) {
320 		//mixin(doNotInline);
321 		size_t groupsLength = groups.length;
322 		if (groupsLength == 0) {
323 			return getIndexEmptyValue;
324 		}
325 
326 		Hash hash = Hash(hashFunc(el));
327 		size_t mask = groupsLength - 1;
328 		size_t group = cast(int)(hash.h & mask); // Starting point	
329 		//numA++;
330 		while (true) {
331 			//numB++;
332 			Group* gr = &groups[group];
333 			int cntrlV = matchSIMD(gr.controlVec, hash.getH2WithLastSet); // Compare 8 controls at once to h2
334 			while (cntrlV != 0) {
335 				//numC++;
336 				int ffInd = firstSetBit(cntrlV);
337 				int i = ffInd / 2; // Find first set bit and divide by 2 to get element index
338 				if (gr.elements.ptr[i] == el) {
339 					return group * 8 + i;
340 				}
341 				cntrlV &= 0xFFFF_FFFF << (ffInd + 1);
342 			}
343 			cntrlV = matchSIMD(gr.controlVec, neverUsedMask); // If there is neverUsed element, we will never find our element
344 			if (cntrlV != 0) {
345 				return getIndexEmptyValue;
346 			}
347 			group++;
348 			group = group & mask;
349 		}
350 
351 	}
352 	// foreach support
353 	int opApply(DG)(scope DG dg) {
354 		int result;
355 		foreach (ref Group gr; groups) {
356 			foreach (i, ref Control c; gr.control) {
357 				if (c.isEmpty) {
358 					continue;
359 				}
360 				static if (hasValue && isForeachDelegateWithTypes!(DG, Control, T, ADV[0])) {
361 					result = dg(gr.control[i], gr.elements[i], gr.values[i]);
362 				} else static if (isForeachDelegateWithTypes!(DG, Control, T)) {
363 					result = dg(gr.control[i], gr.elements[i]);
364 				} else static if (isForeachDelegateWithTypes!(DG, T)) {
365 					result = dg(gr.elements[i]);
366 				} else {
367 					static assert(0);
368 				}
369 				if (result)
370 					break;
371 			}
372 		}
373 
374 		return result;
375 	}
376 
377 	void saveGroupDistributionPlot(string path) {
378 		BenchmarkData!(1, 8192) distr; // For now use benchamrk as a plotter
379 
380 		foreach (ref T el; this) {
381 			int group = hashMod(hashFunc(el));
382 			if (group >= 8192) {
383 				continue;
384 			}
385 			distr.times[0][group]++;
386 
387 		}
388 		distr.plotUsingGnuplot(path, ["group distribution"]);
389 	}
390 
391 }
392 
393 @nogc nothrow pure unittest {
394 	ushort8 control = 15;
395 	control.array[0] = 10;
396 	control.array[7] = 10;
397 	ushort check = 15;
398 	ushort ret = HashSet!(int).matchSIMD(control, check);
399 	assert(ret == 0b0011_1111_1111_1100);
400 }
401 
402 unittest {
403 	HashSet!(int) set;
404 
405 	assert(set.isIn(123) == false);
406 	set.add(123);
407 	set.add(123);
408 	assert(set.isIn(123) == true);
409 	assert(set.isIn(122) == false);
410 	assert(set.addedElements == 1);
411 	set.remove(123);
412 	assert(set.isIn(123) == false);
413 	assert(set.addedElements == 0);
414 	assert(set.tryRemove(500) == false);
415 	set.add(123);
416 	assert(set.tryRemove(123) == true);
417 
418 	foreach (i; 1 .. 130) {
419 		set.add(i);
420 	}
421 
422 	foreach (i; 1 .. 130) {
423 		assert(set.isIn(i));
424 	}
425 
426 	foreach (i; 130 .. 500) {
427 		assert(!set.isIn(i));
428 	}
429 
430 	foreach (int el; set) {
431 		assert(set.isIn(el));
432 	}
433 }
434 
435 void benchmarkHashSetInt() {
436 	HashSet!(int) set;
437 	byte[int] mapStandard;
438 	uint elementsNumToAdd = 200; //cast(uint)(64536*0.9);
439 	// Add elements
440 	foreach (int i; 0 .. elementsNumToAdd) {
441 		set.add(i);
442 		mapStandard[i] = true;
443 	}
444 	// Check if isIn is working
445 	foreach (int i; 0 .. elementsNumToAdd) {
446 		assert(set.isIn(i));
447 		assert((i in mapStandard) !is null);
448 	}
449 	// Check if isIn is returning false properly
450 	foreach (int i; elementsNumToAdd .. elementsNumToAdd + 10_000) {
451 		assert(!set.isIn(i));
452 		assert((i in mapStandard) is null);
453 	}
454 	//set.numA=set.numB=set.numC=0;
455 	enum itNum = 100;
456 	BenchmarkData!(2, itNum) bench;
457 	doNotOptimize(set); // Make some confusion for compiler
458 	doNotOptimize(mapStandard);
459 	ushort myResults;
460 	myResults = 0;
461 	//benchmark standard library implementation
462 	foreach (b; 0 .. itNum) {
463 		bench.start!(1)(b);
464 		foreach (i; 0 .. 1000_000) {
465 			auto ret = myResults in mapStandard;
466 			myResults += 1 + cast(bool) ret; //cast(typeof(myResults))(cast(bool)ret);
467 			doNotOptimize(ret);
468 		}
469 		bench.end!(1)(b);
470 	}
471 
472 	auto stResult = myResults;
473 	//benchmark this implementation
474 	myResults = 0;
475 	foreach (b; 0 .. itNum) {
476 		bench.start!(0)(b);
477 		foreach (i; 0 .. 1000_000) {
478 			auto ret = set.isIn(myResults);
479 			myResults += 1 + ret; //cast(typeof(myResults))(ret);
480 			doNotOptimize(ret);
481 		}
482 		bench.end!(0)(b);
483 	}
484 	assert(myResults == stResult); // Same behavior as standard map
485 	//writeln(set.getLoadFactor(set.addedElements));
486 	//writeln(set.numA);
487 	//writeln(set.numB);
488 	//writeln(set.numC);
489 
490 	doNotOptimize(myResults);
491 	bench.plotUsingGnuplot("test.png", ["my", "standard"]);
492 	set.saveGroupDistributionPlot("distr.png");
493 }
494 
495 void benchmarkHashSetPerformancePerElement() {
496 	ushort trueResults;
497 	doNotOptimize(trueResults);
498 	enum itNum = 1000;
499 	BenchmarkData!(2, itNum) bench;
500 	HashSet!(int) set;
501 	byte[int] mapStandard;
502 	//writeln(set.getLoadFactor(set.addedElements));
503 	//set.numA=set.numB=set.numC=0;
504 	size_t lastAdded;
505 	size_t numToAdd = 16 * 8;
506 
507 	foreach (b; 0 .. itNum) {
508 		foreach (i; lastAdded .. lastAdded + numToAdd) {
509 			mapStandard[cast(uint) i] = true;
510 		}
511 		lastAdded += numToAdd;
512 		bench.start!(1)(b);
513 		foreach (i; 0 .. 1000_00) {
514 			auto ret = trueResults in mapStandard;
515 			trueResults += 1; //cast(typeof(trueResults))(cast(bool)ret);
516 			doNotOptimize(ret);
517 		}
518 		bench.end!(1)(b);
519 	}
520 	lastAdded = 0;
521 	trueResults = 0;
522 	foreach (b; 0 .. itNum) {
523 		foreach (i; lastAdded .. lastAdded + numToAdd) {
524 			set.add(cast(uint) i);
525 		}
526 		lastAdded += numToAdd;
527 		bench.start!(0)(b);
528 		foreach (i; 0 .. 1000_00) {
529 			auto ret = set.isIn(trueResults);
530 			trueResults += 1; //cast(typeof(trueResults))(ret);
531 			doNotOptimize(ret);
532 		}
533 		bench.end!(0)(b);
534 	}
535 	//writeln(set.numA);
536 	//writeln(set.numB);
537 	// writeln(set.numC);
538 	doNotOptimize(trueResults);
539 	bench.plotUsingGnuplot("test.png", ["my", "standard"]);
540 	//set.saveGroupDistributionPlot("distr.png");
541 
542 }