1 module mutils.container.hash_set; 2 3 import core.bitop; 4 import core.simd: ushort8; 5 import std.meta; 6 import std.stdio; 7 import std.traits; 8 9 import mutils.benchmark; 10 import mutils.container.vector; 11 import mutils.traits; 12 13 version(DigitalMars){ 14 import core.bitop; 15 alias firstSetBit=bsr;// DMD treats it as intrinsics 16 }else version(LDC){ 17 import ldc.intrinsics; 18 int firstSetBit(int i){ 19 return llvm_cttz(i, true)+1; 20 } 21 }else{ 22 static assert("Compiler not supported."); 23 } 24 25 enum ushort emptyMask=1; 26 enum ushort neverUsedMask=2; 27 enum ushort hashMask=~emptyMask; 28 // lower 15 bits - part of hash, last bit - isEmpty 29 struct Control{ 30 nothrow @nogc @safe: 31 32 ushort b=neverUsedMask; 33 34 bool isEmpty(){ 35 return (b & emptyMask)==0; 36 } 37 38 /*void setEmpty(){ 39 b=emptyMask; 40 }*/ 41 42 /*bool cmpHash(size_t hash){ 43 union Tmp{ 44 size_t h; 45 ushort[size_t.sizeof/2] d; 46 } 47 Tmp t=Tmp(hash); 48 return (t.d[0] & hashMask)==(b & hashMask); 49 }*/ 50 51 void set(size_t hash){ 52 union Tmp{ 53 size_t h; 54 ushort[size_t.sizeof/2] d; 55 } 56 Tmp t=Tmp(hash); 57 b=(t.d[0] & hashMask) | emptyMask; 58 } 59 } 60 61 // Hash helper struct 62 // hash is made out of two parts[ H1 48 bits ][ H2 16 bits] 63 // whole hash is used to find group 64 // H2 is used to quickly(SIMD) find element in group 65 struct Hash{ 66 nothrow @nogc @safe: 67 union{ 68 size_t h=void; 69 ushort[size_t.sizeof/2] d=void; 70 } 71 this(size_t hash){ 72 h=hash; 73 } 74 75 size_t getH1(){ 76 Hash tmp=h; 77 tmp.d[0]=d[0] & emptyMask;//clear H2 hash 78 return tmp.h; 79 } 80 81 ushort getH2(){ 82 return d[0] & hashMask; 83 } 84 85 ushort getH2WithLastSet(){ 86 return d[0] | emptyMask; 87 } 88 89 } 90 91 size_t defaultHashFunc(T)(auto ref T t){ 92 static if (isIntegral!(T)){ 93 return hashInt(t); 94 }else{ 95 return hashInt(t.hashOf);// hashOf is not giving proper distribution between H1 and H2 hash parts 96 } 97 } 98 99 // Can turn bad hash function to good one 100 ulong hashInt(ulong x) nothrow @nogc @safe { 101 x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9; 102 x = (x ^ (x >> 27)) * 0x94d049bb133111eb; 103 x = x ^ (x >> 31); 104 return x; 105 } 106 107 // ADV additional value - used to implement HashMap without unnecessary copies 108 struct HashSet(T, alias hashFunc=defaultHashFunc, ADV...){ 109 static assert(ADV.length<=1);// ADV is treated as a optional additional value type 110 static assert(size_t.sizeof==8);// Only 64 bit 111 enum hasValue=ADV.length==1; 112 enum rehashFactor=0.85; 113 enum size_t getIndexEmptyValue=size_t.max; 114 115 static struct Group{ 116 union{ 117 Control[8] control; 118 ushort8 controlVec; 119 } 120 T[8] elements; 121 static if(hasValue)ADV[0][8] values; 122 123 // Prevent error in Vector!Group 124 bool opEquals()(auto ref const Group r) const { 125 assert(0); 126 } 127 } 128 129 130 Vector!Group groups;// Length should be always power of 2 131 size_t addedElements;// Used to compute loadFactor 132 133 float getLoadFactor(size_t forElementsNum) { 134 if(groups.length==0){ 135 return 1; 136 } 137 return cast(float)forElementsNum/(groups.length*8); 138 } 139 140 void rehash() { 141 mixin(doNotInline); 142 // Get all elements 143 Vector!T allElements; 144 allElements.reserve(groups.length); 145 static if(hasValue)Vector!(ADV[0]) allValues; 146 static if(hasValue)allValues.reserve(groups.length); 147 static if(hasValue){ 148 foreach(ref Control c, ref T el, ref ADV[0] val; this){ 149 allElements~=el; 150 allValues~=val; 151 c=Control.init; 152 } 153 }else{ 154 foreach(ref Control c, ref T el; this){ 155 allElements~=el; 156 c=Control.init; 157 } 158 } 159 160 if(getLoadFactor(addedElements+1)>rehashFactor){// Reallocate 161 groups.length=(groups.length?groups.length:1)<<1;// Power of two 162 } 163 164 // Insert elements 165 foreach(i, el;allElements){ 166 static if(hasValue){ 167 add(el, allValues[i]); 168 }else{ 169 add(el); 170 } 171 } 172 addedElements=allElements.length; 173 allElements.clear(); 174 } 175 176 size_t length(){ 177 return addedElements; 178 } 179 180 bool tryRemove(T el) { 181 size_t index=getIndex(el); 182 if(index==getIndexEmptyValue){ 183 return false; 184 } 185 addedElements--; 186 size_t group=index/8; 187 size_t elIndex=index%8; 188 groups[group].control[elIndex]=Control.init; 189 //TODO value destructor 190 return true; 191 } 192 193 void remove(T el) { 194 assert(tryRemove(el)); 195 } 196 197 void add(T el, ADV value){ 198 if(isIn(el)){ 199 return; 200 } 201 202 if(getLoadFactor(addedElements+1)>rehashFactor){ 203 assumeNoGC(&rehash)();// rehash is @nogc but compiler cannot deduce that because rehash calls add internally 204 } 205 addedElements++; 206 Hash hash=Hash(hashFunc(el)); 207 int group=hashMod(hash.h);// Starting point 208 uint groupSkip=0; 209 while(true){ 210 Group* gr=&groups[group]; 211 foreach(i, ref Control c; gr.control){ 212 if(c.isEmpty){ 213 c.set(hash.h); 214 gr.elements[i]=el; 215 static if(hasValue)gr.values[i]=value[0]; 216 return; 217 } 218 } 219 group++; 220 if(group>=groups.length){ 221 group=0; 222 } 223 } 224 } 225 226 // Sets bits in ushort where value in control matches check value 227 // Ex. control=[0,1,2,3,4,5,6,7], check=2, return=0b0000_0000_0011_0000 228 static auto matchSIMD(ushort8 control, ushort check) @nogc { 229 ushort8 v=ushort8(check); 230 version(DigitalMars){ 231 import core.simd: __simd, ubyte16, XMM; 232 ubyte16 ok=__simd(XMM.PCMPEQW, control, v); 233 ubyte16 bitsMask=[1,2,4,8,16,32,64,128,1,2,4,8,16,32,64,128]; 234 ubyte16 bits=bitsMask&ok; 235 ubyte16 zeros=0; 236 ushort8 vv=__simd(XMM.PSADBW, bits, zeros); 237 ushort num=cast(ushort)(vv[0]+vv[4]*256); 238 }else version(LDC){ 239 import ldc.simd; 240 import ldc.gccbuiltins_x86; 241 ushort8 ok = equalMask!ushort8(control, v); 242 ushort num=cast(ushort)__builtin_ia32_pmovmskb128(ok); 243 }else{ 244 static assert(0); 245 } 246 return num; 247 } 248 // Division is expensive use lookuptable 249 int hashMod(size_t hash) nothrow @nogc @system{ 250 return cast(int)(hash & (groups.length-1)); 251 } 252 253 bool isIn(T el){ 254 return getIndex(el)!=getIndexEmptyValue; 255 } 256 257 // For debug 258 /*int numA; 259 int numB; 260 int numC;*/ 261 262 263 size_t getIndex(T el) { 264 return getIndex(el); 265 } 266 267 size_t getIndex(ref T el) { 268 //mixin(doNotInline); 269 size_t groupsLength=groups.length; 270 if(groupsLength==0){ 271 return getIndexEmptyValue; 272 } 273 274 Hash hash=Hash(hashFunc(el)); 275 size_t mask=groupsLength-1; 276 size_t group=cast(int)(hash.h & mask);// Starting point 277 //numA++; 278 while( true ){ 279 //numB++; 280 Group* gr=&groups[group]; 281 int cntrlV=matchSIMD(gr.controlVec, hash.getH2WithLastSet);// Compare 8 controls at once to h2 282 while( cntrlV!=0 ){ 283 //numC++; 284 int ffInd=firstSetBit(cntrlV); 285 int i=ffInd/2;// Find first set bit and divide by 2 to get element index 286 if( gr.elements.ptr[i]==el ){ 287 return group*8+i; 288 } 289 cntrlV&=0xFFFF_FFFF<<(ffInd+1); 290 } 291 cntrlV=matchSIMD(gr.controlVec, neverUsedMask);// If there is neverUsed element, we will never find our element 292 if( cntrlV!=0 ){ 293 return getIndexEmptyValue; 294 } 295 group++; 296 group=group & mask; 297 } 298 299 } 300 // foreach support 301 int opApply(DG)(scope DG dg) { 302 int result; 303 foreach(ref Group gr; groups){ 304 foreach(i, ref Control c; gr.control){ 305 if(c.isEmpty){ 306 continue; 307 } 308 static if(hasValue && isForeachDelegateWithTypes!(DG, Control, T, ADV[0]) ){ 309 result=dg(gr.control[i], gr.elements[i], gr.values[i]); 310 }else static if( isForeachDelegateWithTypes!(DG, Control, T) ){ 311 result=dg(gr.control[i], gr.elements[i]); 312 }else static if( isForeachDelegateWithTypes!(DG, T) ){ 313 result=dg(gr.elements[i]); 314 }else{ 315 static assert(0); 316 } 317 if (result) 318 break; 319 } 320 } 321 322 return result; 323 } 324 325 void saveGroupDistributionPlot(string path){ 326 BenchmarkData!(1, 8192) distr;// For now use benchamrk as a plotter 327 328 foreach(ref T el; this){ 329 int group=hashMod(hashFunc(el)); 330 if(group>=8192){ 331 continue; 332 } 333 distr.times[0][group]++; 334 335 } 336 distr.plotUsingGnuplot(path, ["group distribution"]); 337 } 338 339 } 340 341 342 343 @nogc nothrow pure unittest{ 344 ushort8 control=15; 345 control.array[0]=10; 346 control.array[7]=10; 347 ushort check=15; 348 ushort ret=HashSet!(int).matchSIMD(control, check); 349 assert(ret==0b0011_1111_1111_1100); 350 } 351 352 @nogc unittest{ 353 static struct KeyValue{ 354 int key; 355 int value; 356 bool opEquals()(auto ref const KeyValue r) @nogc{ 357 return key==r.key; 358 } 359 } 360 static size_t hashFunc(KeyValue kv){ 361 return defaultHashFunc(kv.key); 362 } 363 364 HashSet!(KeyValue, hashFunc) set222; 365 set222.add(KeyValue(1,2)); 366 367 HashSet!(int) set; 368 369 assert(set.isIn(123)==false); 370 set.add(123); 371 set.add(123); 372 assert(set.isIn(123)==true); 373 assert(set.isIn(122)==false); 374 assert(set.addedElements==1); 375 set.remove(123); 376 assert(set.isIn(123)==false); 377 assert(set.addedElements==0); 378 assert(set.tryRemove(500)==false); 379 set.add(123); 380 assert(set.tryRemove(123)==true); 381 382 383 foreach(i;1..130){ 384 set.add(i); 385 } 386 387 foreach(i;1..130){ 388 assert(set.isIn(i)); 389 } 390 391 foreach(i;130..500){ 392 assert(!set.isIn(i)); 393 } 394 395 foreach(int el; set){ 396 assert(set.isIn(el)); 397 } 398 } 399 400 401 void benchmarkHashSetInt(){ 402 HashSet!(int) set; 403 byte[int] mapStandard; 404 uint elementsNumToAdd=200;//cast(uint)(64536*0.9); 405 // Add elements 406 foreach(int i;0..elementsNumToAdd){ 407 set.add(i); 408 mapStandard[i]=true; 409 } 410 // Check if isIn is working 411 foreach(int i;0..elementsNumToAdd){ 412 assert(set.isIn(i)); 413 assert((i in mapStandard) !is null); 414 } 415 // Check if isIn is returning false properly 416 foreach(int i;elementsNumToAdd..elementsNumToAdd+10_000){ 417 assert(!set.isIn(i)); 418 assert((i in mapStandard) is null); 419 } 420 //set.numA=set.numB=set.numC=0; 421 enum itNum=100; 422 BenchmarkData!(2, itNum) bench; 423 doNotOptimize(set);// Make some confusion for compiler 424 doNotOptimize(mapStandard); 425 ushort myResults; 426 myResults=0; 427 //benchmark standard library implementation 428 foreach(b;0..itNum){ 429 bench.start!(1)(b); 430 foreach(i;0..1000_000){ 431 auto ret=myResults in mapStandard; 432 myResults+=1+cast(bool)ret;//cast(typeof(myResults))(cast(bool)ret); 433 doNotOptimize(ret); 434 } 435 bench.end!(1)(b); 436 } 437 438 auto stResult=myResults; 439 //benchmark this implementation 440 myResults=0; 441 foreach(b;0..itNum){ 442 bench.start!(0)(b); 443 foreach(i;0..1000_000){ 444 auto ret=set.isIn(myResults); 445 myResults+=1+ret;//cast(typeof(myResults))(ret); 446 doNotOptimize(ret); 447 } 448 bench.end!(0)(b); 449 } 450 assert(myResults==stResult);// Same behavior as standard map 451 //writeln(set.getLoadFactor(set.addedElements)); 452 //writeln(set.numA); 453 //writeln(set.numB); 454 //writeln(set.numC); 455 456 doNotOptimize(myResults); 457 bench.plotUsingGnuplot("test.png",["my", "standard"]); 458 set.saveGroupDistributionPlot("distr.png"); 459 } 460 461 462 void benchmarkHashSetPerformancePerElement(){ 463 ushort trueResults; 464 doNotOptimize(trueResults); 465 enum itNum=1000; 466 BenchmarkData!(2, itNum) bench; 467 HashSet!(int) set; 468 byte[int] mapStandard; 469 //writeln(set.getLoadFactor(set.addedElements)); 470 //set.numA=set.numB=set.numC=0; 471 size_t lastAdded; 472 size_t numToAdd=16*8; 473 474 foreach(b;0..itNum){ 475 foreach(i;lastAdded..lastAdded+numToAdd){ 476 mapStandard[cast(uint)i]=true; 477 } 478 lastAdded+=numToAdd; 479 bench.start!(1)(b); 480 foreach(i;0..1000_00){ 481 auto ret=trueResults in mapStandard; 482 trueResults+=1;//cast(typeof(trueResults))(cast(bool)ret); 483 doNotOptimize(ret); 484 } 485 bench.end!(1)(b); 486 } 487 lastAdded=0; 488 trueResults=0; 489 foreach(b;0..itNum){ 490 foreach(i;lastAdded..lastAdded+numToAdd){ 491 set.add(cast(uint)i); 492 } 493 lastAdded+=numToAdd; 494 bench.start!(0)(b); 495 foreach(i;0..1000_00){ 496 auto ret=set.isIn(trueResults); 497 trueResults+=1;//cast(typeof(trueResults))(ret); 498 doNotOptimize(ret); 499 } 500 bench.end!(0)(b); 501 } 502 //writeln(set.numA); 503 //writeln(set.numB); 504 // writeln(set.numC); 505 doNotOptimize(trueResults); 506 bench.plotUsingGnuplot("test.png",["my", "standard"]); 507 //set.saveGroupDistributionPlot("distr.png"); 508 509 }