1 module mutils.safe_union;
2 
3 import std.algorithm : max;
4 import std.conv : to;
5 import std.meta : staticIndexOf;
6 import std.traits : isArray,ForeachType,hasMember,ReturnType,Parameters;
7 
8 /**
9  * Union of ConTypes... 
10  * Ensures correct access with assert
11  */
12 struct SafeUnion(bool makeFirstParDefaultOne, ConTypes...) {
13 	alias FromTypes=ConTypes;
14 	static assert(FromTypes.length>0,"Union has to have members.");
15 
16 	mixin(getCode!(FromTypes));
17 	//enum Types{...}    //from mixin
18 	alias Types=TypesM;// alias used to give better autocompletion in IDE-s
19 
20 	Types currentType=(makeFirstParDefaultOne)?Types._e_0:Types.none;
21 
22 	/**
23 	 * Constuctor supporting direcs assigment of Type
24 	 */
25 	this(T)(T obj){
26 		static assert(properType!T,"Given Type is not present in union");
27 		set(obj);
28 	}  
29 	void opAssign(SafeUnion!(makeFirstParDefaultOne, ConTypes) obj){
30 		this.tupleof=obj.tupleof;
31 	}
32 	//void opAssign(this);
33 	void opAssign(T)(T obj){
34 		static assert(properType!T,"Given Type is not present in union");
35 		set(obj);
36 	}
37 
38 	/**
39 	 * returns given type with check
40 	 */
41 	@nogc nothrow auto get(T)(){
42 		static assert(properType!T,"Given Type is not present in union");
43 		foreach(i,Type;FromTypes){
44 			static if(is(Type==T)){
45 				assert(currentType==i,"Got type which is not currently bound.");
46 				mixin("return &_"~i.to!string~";");
47 			}
48 		}
49 		assert(false);
50 	}
51 
52 	/**
53 	 * Returns enum value for Type
54 	 */
55 	@nogc nothrow bool isType(T)(){
56 		static assert(properType!T,"Given Type is not present in union");
57 		bool ok=false;
58 		foreach(i,Type;FromTypes){
59 			static if(is(Type==T)){
60 				Types type=cast(Types)i;
61 				if(currentType==type){
62 					ok=true;
63 				}
64 			}
65 		}
66 		return ok;
67 	}
68 
69 	/**
70 	 * Returns enum value for Type
71 	 */
72 	static Types getEnum(T)(){
73 		static assert(properType!T,"Given Type is not present in union");
74 		foreach(i,Type;FromTypes){
75 			static if(is(Type==T)){
76 				return cast(Types)i;
77 			}
78 		}
79 	}
80 
81 	/**
82 	 * Sets given Type
83 	 */
84 	@nogc nothrow auto  set(T)(T obj){
85 		static assert(properType!T,"Given Type is not present in union");
86 		foreach(i,Type;FromTypes){
87 			static if(is(Type==T)){
88 				currentType=cast(Types)i;
89 				mixin("_"~i.to!string~"=obj;");
90 			}
91 		}
92 	}
93 	
94 	auto ref apply(alias fun)() {
95 		switch(currentType){
96 			mixin(getCaseCode("return fun(_%1$s);")); 			
97 			
98 			default:
99 				assert(0);
100 		}
101 	}
102 
103 
104 
105 	import mutils.serializer.binary;
106 	/**
107 	 * Support for serialization
108 	 */
109 	void customSerialize(Load load, Serializer, ContainerOrSlice)(Serializer serializer,ref ContainerOrSlice con){
110 		serializer.serialize!(load)(currentType,con);
111 		final switch(currentType){
112 			mixin(getCaseCode("serializer.serialize!(load)(_%1$s,con);break;")); 
113 			case Types.none:
114 				break;
115 		}
116 	}
117 
118 	import std.range:put;
119 	import std.format:FormatSpec,formatValue;
120 	/**
121 	 * Preety print
122 	 */
123 	void toString(scope void delegate(const(char)[]) sink, FormatSpec!char fmt)
124 	{
125 		put(sink, "SafeUnion(");
126 		
127 		final switch(currentType){
128 			mixin(getCaseCode("formatValue(sink, _%1$s, fmt);break;")); 
129 			case Types.none:
130 				put(sink, "none");
131 				break;
132 		}
133 		
134 		put(sink, ")");
135 	}
136 
137 	/**
138 	 * Checks if opDispatch supports given function
139 	 */
140 	static bool checkOpDispach(string funcName)(){
141 		bool ok=true;	
142 		foreach(Type;FromTypes){
143 			ok=ok && hasMember!(Type, funcName);
144 		}
145 		return ok;
146 	}  
147 
148 	/**
149 	 * Forwards call to union member
150 	 * Works only if all union members has this function and this function has the same return type and parameter types
151 	 * Can not be made opDispatch because it somehow breakes hasMember trait
152 	 */
153 	auto call(string funcName, Args...)(auto ref Args args)
154 		if(checkOpDispach!(funcName) )	
155 	{		
156 		mixin("alias CompareReturnType=ReturnType!(FromTypes[0]."~funcName~");");
157 		mixin("alias CompareParametersTypes=Parameters!(FromTypes[0]."~funcName~");");
158 		foreach(Type;FromTypes){
159 			mixin("enum bool typeOk=is(ReturnType!(Type."~funcName~")==CompareReturnType);");
160 			mixin("enum bool parametersOk=is(Parameters!(Type."~funcName~")==CompareParametersTypes);");
161 			static assert(typeOk,"Return type "~CompareReturnType.stringof~" of '"~funcName~"' has to be the same in every union member.");
162 			static assert(parametersOk,"Parameter types "~CompareParametersTypes.stringof~" of '"~funcName~"' have to be the same in every union member.");
163 		}
164 		switch(currentType){
165 			mixin(getCaseCode("return _%1$s."~funcName~"(args);"));
166 			default:
167 				assert(0);
168 		}
169 	}
170 package: 
171 
172 	/** 
173 	 * Generates cases for switch with code, use _%1$s to place your var
174 	 */
175 	private static string getCaseCode(string code){
176 		string str;
177 		foreach(uint i,type;FromTypes){
178 			import std.format;
179 			string istr=i.to!string;
180 			str~="case Types._e_"~istr~":";
181 			str~=format(code,istr);
182 		}
183 		return str;
184 	}	
185 	
186 	/** 
187 	 * Generates enum,and union with given FromTypes
188 	 */
189 	private static string getCode(FromTypes...)(){
190 		string codeEnum="enum TypesM:ubyte{\n";
191 		string code="private union{\n";
192 		foreach(uint i,type;FromTypes){
193 			string istr=i.to!string;
194 			string typeName=type.stringof;
195 			string enumName="_e_"~istr;
196 			string valueName="_"~istr;
197 			codeEnum~=enumName~"="~istr~",\n";
198 			code~="FromTypes["~istr~"] "~valueName~";\n";
199 			
200 			
201 		}
202 		codeEnum~="none\n}\n";
203 		return codeEnum~code~"}\n";
204 	}
205 
206 	/**
207 	 *  Checks if Type is in union Types
208 	 */
209 	private static  bool properType(T)(){
210 		return staticIndexOf!(T,FromTypes)!=-1;
211 	}
212 }
213 /// Example Usage
214 unittest{
215 	struct Triangle{		
216 		int add(int a){
217 			return a+10;
218 		}
219 	}
220 	struct Rectangle {
221 		int add(int a){
222 			return a+100;
223 		}
224 	}
225 	static uint strangeID(T)(T obj){
226 		static if(is(T==Triangle)){
227 			return 123;
228 		}else static if(is(T==Rectangle)){
229 			return 14342;			
230 		}else{
231 			assert(0);
232 		}
233 	}
234 	alias Shape=SafeUnion!(false, Triangle,Rectangle);
235 	Shape shp;
236 	shp.set(Triangle());
237 	assert(shp.isType!Triangle);
238 	assert(!shp.isType!Rectangle);
239 	assert(shp.call!("add")(6)==16);//Better error messages 
240 	assert(shp.apply!strangeID==123);
241 	//shp.get!(Rectangle);//Crash
242 	shp.set(Rectangle());
243 	assert(shp.call!("add")(6)==106);
244 	assert(shp.apply!strangeID==14342);
245 	shp.currentType=shp.Types.none;
246 	//shp.apply!strangeID;//Crash
247 	//shp.add(6);//Crash
248 	final switch(shp.currentType){
249 		case shp.getEnum!Triangle:
250 			break;
251 		case Shape.getEnum!Rectangle:
252 			break;
253 		case Shape.Types.none:
254 			break;
255 	}
256 }