1 module mutils.container.hash_set_simd; 2 3 version (X86_64) : import core.bitop; 4 import core.simd : ushort8; 5 import std.meta; 6 import std.traits; 7 8 import mutils.benchmark; 9 import mutils.container.vector; 10 import mutils.traits; 11 import mutils.utils; 12 13 //There is some bug, maybe with not tracking deleting entries (ending is never true) 14 15 version (DigitalMars) { 16 import core.bitop; 17 18 alias firstSetBit = bsr; // DMD treats it as intrinsics 19 } else version (LDC) { 20 import ldc.intrinsics; 21 22 int firstSetBit(int i) { 23 return llvm_cttz(i, true) + 1; 24 } 25 } else { 26 static assert("Compiler not supported."); 27 } 28 29 enum ushort emptyMask = 1; 30 enum ushort neverUsedMask = 2; 31 enum ushort hashMask = ushort.max - 1; 32 // lower 15 bits - part of hash, last bit - isEmpty 33 struct Control { 34 nothrow @nogc @safe: 35 36 ushort b = neverUsedMask; 37 38 bool isEmpty() { 39 return (b & emptyMask) == 0; 40 } 41 42 /*void setEmpty(){ 43 b=emptyMask; 44 }*/ 45 46 /*bool cmpHash(size_t hash){ 47 union Tmp{ 48 size_t h; 49 ushort[size_t.sizeof/2] d; 50 } 51 Tmp t=Tmp(hash); 52 return (t.d[0] & hashMask)==(b & hashMask); 53 }*/ 54 55 void set(size_t hash) { 56 union Tmp { 57 size_t h; 58 ushort[size_t.sizeof / 2] d; 59 } 60 61 Tmp t = Tmp(hash); 62 b = (t.d[0] & hashMask) | emptyMask; 63 } 64 } 65 66 // Hash helper struct 67 // hash is made out of two parts[ H1 48 bits ][ H2 16 bits] 68 // whole hash is used to find group 69 // H2 is used to quickly(SIMD) find element in group 70 struct Hash { 71 nothrow @nogc @safe: 72 union { 73 size_t h = void; 74 ushort[size_t.sizeof / 2] d = void; 75 } 76 77 this(size_t hash) { 78 h = hash; 79 } 80 81 size_t getH1() { 82 Hash tmp = h; 83 tmp.d[0] = d[0] & emptyMask; //clear H2 hash 84 return tmp.h; 85 } 86 87 ushort getH2() { 88 return d[0] & hashMask; 89 } 90 91 ushort getH2WithLastSet() { 92 return d[0] | emptyMask; 93 } 94 95 } 96 97 size_t defaultHashFunc(T)(auto ref T t) { 98 static if (isIntegral!(T)) { 99 return hashInt(t); 100 } else { 101 return hashInt(t.hashOf); // hashOf is not giving proper distribution between H1 and H2 hash parts 102 } 103 } 104 105 // Can turn bad hash function to good one 106 ulong hashInt(ulong x) nothrow @nogc @safe { 107 x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9; 108 x = (x ^ (x >> 27)) * 0x94d049bb133111eb; 109 x = x ^ (x >> 31); 110 return x; 111 } 112 113 // ADV additional value - used to implement HashMap without unnecessary copies 114 struct HashSet(T, alias hashFunc = defaultHashFunc, ADV...) { 115 static assert(ADV.length <= 1); // ADV is treated as a optional additional value type 116 static assert(size_t.sizeof == 8); // Only 64 bit 117 enum hasValue = ADV.length == 1; 118 enum rehashFactor = 0.85; 119 enum size_t getIndexEmptyValue = size_t.max; 120 121 static struct Group { 122 union { 123 Control[8] control; 124 ushort8 controlVec; 125 } 126 127 T[8] elements; 128 static if (hasValue) 129 ADV[0][8] values; 130 131 // Prevent error in Vector!Group 132 bool opEquals()(auto ref const Group r) const { 133 assert(0); 134 } 135 } 136 137 void clear() { 138 groups.clear(); 139 addedElements = 0; 140 } 141 142 void reset() { 143 groups.reset(); 144 addedElements = 0; 145 } 146 147 Vector!Group groups; // Length should be always power of 2 148 size_t addedElements; // Used to compute loadFactor 149 150 float getLoadFactor(size_t forElementsNum) { 151 if (groups.length == 0) { 152 return 1; 153 } 154 return cast(float) forElementsNum / (groups.length * 8); 155 } 156 157 void rehash() { 158 mixin(doNotInline); 159 // Get all elements 160 Vector!T allElements; 161 allElements.reserve(groups.length); 162 static if (hasValue) 163 Vector!(ADV[0]) allValues; 164 static if (hasValue) 165 allValues.reserve(groups.length); 166 static if (hasValue) { 167 foreach (ref Control c, ref T el, ref ADV[0] val; this) { 168 allElements ~= el; 169 allValues ~= val; 170 c = Control.init; 171 } 172 } else { 173 foreach (ref Control c, ref T el; this) { 174 allElements ~= el; 175 c = Control.init; 176 } 177 } 178 179 if (getLoadFactor(addedElements + 1) > rehashFactor) { // Reallocate 180 groups.length = (groups.length ? groups.length : 1) << 1; // Power of two 181 } 182 183 // Insert elements 184 foreach (i, ref el; allElements) { 185 static if (hasValue) { 186 add(el, allValues[i]); 187 } else { 188 add(el); 189 } 190 } 191 addedElements = allElements.length; 192 //allElements.clear(); 193 } 194 195 size_t length() { 196 return addedElements; 197 } 198 199 bool tryRemove(T el) { 200 size_t index = getIndex(el); 201 if (index == getIndexEmptyValue) { 202 return false; 203 } 204 addedElements--; 205 size_t group = index / 8; 206 size_t elIndex = index % 8; 207 groups[group].control[elIndex] = Control.init; 208 //TODO value destructor 209 return true; 210 } 211 212 void remove(T el) { 213 bool ok = tryRemove(el); 214 assert(ok); 215 } 216 217 void add(T el, ADV value) { 218 add(el, value); 219 } 220 221 void add(ref T el, ADV value) { 222 if (isIn(el)) { 223 return; 224 } 225 226 if (getLoadFactor(addedElements + 1) > rehashFactor) { 227 assumeNoGC(&rehash)(); // rehash is @nogc but compiler cannot deduce that because rehash calls add internally 228 } 229 addedElements++; 230 Hash hash = Hash(hashFunc(el)); 231 int group = hashMod(hash.h); // Starting point 232 uint groupSkip = 0; 233 while (true) { 234 Group* gr = &groups[group]; 235 foreach (i, ref Control c; gr.control) { 236 if (c.isEmpty) { 237 c.set(hash.h); 238 gr.elements[i] = el; 239 static if (hasValue) 240 gr.values[i] = value[0]; 241 return; 242 } 243 } 244 group++; 245 if (group >= groups.length) { 246 group = 0; 247 } 248 } 249 } 250 251 // Sets bits in ushort where value in control matches check value 252 // Ex. control=[0,1,2,3,4,5,6,7], check=2, return=0b0000_0000_0011_0000 253 static auto matchSIMD(ushort8 control, ushort check) @nogc { 254 // DMD has bug in release mode 255 // https://issues.dlang.org/show_bug.cgi?id=18034 256 /*version(DigitalMars){ 257 import core.simd: __simd, ubyte16, XMM; 258 ushort8 v=ushort8(check); 259 ubyte16 ok=__simd(XMM.PCMPEQW, control, v); 260 ubyte16 bitsMask=[1,2,4,8,16,32,64,128,1,2,4,8,16,32,64,128]; 261 ubyte16 bits=bitsMask&ok; 262 ubyte16 zeros=0; 263 ushort8 vv=__simd(XMM.PSADBW, bits, zeros); 264 ushort num=cast(ushort)(vv[0]+vv[4]*256); 265 }else*/ 266 version (LDC) { 267 import ldc.simd; 268 269 ushort8 v = ushort8(check); 270 ushort8 ok = equalMask!ushort8(control, v); 271 version (X86) { 272 import ldc.gccbuiltins_x86; 273 274 ushort num = cast(ushort) __builtin_ia32_pmovmskb128(ok); 275 } else version (X86_64) { 276 import ldc.gccbuiltins_x86; 277 278 ushort num = cast(ushort) __builtin_ia32_pmovmskb128(ok); 279 } else { 280 ushort num = 0; 281 foreach (i; 0 .. 8) { 282 if (control.ptr[i] == check) { 283 num |= (0b11 << i * 2); 284 } 285 } 286 } 287 } else { 288 ushort num = 0; 289 foreach (i; 0 .. 8) { 290 if (control.ptr[i] == check) { 291 num |= (0b11 << i * 2); 292 } 293 } 294 } 295 return num; 296 } 297 // Division is expensive but groups.length is power of two so use trick 298 int hashMod(size_t hash) nothrow @nogc @system { 299 return cast(int)(hash & (groups.length - 1)); 300 } 301 302 bool isIn(ref T el) { 303 return getIndex(el) != getIndexEmptyValue; 304 } 305 306 bool isIn(T el) { 307 return getIndex(el) != getIndexEmptyValue; 308 } 309 310 // For debug 311 /*int numA; 312 int numB; 313 int numC;*/ 314 315 size_t getIndex(T el) { 316 return getIndex(el); 317 } 318 319 size_t getIndex(ref T el) { 320 //mixin(doNotInline); 321 size_t groupsLength = groups.length; 322 if (groupsLength == 0) { 323 return getIndexEmptyValue; 324 } 325 326 Hash hash = Hash(hashFunc(el)); 327 size_t mask = groupsLength - 1; 328 size_t group = cast(int)(hash.h & mask); // Starting point 329 //numA++; 330 while (true) { 331 //numB++; 332 Group* gr = &groups[group]; 333 int cntrlV = matchSIMD(gr.controlVec, hash.getH2WithLastSet); // Compare 8 controls at once to h2 334 while (cntrlV != 0) { 335 //numC++; 336 int ffInd = firstSetBit(cntrlV); 337 int i = ffInd / 2; // Find first set bit and divide by 2 to get element index 338 if (gr.elements.ptr[i] == el) { 339 return group * 8 + i; 340 } 341 cntrlV &= 0xFFFF_FFFF << (ffInd + 1); 342 } 343 cntrlV = matchSIMD(gr.controlVec, neverUsedMask); // If there is neverUsed element, we will never find our element 344 if (cntrlV != 0) { 345 return getIndexEmptyValue; 346 } 347 group++; 348 group = group & mask; 349 } 350 351 } 352 // foreach support 353 int opApply(DG)(scope DG dg) { 354 int result; 355 foreach (ref Group gr; groups) { 356 foreach (i, ref Control c; gr.control) { 357 if (c.isEmpty) { 358 continue; 359 } 360 static if (hasValue && isForeachDelegateWithTypes!(DG, Control, T, ADV[0])) { 361 result = dg(gr.control[i], gr.elements[i], gr.values[i]); 362 } else static if (isForeachDelegateWithTypes!(DG, Control, T)) { 363 result = dg(gr.control[i], gr.elements[i]); 364 } else static if (isForeachDelegateWithTypes!(DG, T)) { 365 result = dg(gr.elements[i]); 366 } else { 367 static assert(0); 368 } 369 if (result) 370 break; 371 } 372 } 373 374 return result; 375 } 376 377 void saveGroupDistributionPlot(string path) { 378 BenchmarkData!(1, 8192) distr; // For now use benchamrk as a plotter 379 380 foreach (ref T el; this) { 381 int group = hashMod(hashFunc(el)); 382 if (group >= 8192) { 383 continue; 384 } 385 distr.times[0][group]++; 386 387 } 388 distr.plotUsingGnuplot(path, ["group distribution"]); 389 } 390 391 } 392 393 @nogc nothrow pure unittest { 394 ushort8 control = 15; 395 control.array[0] = 10; 396 control.array[7] = 10; 397 ushort check = 15; 398 ushort ret = HashSet!(int).matchSIMD(control, check); 399 assert(ret == 0b0011_1111_1111_1100); 400 } 401 402 unittest { 403 HashSet!(int) set; 404 405 assert(set.isIn(123) == false); 406 set.add(123); 407 set.add(123); 408 assert(set.isIn(123) == true); 409 assert(set.isIn(122) == false); 410 assert(set.addedElements == 1); 411 set.remove(123); 412 assert(set.isIn(123) == false); 413 assert(set.addedElements == 0); 414 assert(set.tryRemove(500) == false); 415 set.add(123); 416 assert(set.tryRemove(123) == true); 417 418 foreach (i; 1 .. 130) { 419 set.add(i); 420 } 421 422 foreach (i; 1 .. 130) { 423 assert(set.isIn(i)); 424 } 425 426 foreach (i; 130 .. 500) { 427 assert(!set.isIn(i)); 428 } 429 430 foreach (int el; set) { 431 assert(set.isIn(el)); 432 } 433 } 434 435 void benchmarkHashSetInt() { 436 HashSet!(int) set; 437 byte[int] mapStandard; 438 uint elementsNumToAdd = 200; //cast(uint)(64536*0.9); 439 // Add elements 440 foreach (int i; 0 .. elementsNumToAdd) { 441 set.add(i); 442 mapStandard[i] = true; 443 } 444 // Check if isIn is working 445 foreach (int i; 0 .. elementsNumToAdd) { 446 assert(set.isIn(i)); 447 assert((i in mapStandard) !is null); 448 } 449 // Check if isIn is returning false properly 450 foreach (int i; elementsNumToAdd .. elementsNumToAdd + 10_000) { 451 assert(!set.isIn(i)); 452 assert((i in mapStandard) is null); 453 } 454 //set.numA=set.numB=set.numC=0; 455 enum itNum = 100; 456 BenchmarkData!(2, itNum) bench; 457 doNotOptimize(set); // Make some confusion for compiler 458 doNotOptimize(mapStandard); 459 ushort myResults; 460 myResults = 0; 461 //benchmark standard library implementation 462 foreach (b; 0 .. itNum) { 463 bench.start!(1)(b); 464 foreach (i; 0 .. 1000_000) { 465 auto ret = myResults in mapStandard; 466 myResults += 1 + cast(bool) ret; //cast(typeof(myResults))(cast(bool)ret); 467 doNotOptimize(ret); 468 } 469 bench.end!(1)(b); 470 } 471 472 auto stResult = myResults; 473 //benchmark this implementation 474 myResults = 0; 475 foreach (b; 0 .. itNum) { 476 bench.start!(0)(b); 477 foreach (i; 0 .. 1000_000) { 478 auto ret = set.isIn(myResults); 479 myResults += 1 + ret; //cast(typeof(myResults))(ret); 480 doNotOptimize(ret); 481 } 482 bench.end!(0)(b); 483 } 484 assert(myResults == stResult); // Same behavior as standard map 485 //writeln(set.getLoadFactor(set.addedElements)); 486 //writeln(set.numA); 487 //writeln(set.numB); 488 //writeln(set.numC); 489 490 doNotOptimize(myResults); 491 bench.plotUsingGnuplot("test.png", ["my", "standard"]); 492 set.saveGroupDistributionPlot("distr.png"); 493 } 494 495 void benchmarkHashSetPerformancePerElement() { 496 ushort trueResults; 497 doNotOptimize(trueResults); 498 enum itNum = 1000; 499 BenchmarkData!(2, itNum) bench; 500 HashSet!(int) set; 501 byte[int] mapStandard; 502 //writeln(set.getLoadFactor(set.addedElements)); 503 //set.numA=set.numB=set.numC=0; 504 size_t lastAdded; 505 size_t numToAdd = 16 * 8; 506 507 foreach (b; 0 .. itNum) { 508 foreach (i; lastAdded .. lastAdded + numToAdd) { 509 mapStandard[cast(uint) i] = true; 510 } 511 lastAdded += numToAdd; 512 bench.start!(1)(b); 513 foreach (i; 0 .. 1000_00) { 514 auto ret = trueResults in mapStandard; 515 trueResults += 1; //cast(typeof(trueResults))(cast(bool)ret); 516 doNotOptimize(ret); 517 } 518 bench.end!(1)(b); 519 } 520 lastAdded = 0; 521 trueResults = 0; 522 foreach (b; 0 .. itNum) { 523 foreach (i; lastAdded .. lastAdded + numToAdd) { 524 set.add(cast(uint) i); 525 } 526 lastAdded += numToAdd; 527 bench.start!(0)(b); 528 foreach (i; 0 .. 1000_00) { 529 auto ret = set.isIn(trueResults); 530 trueResults += 1; //cast(typeof(trueResults))(ret); 531 doNotOptimize(ret); 532 } 533 bench.end!(0)(b); 534 } 535 //writeln(set.numA); 536 //writeln(set.numB); 537 // writeln(set.numC); 538 doNotOptimize(trueResults); 539 bench.plotUsingGnuplot("test.png", ["my", "standard"]); 540 //set.saveGroupDistributionPlot("distr.png"); 541 542 }