1 module mutils.container.spatialtree;
2 
3 import std.stdio;
4 import std.traits : ForeachType,hasMember;
5 
6 import mutils.container.buckets_chain;
7 import mutils.container.vector : DataContainer = Vector;
8 
9 template QuadTree(T, bool loose=false, ubyte maxLevel=8){
10 	alias QuadTree=SpatialTree!(2, T, loose, maxLevel);
11 }
12 
13 template OcTree(T, bool loose=false, ubyte maxLevel=8){
14 	alias OcTree=SpatialTree!(3, T, loose, maxLevel);
15 }
16 
17 /***
18  * Implementation of QuadTree and OcTree, with loose bounds and without
19  * Loose octree requires from type T to have member pos and radius, it can be function or variable.
20  * */
21 struct SpatialTree(ubyte dimension, T, bool loose=false, ubyte maxLevel=8){
22 	static assert(dimension==2 || dimension==3, "Only QuadTrees and OcTrees are supported (dimension value: 2 or 3).");
23 	static assert(!loose || hasMember!(T, "pos") || hasMember!(T, "radius"), "Loose SpatialTree has to have members: pos and radius.");
24 	
25 	alias Point=float[dimension];
26 	alias isLoose=loose;
27 	alias QuadContainer=BucketsListChain!(Node[2^^Point.length], 128, false);
28 	
29 	float size=100;
30 	QuadContainer quadContainer;
31 	
32 	/* Example T
33 	 struct SpatialTreeData{
34 		 float[2] pos;
35 		 float radius;
36 		 MyData1 data1;
37 		 MyData* data2;
38 	 }
39 	 */
40 	
41 	static struct Node{
42 		static if(loose){
43 			DataContainer!T dataContainer;
44 			Node* child;
45 		}else{
46 			union{
47 				DataContainer!T dataContainer;// Used only at lowest level
48 				Node* child;
49 			}
50 		}
51 		
52 	}
53 	
54 	Node root;
55 	
56 	void initialize(){}
57 	
58 	void clear(){
59 		root.child=null;
60 		root.dataContainer.clear();
61 		quadContainer.clear();
62 	}
63 	
64 	/////////////////////////
65 	///// Add functions /////
66 	/////////////////////////
67 	
68 	void remove(Point posRemove, T data){
69 		int levelFrom=0;
70 		
71 		bool removeFormQuad(Point pos, Node* node, float halfSize, int level){	
72 			
73 			if(hasElements(level) && level>=levelFrom){
74 				bool ok=node.dataContainer.tryRemoveElement(data);
75 				if(ok)return true;
76 				//	else writeln("MISS");
77 			}
78 			
79 			if(hasChildren(node, level)){
80 				float quarterSize=halfSize/2;
81 				
82 				bool[Point.length] direction;
83 				foreach(i;0..Point.length){
84 					direction[i]=posRemove[i]>pos[i];
85 				}
86 				Point xy=pos[]+quarterSize*(direction[]*2-1);
87 				uint index=directionToIndex(direction);
88 				
89 				bool ok=removeFormQuad(xy,&node.child[index],quarterSize,level+1);
90 				if(ok)return true;
91 				
92 				enum uint nodesNumber=2^^Point.length;
93 				foreach(ubyte i;0..nodesNumber){
94 					if(i==index)continue;
95 					direction=indexToDirection(i);
96 					xy=pos[]+quarterSize*(direction[]*2-1);
97 					ok=removeFormQuad(xy,&node.child[i],quarterSize,level+1);
98 					if(ok)return true;
99 				}
100 			}
101 			if(hasElements(level) && level<levelFrom){
102 				bool ok=node.dataContainer.tryRemoveElement(data);
103 				if(ok)return true;
104 				//else writeln("MISS");
105 			}
106 			return false;
107 		}
108 		
109 		
110 		foreach(i,el;posRemove){
111 			if(el>size/2 || el<-size/2){
112 				static if(loose){
113 					root.dataContainer.removeElement(data);
114 				}
115 				return;
116 			}
117 		}
118 		static if(loose){
119 			float diam=data.radius*2;
120 			byte level=-1;
121 			float sizeTmp=size/2;
122 			while(sizeTmp>diam/2 && level<maxLevel){
123 				level++;
124 				sizeTmp/=2;
125 			}
126 			levelFrom=level;
127 		}
128 		Point pos=0;
129 		bool ok=removeFormQuad(pos, &root, size/2, 0);
130 		assert(ok,"Unable to find element in Tree");
131 	}
132 	
133 	void add(Point pos, T data){
134 		foreach(i,el;pos){
135 			if(el>size/2 || el<-size/2){
136 				static if(loose){
137 					root.dataContainer.add(data);
138 				}
139 				return;
140 			}
141 		}
142 		static if(loose){
143 			float diam=data.radius*2;
144 			byte level=-1;
145 			float sizeTmp=size/2;
146 			while(sizeTmp>diam/2 && level<maxLevel){
147 				level++;
148 				sizeTmp/=2;
149 			}
150 			
151 			addToQuad(pos, data, &root, size/2, cast(ubyte)(maxLevel-level));
152 		}else{
153 			addToQuad(pos, data, &root, size/2, 0);
154 		}
155 	}
156 	
157 	
158 	void addToQuad(Point pos, T data, Node* quad, float halfSize, ubyte level){
159 		while(level<maxLevel){
160 			if(quad.child is null){
161 				allocateQuads(quad, level);
162 			}
163 			float quarterSize=halfSize/2;
164 			bool[Point.length] direction;
165 			foreach(i;0..Point.length){
166 				direction[i]=pos[i]>0;
167 			}
168 			pos=pos[]-quarterSize*(direction[]*2-1);
169 			
170 			quad=&quad.child[directionToIndex(direction)];
171 			halfSize=quarterSize;
172 			level++;
173 		}
174 		quad.dataContainer.add(data);	
175 	}
176 	
177 	/////////////////////////
178 	//// Visit functions ////
179 	/////////////////////////
180 	
181 	
182 	
183 	
184 	int visitAll(scope int delegate(Point pos, Node* quad, float halfSize, int level) visitor){
185 		Point pos=0;
186 		return visitAll(visitor,pos,&root, size/2,0);		
187 	}
188 	int visitAll(scope int delegate(Point pos, Node* quad, float halfSize, int level) visitor, Point pos, Node* quad, float halfSize, int level){
189 		int visitAllImpl(Point pos, Node* quad, float halfSize, int level){
190 			int res=visitor(pos, quad, halfSize, level);
191 			if(quad.child !is null && level<maxLevel){
192 				float quarterSize=halfSize/2;
193 				enum uint nodesNumber=2^^Point.length;
194 				foreach(i;0..nodesNumber){
195 					auto direction=indexToDirection(i);
196 					Point xy=pos[]+quarterSize*(direction[]*2-1);
197 					res=visitAllImpl(xy,&quad.child[i],quarterSize,level+1);
198 					if(res)return res;
199 				}
200 			}	
201 			return res;			
202 		}
203 		return visitAllImpl(pos, quad, halfSize, level);		
204 	}	
205 	
206 	
207 	void visitAllNodesIn(scope int delegate(Point pos, Node* quad, float halfSize, int level) visitor, Point downLeft, Point upRight){
208 		
209 		int check(Point pos, Node* quad, float halfSize, int level)
210 		{
211 			if(quad.dataContainer.length>0){
212 				foreach( obj;quad.dataContainer){
213 					Point myDownLeft=pos[]-halfSize*(1+loose);//loose size is twice of a normal tree
214 					Point myUpRight =pos[]+halfSize*(1+loose);
215 					if(notInBox(downLeft, upRight, myDownLeft, myUpRight)){
216 						return 0;
217 					}
218 				}
219 			}
220 			bool hasElements=quad.dataContainer.length>0;
221 			static if(!loose){
222 				hasElements = hasElements && level>=maxLevel;//only leafs have data				
223 			}
224 			if(hasElements){
225 				int res=visitor(pos, quad, halfSize, level);
226 				if(res)return res;
227 			}
228 			return 0;
229 		}
230 		Point pos=0;
231 		visitAll(&check,pos,&root, size/2,0);
232 		
233 	}
234 	
235 	
236 	void visitAllDataIn(scope void delegate(ref T data) visitor, Point downLeft, Point upRight){
237 		
238 		void visitAllDataNoCheck(Node* node, int level){
239 			if(hasElements(level)){
240 				foreach(ref pData;node.dataContainer){
241 					static if(loose){
242 						if(circleNotInBox(downLeft, upRight, pData.pos,pData.radius )){
243 							continue;
244 						}
245 					}
246 					visitor(pData);			
247 				}
248 			}
249 			
250 			if(hasChildren(node, level)){
251 				enum uint nodesNumber=2^^Point.length;
252 				foreach(i;0..nodesNumber){
253 					visitAllDataNoCheck(&node.child[i], level+1);
254 				}
255 			}			
256 		}
257 		
258 		void visitAllDataImpl(Point pos, Node* node, float halfSize, int level){
259 			Point myDownLeft=pos[]-halfSize*(1+loose);//loose size is twice of a normal tree
260 			Point myUpRight =pos[]+halfSize*(1+loose);
261 			if(notInBox(downLeft, upRight, myDownLeft, myUpRight)){
262 				return;
263 			}
264 			if(inBox(downLeft, upRight, myDownLeft, myUpRight)){
265 				visitAllDataNoCheck(node,level);
266 				return;
267 			}
268 			if(hasElements(level)){
269 				foreach(ref pData;node.dataContainer){
270 					static if(loose){
271 						if(circleNotInBox(downLeft, upRight, pData.pos,pData.radius )){
272 							continue;
273 						}
274 					}
275 					visitor(pData);				
276 				}
277 			}
278 			
279 			if(hasChildren(node, level)){
280 				float quarterSize=halfSize/2;
281 				enum uint nodesNumber=2^^Point.length;
282 				foreach(i;0..nodesNumber){
283 					auto direction=indexToDirection(i);
284 					Point xy=pos[]+quarterSize*(direction[]*2-1);
285 					visitAllDataImpl(xy,&node.child[i],quarterSize,level+1);
286 				}
287 			}	
288 		}
289 		Point pos=0;
290 		visitAllDataImpl(pos, &root, size/2, 0);		
291 	}
292 	
293 	static if(hasMember!(T,"pos"))
294 	void updatePositions(){	
295 		void updatePositionsImpl(Point pos, Node* node, float halfSize, int level){
296 			if(hasElements(level)){
297 				Point myDownLeft=pos[]-halfSize*(1+loose);//loose size is twice of a normal tree
298 				Point myUpRight =pos[]+halfSize*(1+loose);
299 				foreach_reverse(i,pData;node.dataContainer){
300 					static if(loose){
301 						float radius=pData.radius;
302 					}else{
303 						float radius=0;
304 					}
305 					if(!circleInBox(myDownLeft, myUpRight, pData.pos, radius ) || (level!=maxLevel && (radius*2)<halfSize)){
306 						node.dataContainer.remove(i);
307 						add(pData.pos,pData);
308 					}			
309 				}
310 			}
311 			
312 			if(hasChildren(node, level)){
313 				float quarterSize=halfSize/2;
314 				enum uint nodesNumber=2^^Point.length;
315 				foreach(i;0..nodesNumber){
316 					auto direction=indexToDirection(i);
317 					Point xy=pos[]+quarterSize*(direction[]*2-1);
318 					updatePositionsImpl(xy,&node.child[i],quarterSize,level+1);
319 				}
320 			}	
321 		}
322 		Point pos=0;
323 		updatePositionsImpl(pos, &root, size/2, 0);		
324 	}
325 
326 	
327 	
328 	/////////////////////////
329 	//// Helper functions ///
330 	/////////////////////////
331 	
332 	void allocateQuads(Node* quad, int level){
333 		auto xx=quadContainer.add();
334 		quad.child=cast(Node*)xx.ptr;
335 	}
336 	
337 	static bool notInBox(Point left, Point right,Point myLeft, Point myRight) pure nothrow{
338 		bool b0=(myRight[0]<left[0]) | (myLeft[0]>right[0]);
339 		bool b1=(myRight[1]<left[1]) | (myLeft[1]>right[1]);
340 		static if(Point.length==2){
341 			return b0 | b1;
342 		}else static if(Point.length==3){
343 			bool b2=(myRight[2]<left[2]) | (myLeft[2]>right[2]);
344 			return b0 | b1 | b2;
345 		}
346 	}
347 	
348 	static bool inBox(Point left, Point right,Point myLeft, Point myRight) pure nothrow{
349 		bool b0=(myLeft[0]>left[0]) & (myRight[0]<right[0]);
350 		bool b1=(myLeft[1]>left[1]) & (myRight[1]<right[1]);
351 		static if(Point.length==2){
352 			return b0 & b1;
353 		}else static if(Point.length==3){
354 			bool b2=(myLeft[2]>left[2]) & (myRight[2]<right[2]);
355 			return b0 & b1 & b2;
356 		}
357 	}
358 	
359 	static bool hasChildren(Node* node, int level)  {
360 		static if(loose){
361 			return node.child !is null;
362 		}else{
363 			return node.child !is null && level<maxLevel;
364 		}
365 	}
366 	
367 	static bool hasElements(int level)  {
368 		static if(loose){
369 			return true;
370 		}else{
371 			return level==maxLevel;
372 		}
373 	}
374 	
375 	static bool circleInBox(Point left, Point right,Point pos, float radius)  {
376 		bool ok=true;
377 		foreach(i;0..Point.length){
378 			ok&=(pos[i]+radius<right[i]) & (pos[i]-radius>left[i]);
379 		}
380 		return ok;
381 	}
382 	
383 	
384 	
385 	static bool circleNotInBox(Point left, Point right,Point pos, float radius)  {
386 		bool ok=false;
387 		foreach(i;0..Point.length){
388 			ok|=(pos[i]>radius+right[i]) | (left[i]>pos[i]+radius);
389 		}
390 		return ok;
391 	}
392 	
393 	static uint directionToIndex(bool[Point.length] dir) pure nothrow{
394 		static if(Point.length==2){
395 			return dir[0]+2*dir[1];
396 		}else static if(Point.length==3){
397 			return dir[0]+2*dir[1]+4*dir[2];
398 		}
399 	}
400 	
401 	static bool[Point.length] indexToDirection(uint index) pure nothrow{
402 		static if(Point.length==2){
403 			return [index==1 || index==3, index>=2];
404 		}else static if(Point.length==3){
405 			bool i3=index==3;
406 			bool i7=index==7;
407 			return [index==1 || i3 || index==5 || i7, index==2 || i3 || index==6 || i7, index>=4];
408 		}
409 	}
410 	
411 }
412 import mutils.linalg.vec;
413 
414 unittest{
415 	import std.meta;
416 	alias vec2=Vec!(float, 2);
417 	alias vec3=Vec!(float, 3);
418 
419 	int numFound;
420 	int numOk;
421 	void test0(T)(ref T num){numFound++;numOk+=num.data==0;}
422 	void test1(T)(ref T num){numFound++;numOk+=num.data==1;}
423 	void test2(T)(ref T num){numFound++;numOk+=num.data==2;}
424 	void test3(T)(ref T num){numFound++;numOk+=num.data==3;}
425 	void test4(T)(ref T num){numFound++;numOk+=num.data==4;}
426 	void test5(T)(ref T num){numFound++;numOk+=num.data==5;}
427 	void test6(T)(ref T num){numFound++;numOk+=num.data==6;}
428 	void test7(T)(ref T num){numFound++;numOk+=num.data==7;}
429 	//Test Loose QuadTree
430 	{
431 		numFound=numOk=0;
432 		struct QuadTreeData1{
433 			float[2] pos;
434 			float radius;
435 			int data;
436 		}
437 		alias TestTree=SpatialTree!(2,QuadTreeData1,true);
438 		TestTree tree;
439 		tree.initialize();
440 		tree.add(vec2(-1,+1), QuadTreeData1(vec2(-1,+1),0.1,0));
441 		tree.add(vec2(+1,+1), QuadTreeData1(vec2(+1,+1),0.1,1));
442 		tree.add(vec2(+1,-1), QuadTreeData1(vec2(+1,-1),0.1,2));
443 		tree.add(vec2(-1,-1), QuadTreeData1(vec2(-1,-1),0.1,3));		
444 		
445 		tree.visitAllDataIn(&test0!QuadTreeData1,vec2(-10,0.1),vec2(-0.1, 10));
446 		tree.visitAllDataIn(&test1!QuadTreeData1,vec2(0.1,0.1),vec2(10, 10));
447 		tree.visitAllDataIn(&test2!QuadTreeData1,vec2(0.1,-10),vec2(10, -0.1));
448 		tree.visitAllDataIn(&test3!QuadTreeData1,vec2(-10,-10),vec2(-0.1, -0.1));
449 		assert(numFound==4);
450 		assert(numOk==4);
451 	}
452 	
453 	//Test QuadTree
454 	{
455 		numFound=numOk=0;
456 		struct QuadTreeData2{
457 			int data;
458 		}
459 		alias TestTree=SpatialTree!(2,QuadTreeData2,false);
460 		TestTree tree;
461 		tree.initialize();
462 		
463 		tree.add(vec2(-5,+5), QuadTreeData2(0));
464 		tree.add(vec2(+5,+5), QuadTreeData2(1));
465 		tree.add(vec2(+5,-5), QuadTreeData2(2));
466 		tree.add(vec2(-5,-5), QuadTreeData2(3));
467 		
468 		
469 		tree.visitAllDataIn(&test0!QuadTreeData2,vec2(-10,0.1),vec2(-0.1, 10));
470 		tree.visitAllDataIn(&test1!QuadTreeData2,vec2(0.1,0.1),vec2(10, 10));
471 		tree.visitAllDataIn(&test2!QuadTreeData2,vec2(0.1,-10),vec2(10, -0.1));
472 		tree.visitAllDataIn(&test3!QuadTreeData2,vec2(-10,-10),vec2(-0.1, -0.1));
473 		assert(numFound==4);
474 		assert(numOk==4);
475 	}
476 	
477 	//Test Loose OctTree
478 	{
479 		numFound=numOk=0;
480 		struct OctTreeData1{
481 			float[3] pos;
482 			float radius;
483 			int data;
484 		}
485 		alias TestTree=SpatialTree!(3,OctTreeData1,true);
486 		TestTree tree;
487 		tree.initialize();
488 		
489 		tree.add(vec3(-5,+5,+5), OctTreeData1(vec3(-5,+5,+5),0.1,0));
490 		tree.add(vec3(+5,+5,+5), OctTreeData1(vec3(+5,+5,+5),0.1,1));
491 		tree.add(vec3(+5,-5,+5), OctTreeData1(vec3(+5,-5,+5),0.1,2));
492 		tree.add(vec3(-5,-5,+5), OctTreeData1(vec3(-5,-5,+5),0.1,3));
493 		tree.add(vec3(-5,+5,-5), OctTreeData1(vec3(-5,+5,-5),0.1,4));
494 		tree.add(vec3(+5,+5,-5), OctTreeData1(vec3(+5,+5,-5),0.1,5));
495 		tree.add(vec3(+5,-5,-5), OctTreeData1(vec3(+5,-5,-5),0.1,6));
496 		tree.add(vec3(-5,-5,-5), OctTreeData1(vec3(-5,-5,-5),0.1,7));
497 		
498 		
499 		tree.visitAllDataIn(&test0!OctTreeData1,vec3(-10,0.1,0.1),vec3(-0.1, 10,10));
500 		tree.visitAllDataIn(&test1!OctTreeData1,vec3(0.1,0.1,0.1),vec3(10, 10,10));
501 		tree.visitAllDataIn(&test2!OctTreeData1,vec3(0.1,-10,0.1),vec3(10, -0.1,10));
502 		tree.visitAllDataIn(&test3!OctTreeData1,vec3(-10,-10,0.1),vec3(-0.1, -0.1,10));
503 		tree.visitAllDataIn(&test4!OctTreeData1,vec3(-10,0.1,-10),vec3(-0.1, 10,-0.1));
504 		tree.visitAllDataIn(&test5!OctTreeData1,vec3(0.1,0.1,-10),vec3(10, 10,-0.1));
505 		tree.visitAllDataIn(&test6!OctTreeData1,vec3(0.1,-10,-10),vec3(10, -0.1,-0.1));
506 		tree.visitAllDataIn(&test7!OctTreeData1,vec3(-10,-10,-10),vec3(-0.1, -0.1,-0.1));
507 		assert(numFound==8);
508 		assert(numOk==8);
509 	}
510 	
511 	
512 	//Test OctTree
513 	{
514 		numFound=numOk=0;
515 		struct OctTreeData2{
516 			int data;
517 		}
518 		alias TestTree=SpatialTree!(3,OctTreeData2,false);
519 		TestTree tree;
520 		tree.initialize();
521 		
522 		tree.add(vec3(-5,+5,+5), OctTreeData2(0));
523 		tree.add(vec3(+5,+5,+5), OctTreeData2(1));
524 		tree.add(vec3(+5,-5,+5), OctTreeData2(2));
525 		tree.add(vec3(-5,-5,+5), OctTreeData2(3));
526 		tree.add(vec3(-5,+5,-5), OctTreeData2(4));
527 		tree.add(vec3(+5,+5,-5), OctTreeData2(5));
528 		tree.add(vec3(+5,-5,-5), OctTreeData2(6));
529 		tree.add(vec3(-5,-5,-5), OctTreeData2(7));
530 		
531 		
532 		tree.visitAllDataIn(&test0!OctTreeData2,vec3(-10,0.1,0.1),vec3(-0.1, 10,10));
533 		tree.visitAllDataIn(&test1!OctTreeData2,vec3(0.1,0.1,0.1),vec3(10, 10,10));
534 		tree.visitAllDataIn(&test2!OctTreeData2,vec3(0.1,-10,0.1),vec3(10, -0.1,10));
535 		tree.visitAllDataIn(&test3!OctTreeData2,vec3(-10,-10,0.1),vec3(-0.1, -0.1,10));
536 		tree.visitAllDataIn(&test4!OctTreeData2,vec3(-10,0.1,-10),vec3(-0.1, 10,-0.1));
537 		tree.visitAllDataIn(&test5!OctTreeData2,vec3(0.1,0.1,-10),vec3(10, 10,-0.1));
538 		tree.visitAllDataIn(&test6!OctTreeData2,vec3(0.1,-10,-10),vec3(10, -0.1,-0.1));
539 		tree.visitAllDataIn(&test7!OctTreeData2,vec3(-10,-10,-10),vec3(-0.1, -0.1,-0.1));
540 		assert(numFound==8);
541 		assert(numOk==8);
542 	}
543 	
544 	
545 }
546