1 module mutils.container.spatialtree;
2 
3 import std.conv : emplace;
4 import std.experimental.allocator.building_blocks.allocator_list : AllocatorList;
5 import std.experimental.allocator.building_blocks.null_allocator : NullAllocator;
6 import std.experimental.allocator.building_blocks.region : Region;
7 import std.experimental.allocator.common : unbounded;
8 import std.experimental.allocator.mallocator : Mallocator;
9 import std.traits : ForeachType, hasMember;
10 
11 import mutils.allocator.free_list : ContiguousFreeList;
12 import mutils.container.vector : DataContainer = Vector;
13 
14 template QuadTree(T, bool loose = false, ubyte maxLevel = 8) {
15 	alias QuadTree = SpatialTree!(2, T, loose, maxLevel);
16 }
17 
18 template OcTree(T, bool loose = false, ubyte maxLevel = 8) {
19 	alias OcTree = SpatialTree!(3, T, loose, maxLevel);
20 }
21 
22 /***
23  * Implementation of QuadTree and OcTree, with loose bounds and without
24  * Loose octree requires from type T to have member pos and radius, it can be function or variable.
25  * */
26 struct SpatialTree(ubyte dimension, T, bool loose = false, ubyte maxLevel = 8) {
27 	static assert(dimension == 2 || dimension == 3,
28 			"Only QuadTrees and OcTrees are supported (dimension value: 2 or 3).");
29 	static assert(!loose || hasMember!(T, "pos") || hasMember!(T, "radius"),
30 			"Loose SpatialTree has to have members: pos and radius.");
31 
32 	alias Point = float[dimension];
33 	alias isLoose = loose;
34 
35 	enum nodesNumInArray = 2 ^^ dimension;
36 	enum allocationSize = nodesNumInArray * Node.sizeof;
37 	alias QuadAllocator = AllocatorList!((n) => ContiguousFreeList!(Mallocator,
38 			0, unbounded)(4096, allocationSize), Mallocator);
39 
40 	QuadAllocator allocator;
41 	//Region!Mallocator
42 	float size = 100;
43 
44 	/* Example T
45 	 struct SpatialTreeData{
46 	 float[2] pos;
47 	 float radius;
48 	 MyData1 data1;
49 	 MyData* data2;
50 	 }
51 	 */
52 
53 	static struct Node {
54 		//static if(loose){
55 		DataContainer!T dataContainer;
56 		Node* child;
57 		/*}else{
58 		 union{
59 		 DataContainer!T dataContainer;// Used only at lowest level
60 		 Node* child;
61 		 }
62 		 }*/
63 
64 	}
65 
66 	Node root;
67 
68 	void initialize() {
69 	}
70 
71 	~this() {
72 		clear();
73 	}
74 
75 	void clear() {
76 		int deallocateContainers(Point pos, Node* quad, float halfSize, int level) {
77 			quad.dataContainer.clear();
78 			return 0;
79 		}
80 
81 		visitAll(&deallocateContainers);
82 		allocator.deallocateAll();
83 		root.child = null;
84 		root.dataContainer.clear();
85 	}
86 
87 	void remove(Point posRemove, T data) {
88 		int levelFrom = 0;
89 
90 		bool removeFormQuad(Point pos, Node* node, float halfSize, int level) {
91 
92 			if (hasElements(level) && level >= levelFrom) {
93 				bool ok = node.dataContainer.tryRemoveElement(data);
94 				if (ok)
95 					return true;
96 				//	else writeln("MISS");
97 			}
98 
99 			if (hasChildren(node, level)) {
100 				float quarterSize = halfSize / 2;
101 
102 				bool[Point.length] direction;
103 				foreach (i; 0 .. Point.length) {
104 					direction[i] = posRemove[i] > pos[i];
105 				}
106 				Point xy = pos[] + quarterSize * (direction[] * 2 - 1);
107 				uint index = directionToIndex(direction);
108 
109 				bool ok = removeFormQuad(xy, &node.child[index], quarterSize, level + 1);
110 				if (ok)
111 					return true;
112 
113 				foreach (ubyte i; 0 .. nodesNumInArray) {
114 					if (i == index)
115 						continue;
116 					direction = indexToDirection(i);
117 					xy = pos[] + quarterSize * (direction[] * 2 - 1);
118 					ok = removeFormQuad(xy, &node.child[i], quarterSize, level + 1);
119 					if (ok)
120 						return true;
121 				}
122 			}
123 			if (hasElements(level) && level < levelFrom) {
124 				bool ok = node.dataContainer.tryRemoveElement(data);
125 				if (ok)
126 					return true;
127 				//else writeln("MISS");
128 			}
129 			return false;
130 		}
131 
132 		static if (loose) {
133 			foreach (i, el; posRemove) {
134 				if (el > size / 2 || el < -size / 2) {
135 					bool ok = root.dataContainer.tryRemoveElement(data);
136 					if (ok) {
137 						return;
138 					}
139 				}
140 			}
141 		}
142 		Point pos = 0;
143 		bool ok = removeFormQuad(pos, &root, size / 2, 0);
144 		assert(ok, "Unable to find element in Tree");
145 	}
146 
147 	/////////////////////////
148 	///// Add functions /////
149 	/////////////////////////
150 
151 	void add(Point pos, T data) {
152 		static if (loose) {
153 			foreach (i, el; pos) {
154 				if (el > size / 2 || el < -size / 2) {
155 					root.dataContainer.add(data);
156 					return;
157 				}
158 			}
159 		}
160 		static if (loose) {
161 			float diam = data.radius * 2;
162 			byte level = -1;
163 			float sizeTmp = size / 2;
164 			while (sizeTmp > diam / 2 && level < maxLevel) {
165 				level++;
166 				sizeTmp /= 2;
167 			}
168 			addToQuad(pos, data, &root, size / 2, cast(ubyte)(maxLevel - level));
169 		} else {
170 			addToQuad(pos, data, &root, size / 2, 0);
171 		}
172 	}
173 
174 	void addToQuad(Point pos, T data, Node* quad, float halfSize, ubyte level) {
175 		while (level < maxLevel) {
176 			if (quad.child is null) {
177 				allocateQuads(quad, level);
178 			}
179 			float quarterSize = halfSize / 2;
180 			bool[Point.length] direction;
181 			foreach (i; 0 .. Point.length) {
182 				direction[i] = pos[i] > 0;
183 			}
184 			pos = pos[] - quarterSize * (direction[] * 2 - 1);
185 
186 			quad = &quad.child[directionToIndex(direction)];
187 			halfSize = quarterSize;
188 			level++;
189 		}
190 		quad.dataContainer.add(data);
191 	}
192 
193 	/////////////////////////
194 	//// Visit functions ////
195 	/////////////////////////
196 	// 0 - continue, !0 - break
197 	int visitAll(scope int delegate(Point pos, Node* quad, float halfSize, int level) visitor) {
198 		Point pos = 0;
199 		return visitAll(visitor, pos, &root, size / 2, 0);
200 	}
201 
202 	int visitAll(scope int delegate(Point pos, Node* quad, float halfSize,
203 			int level) visitor, Point pos, Node* quad, float halfSize, int level) {
204 		int visitAllImpl(Point pos, Node* quad, float halfSize, int level) {
205 			int res = visitor(pos, quad, halfSize, level);
206 			int mmm = maxLevel;
207 			if (quad.child !is null && level < maxLevel) {
208 				float quarterSize = halfSize / 2;
209 				foreach (i; 0 .. nodesNumInArray) {
210 					auto direction = indexToDirection(i);
211 					Point xy = pos[] + quarterSize * (direction[] * 2 - 1);
212 					res = visitAllImpl(xy, &quad.child[i], quarterSize, level + 1);
213 					if (res)
214 						return res;
215 				}
216 			}
217 			return res;
218 		}
219 
220 		return visitAllImpl(pos, quad, halfSize, level);
221 	}
222 
223 	void visitAllNodesIn(scope int delegate(Point pos, Node* quad, float halfSize,
224 			int level) visitor, Point downLeft, Point upRight) {
225 
226 		int check(Point pos, Node* quad, float halfSize, int level) {
227 			if (quad.dataContainer.length > 0) {
228 				foreach (obj; quad.dataContainer) {
229 					Point myDownLeft = pos[] - halfSize * (1 + loose); //loose size is twice of a normal tree
230 					Point myUpRight = pos[] + halfSize * (1 + loose);
231 					if (notInBox(downLeft, upRight, myDownLeft, myUpRight)) {
232 						return 0;
233 					}
234 				}
235 			}
236 			bool hasElements = quad.dataContainer.length > 0;
237 			static if (!loose) {
238 				hasElements = hasElements && level >= maxLevel; //only leafs have data				
239 			}
240 			if (hasElements) {
241 				int res = visitor(pos, quad, halfSize, level);
242 				if (res)
243 					return res;
244 			}
245 			return 0;
246 		}
247 
248 		Point pos = 0;
249 		visitAll(&check, pos, &root, size / 2, 0);
250 
251 	}
252 
253 	void visitAllDataIn(scope void delegate(ref T data) visitor, Point downLeft, Point upRight) {
254 
255 		void visitAllDataNoCheck(Node* node, int level) {
256 			if (hasElements(level)) {
257 				foreach (ref pData; node.dataContainer) {
258 					visitor(pData);
259 				}
260 			}
261 
262 			if (hasChildren(node, level)) {
263 				foreach (i; 0 .. nodesNumInArray) {
264 					visitAllDataNoCheck(&node.child[i], level + 1);
265 				}
266 			}
267 		}
268 
269 		void visitAllDataImpl(Point pos, Node* node, float halfSize, int level) {
270 			Point myDownLeft = pos[] - halfSize * (1 + loose); //loose size is twice of a normal tree
271 			Point myUpRight = pos[] + halfSize * (1 + loose);
272 			if (notInBox(downLeft, upRight, myDownLeft, myUpRight)) {
273 				return;
274 			}
275 			if (inBox(downLeft, upRight, myDownLeft, myUpRight)) {
276 				visitAllDataNoCheck(node, level);
277 				return;
278 			}
279 			if (hasElements(level)) {
280 				foreach (ref pData; node.dataContainer) {
281 					static if (loose) {
282 						if (circleNotInBox(downLeft, upRight, pData.pos, pData.radius)) {
283 							continue;
284 						}
285 					}
286 					visitor(pData);
287 				}
288 			}
289 
290 			if (hasChildren(node, level)) {
291 				float quarterSize = halfSize / 2;
292 				foreach (i; 0 .. nodesNumInArray) {
293 					auto direction = indexToDirection(i);
294 					Point xy = pos[] + quarterSize * (direction[] * 2 - 1);
295 					visitAllDataImpl(xy, &node.child[i], quarterSize, level + 1);
296 				}
297 			}
298 		}
299 
300 		Point pos = 0;
301 		visitAllDataImpl(pos, &root, size / 2, 0);
302 	}
303 
304 	static if (hasMember!(T, "pos"))
305 		void updatePositions() {
306 			void updatePositionsImpl(Point pos, Node* node, float halfSize, int level) {
307 				if (hasElements(level)) {
308 					Point myDownLeft = pos[] - halfSize * (1 + loose); //loose size is twice of a normal tree
309 					Point myUpRight = pos[] + halfSize * (1 + loose);
310 					foreach_reverse (i, pData; node.dataContainer) {
311 						static if (loose) {
312 							float radius = pData.radius;
313 						} else {
314 							float radius = 0;
315 						}
316 						if (!circleInBox(myDownLeft, myUpRight, pData.pos,
317 								radius) || (level != maxLevel && (radius * 2) < halfSize)) {
318 							node.dataContainer.remove(i);
319 							add(pData.pos, pData);
320 						}
321 					}
322 				}
323 
324 				if (hasChildren(node, level)) {
325 					float quarterSize = halfSize / 2;
326 					foreach (i; 0 .. nodesNumInArray) {
327 						auto direction = indexToDirection(i);
328 						Point xy = pos[] + quarterSize * (direction[] * 2 - 1);
329 						updatePositionsImpl(xy, &node.child[i], quarterSize, level + 1);
330 					}
331 				}
332 			}
333 
334 			Point pos = 0;
335 			updatePositionsImpl(pos, &root, size / 2, 0);
336 		}
337 
338 	/////////////////////////
339 	//// Helper functions ///
340 	/////////////////////////
341 
342 	void allocateQuads(Node* quad, int level) {
343 		void[] xx = allocator.allocate(allocationSize);
344 		Node* nodes = cast(Node*) xx.ptr;
345 		foreach (i; 0 .. nodesNumInArray) {
346 			emplace(&nodes[i]);
347 		}
348 		quad.child = cast(Node*) xx.ptr;
349 	}
350 
351 	static bool notInBox(Point left, Point right, Point myLeft, Point myRight) pure nothrow {
352 		bool b0 = (myRight[0] < left[0]) | (myLeft[0] > right[0]);
353 		bool b1 = (myRight[1] < left[1]) | (myLeft[1] > right[1]);
354 		static if (Point.length == 2) {
355 			return b0 | b1;
356 		} else static if (Point.length == 3) {
357 			bool b2 = (myRight[2] < left[2]) | (myLeft[2] > right[2]);
358 			return b0 | b1 | b2;
359 		}
360 	}
361 
362 	static bool inBox(Point left, Point right, Point myLeft, Point myRight) pure nothrow {
363 		bool b0 = (myLeft[0] > left[0]) & (myRight[0] < right[0]);
364 		bool b1 = (myLeft[1] > left[1]) & (myRight[1] < right[1]);
365 		static if (Point.length == 2) {
366 			return b0 & b1;
367 		} else static if (Point.length == 3) {
368 			bool b2 = (myLeft[2] > left[2]) & (myRight[2] < right[2]);
369 			return b0 & b1 & b2;
370 		}
371 	}
372 
373 	static bool hasChildren(Node* node, int level) {
374 		static if (loose) {
375 			return node.child !is null;
376 		} else {
377 			return node.child !is null && level < maxLevel;
378 		}
379 	}
380 
381 	static bool hasElements(int level) {
382 		static if (loose) {
383 			return true;
384 		} else {
385 			return level == maxLevel;
386 		}
387 	}
388 
389 	static bool circleInBox(Point left, Point right, Point pos, float radius) {
390 		bool ok = true;
391 		foreach (i; 0 .. Point.length) {
392 			ok &= (pos[i] + radius < right[i]) & (pos[i] - radius > left[i]);
393 		}
394 		return ok;
395 	}
396 
397 	static bool circleNotInBox(Point left, Point right, Point pos, float radius) {
398 		bool ok = false;
399 		foreach (i; 0 .. Point.length) {
400 			ok |= (pos[i] > radius + right[i]) | (left[i] > pos[i] + radius);
401 		}
402 		return ok;
403 	}
404 
405 	static uint directionToIndex(bool[Point.length] dir) pure nothrow {
406 		static if (Point.length == 2) {
407 			return dir[0] + 2 * dir[1];
408 		} else static if (Point.length == 3) {
409 			return dir[0] + 2 * dir[1] + 4 * dir[2];
410 		}
411 	}
412 
413 	static bool[Point.length] indexToDirection(uint index) pure nothrow {
414 		static if (Point.length == 2) {
415 			return [index == 1 || index == 3, index >= 2];
416 		} else static if (Point.length == 3) {
417 			bool i3 = index == 3;
418 			bool i7 = index == 7;
419 			return [index == 1 || i3 || index == 5 || i7, index == 2 || i3
420 				|| index == 6 || i7, index >= 4];
421 		}
422 	}
423 
424 }
425 
426 unittest {
427 	import mutils.container.vector;
428 	import mutils.linalg.vec;
429 
430 	mixin(checkVectorAllocations);
431 
432 	alias vec2 = Vec!(float, 2);
433 	alias vec3 = Vec!(float, 3);
434 
435 	int numFound;
436 	int numOk;
437 	void test0(T)(ref T num) {
438 		numFound++;
439 		numOk += num.data == 0;
440 	}
441 
442 	void test1(T)(ref T num) {
443 		numFound++;
444 		numOk += num.data == 1;
445 	}
446 
447 	void test2(T)(ref T num) {
448 		numFound++;
449 		numOk += num.data == 2;
450 	}
451 
452 	void test3(T)(ref T num) {
453 		numFound++;
454 		numOk += num.data == 3;
455 	}
456 
457 	void test4(T)(ref T num) {
458 		numFound++;
459 		numOk += num.data == 4;
460 	}
461 
462 	void test5(T)(ref T num) {
463 		numFound++;
464 		numOk += num.data == 5;
465 	}
466 
467 	void test6(T)(ref T num) {
468 		numFound++;
469 		numOk += num.data == 6;
470 	}
471 
472 	void test7(T)(ref T num) {
473 		numFound++;
474 		numOk += num.data == 7;
475 	}
476 	//Test Loose QuadTree
477 	{
478 		numFound = numOk = 0;
479 		struct QuadTreeData1 {
480 			float[2] pos;
481 			float radius;
482 			int data;
483 		}
484 
485 		alias TestTree = SpatialTree!(2, QuadTreeData1, true);
486 		TestTree tree;
487 		tree.initialize();
488 		tree.add(vec2(-1, +1), QuadTreeData1(vec2(-1, +1), 0.1, 0));
489 		tree.add(vec2(+1, +1), QuadTreeData1(vec2(+1, +1), 0.1, 1));
490 		tree.add(vec2(+1, -1), QuadTreeData1(vec2(+1, -1), 0.1, 2));
491 		tree.add(vec2(-1, -1), QuadTreeData1(vec2(-1, -1), 0.1, 3));
492 
493 		tree.visitAllDataIn(&test0!QuadTreeData1, vec2(-10, 0.1), vec2(-0.1, 10));
494 		tree.visitAllDataIn(&test1!QuadTreeData1, vec2(0.1, 0.1), vec2(10, 10));
495 		tree.visitAllDataIn(&test2!QuadTreeData1, vec2(0.1, -10), vec2(10, -0.1));
496 		tree.visitAllDataIn(&test3!QuadTreeData1, vec2(-10, -10), vec2(-0.1, -0.1));
497 		assert(numFound == 4);
498 		assert(numOk == 4);
499 	}
500 
501 	//Test QuadTree
502 	{
503 		numFound = numOk = 0;
504 		struct QuadTreeData2 {
505 			int data;
506 		}
507 
508 		alias TestTree = SpatialTree!(2, QuadTreeData2, false);
509 		TestTree tree;
510 		tree.initialize();
511 
512 		tree.add(vec2(-5, +5), QuadTreeData2(0));
513 		tree.add(vec2(+5, +5), QuadTreeData2(1));
514 		tree.add(vec2(+5, -5), QuadTreeData2(2));
515 		tree.add(vec2(-5, -5), QuadTreeData2(3));
516 
517 		tree.visitAllDataIn(&test0!QuadTreeData2, vec2(-10, 0.1), vec2(-0.1, 10));
518 		tree.visitAllDataIn(&test1!QuadTreeData2, vec2(0.1, 0.1), vec2(10, 10));
519 		tree.visitAllDataIn(&test2!QuadTreeData2, vec2(0.1, -10), vec2(10, -0.1));
520 		tree.visitAllDataIn(&test3!QuadTreeData2, vec2(-10, -10), vec2(-0.1, -0.1));
521 		assert(numFound == 4);
522 		assert(numOk == 4);
523 	}
524 
525 	//Test Loose OctTree
526 	{
527 		numFound = numOk = 0;
528 		struct OctTreeData1 {
529 			float[3] pos;
530 			float radius;
531 			int data;
532 		}
533 
534 		alias TestTree = SpatialTree!(3, OctTreeData1, true);
535 		TestTree tree;
536 		tree.initialize();
537 
538 		tree.add(vec3(-5, +5, +5), OctTreeData1(vec3(-5, +5, +5), 0.1, 0));
539 		tree.add(vec3(+5, +5, +5), OctTreeData1(vec3(+5, +5, +5), 0.1, 1));
540 		tree.add(vec3(+5, -5, +5), OctTreeData1(vec3(+5, -5, +5), 0.1, 2));
541 		tree.add(vec3(-5, -5, +5), OctTreeData1(vec3(-5, -5, +5), 0.1, 3));
542 		tree.add(vec3(-5, +5, -5), OctTreeData1(vec3(-5, +5, -5), 0.1, 4));
543 		tree.add(vec3(+5, +5, -5), OctTreeData1(vec3(+5, +5, -5), 0.1, 5));
544 		tree.add(vec3(+5, -5, -5), OctTreeData1(vec3(+5, -5, -5), 0.1, 6));
545 		tree.add(vec3(-5, -5, -5), OctTreeData1(vec3(-5, -5, -5), 0.1, 7));
546 
547 		tree.visitAllDataIn(&test0!OctTreeData1, vec3(-10, 0.1, 0.1), vec3(-0.1, 10, 10));
548 		tree.visitAllDataIn(&test1!OctTreeData1, vec3(0.1, 0.1, 0.1), vec3(10, 10, 10));
549 		tree.visitAllDataIn(&test2!OctTreeData1, vec3(0.1, -10, 0.1), vec3(10, -0.1, 10));
550 		tree.visitAllDataIn(&test3!OctTreeData1, vec3(-10, -10, 0.1), vec3(-0.1, -0.1, 10));
551 		tree.visitAllDataIn(&test4!OctTreeData1, vec3(-10, 0.1, -10), vec3(-0.1, 10, -0.1));
552 		tree.visitAllDataIn(&test5!OctTreeData1, vec3(0.1, 0.1, -10), vec3(10, 10, -0.1));
553 		tree.visitAllDataIn(&test6!OctTreeData1, vec3(0.1, -10, -10), vec3(10, -0.1, -0.1));
554 		tree.visitAllDataIn(&test7!OctTreeData1, vec3(-10, -10, -10),
555 				vec3(-0.1, -0.1, -0.1));
556 		assert(numFound == 8);
557 		assert(numOk == 8);
558 	}
559 
560 	//Test OctTree
561 	{
562 		numFound = numOk = 0;
563 		struct OctTreeData2 {
564 			int data;
565 		}
566 
567 		alias TestTree = SpatialTree!(3, OctTreeData2, false);
568 		TestTree tree;
569 		tree.initialize();
570 
571 		tree.add(vec3(-5, +5, +5), OctTreeData2(0));
572 		tree.add(vec3(+5, +5, +5), OctTreeData2(1));
573 		tree.add(vec3(+5, -5, +5), OctTreeData2(2));
574 		tree.add(vec3(-5, -5, +5), OctTreeData2(3));
575 		tree.add(vec3(-5, +5, -5), OctTreeData2(4));
576 		tree.add(vec3(+5, +5, -5), OctTreeData2(5));
577 		tree.add(vec3(+5, -5, -5), OctTreeData2(6));
578 		tree.add(vec3(-5, -5, -5), OctTreeData2(7));
579 
580 		tree.visitAllDataIn(&test0!OctTreeData2, vec3(-10, 0.1, 0.1), vec3(-0.1, 10, 10));
581 		tree.visitAllDataIn(&test1!OctTreeData2, vec3(0.1, 0.1, 0.1), vec3(10, 10, 10));
582 		tree.visitAllDataIn(&test2!OctTreeData2, vec3(0.1, -10, 0.1), vec3(10, -0.1, 10));
583 		tree.visitAllDataIn(&test3!OctTreeData2, vec3(-10, -10, 0.1), vec3(-0.1, -0.1, 10));
584 		tree.visitAllDataIn(&test4!OctTreeData2, vec3(-10, 0.1, -10), vec3(-0.1, 10, -0.1));
585 		tree.visitAllDataIn(&test5!OctTreeData2, vec3(0.1, 0.1, -10), vec3(10, 10, -0.1));
586 		tree.visitAllDataIn(&test6!OctTreeData2, vec3(0.1, -10, -10), vec3(10, -0.1, -0.1));
587 		tree.visitAllDataIn(&test7!OctTreeData2, vec3(-10, -10, -10),
588 				vec3(-0.1, -0.1, -0.1));
589 		assert(numFound == 8);
590 		assert(numOk == 8);
591 	}
592 
593 }