1 module mysql.query_interface;
2 
3 import mysql.mysql;
4 
5 import std.variant;
6 import std.string;
7 import std.conv;
8 import core.vararg;
9 
10 class QueryInterface {
11     
12     /// Just executes a query. It supports placeholders for parameters
13     /// by using ? in the sql string. NOTE: it only accepts string, int, long, byte, and null types.
14     /// Others will fail runtime asserts.
15 
16     static string makeQuery(Mysql db, string sql, ...) {
17         Variant[] args;
18         for(int i = 0; i < _arguments.length; i++) {
19         //foreach(arg; _arguments) {
20             auto arg = _arguments[i];
21             string a;
22             // STRING
23             if(arg == typeid(string) || arg == typeid(immutable(string)) || arg == typeid(const(string))) {
24                 a = va_arg!string(_argptr);
25             // INT
26             } else if (arg == typeid(int) || arg == typeid(immutable(int)) || arg == typeid(const(int))) {
27                 auto e = va_arg!int(_argptr);
28                 a = to!string(e);
29             // UINT
30             } else if (arg == typeid(uint) || arg == typeid(immutable(uint)) || arg == typeid(const(uint))) {
31                 auto e = va_arg!uint(_argptr);
32                 a = to!string(e);
33             // CHAR
34             } else if (arg == typeid(immutable(char))) {
35                 auto e = va_arg!char(_argptr);
36                 a = to!string(e);
37             // LONG
38             } else if (arg == typeid(long) || arg == typeid(const(long)) || arg == typeid(immutable(long))) {
39                 auto e = va_arg!long(_argptr);
40                 a = to!string(e);
41             // ULONG
42             } else if (arg == typeid(ulong) || arg == typeid(const(ulong)) || arg == typeid(immutable(ulong))) {
43                 auto e = va_arg!ulong(_argptr);
44                 a = to!string(e);
45             // UBYTE
46             } else if (arg == typeid(ubyte) || arg == typeid(const(ubyte)) || arg == typeid(immutable(ubyte))) {
47                 auto e = va_arg!ubyte(_argptr);
48                 a = to!string(e);
49             // BYTE
50             } else if (arg == typeid(byte) || arg == typeid(const(byte)) || arg == typeid(immutable(byte))) {
51                 auto e = va_arg!byte(_argptr);
52                 a = to!string(e);
53             // SHORT
54             } else if (arg == typeid(short) || arg == typeid(const(short)) || arg == typeid(immutable(short))) {
55                 auto e = va_arg!short(_argptr);
56                 a = to!string(e);
57             // USHORT
58             } else if (arg == typeid(ushort) || arg == typeid(const(ushort)) || arg == typeid(immutable(ushort))) {
59                 auto e = va_arg!ushort(_argptr);
60                 a = to!string(e);
61             // FLOAT
62             } else if (arg == typeid(float) || arg == typeid(const(float)) || arg == typeid(immutable(float))) {
63                 auto e = va_arg!float(_argptr);
64                 a = to!string(e);
65             // DOUBLE
66             } else if (arg == typeid(double) || arg == typeid(const(double)) || arg == typeid(immutable(double))) {
67                 auto e = va_arg!double(_argptr);
68                 a = to!string(e);
69             // REAL
70             } else if (arg == typeid(real) || arg == typeid(const(real)) || arg == typeid(immutable(real))) {
71                 auto e = va_arg!real(_argptr);
72                 a = to!string(e);
73             // ARRAYS
74             // INT[]
75             } else if (arg == typeid(int[]) || arg == typeid(immutable(int[])) || arg == typeid(const(int[]))) {
76                 auto e = va_arg!(int[])(_argptr);
77                 a = to!(string[])(e).join(", ");
78             // UINT[]
79             } else if (arg == typeid(uint[]) || arg == typeid(immutable(uint[])) || arg == typeid(const(uint[]))) {
80                 auto e = va_arg!(uint[])(_argptr);
81                 a = to!(string[])(e).join(", ");
82             // LONG[]
83             } else if (arg == typeid(long[]) || arg == typeid(immutable(long[])) || arg == typeid(const(long[]))) {
84                 auto e = va_arg!(long[])(_argptr);
85                 a = to!(string[])(e).join(", ");
86             // ULONG[]
87             } else if (arg == typeid(ulong[]) || arg == typeid(immutable(ulong[])) || arg == typeid(const(ulong[]))) {
88                 auto e = va_arg!(ulong[])(_argptr);
89                 a = to!(string[])(e).join(", ");
90             // BYTE[]
91             } else if (arg == typeid(byte[]) || arg == typeid(immutable(byte[])) || arg == typeid(const(byte[]))) {
92                 auto e = va_arg!(byte[])(_argptr);
93                 a = to!(string[])(e).join(", ");
94             // UBYTE[]
95             } else if (arg == typeid(ubyte[]) || arg == typeid(immutable(ubyte[])) || arg == typeid(const(ubyte[]))) {
96                 auto e = va_arg!(ubyte[])(_argptr);
97                 a = to!(string[])(e).join(", ");
98             // SHORT[]
99             } else if (arg == typeid(short[]) || arg == typeid(immutable(short[])) || arg == typeid(const(short[]))) {
100                 auto e = va_arg!(short[])(_argptr);
101                 a = to!(string[])(e).join(", ");
102             // USHORT[]
103             } else if (arg == typeid(ushort[]) || arg == typeid(immutable(ushort[])) || arg == typeid(const(ushort[]))) {
104                 auto e = va_arg!(ushort[])(_argptr);
105                 a = to!(string[])(e).join(", ");
106             // FLOAT[]
107             } else if (arg == typeid(float[]) || arg == typeid(immutable(float[])) || arg == typeid(const(float[]))) {
108                 auto e = va_arg!(float[])(_argptr);
109                 a = to!(string[])(e).join(", ");
110             // DOUBLE[]
111             } else if (arg == typeid(double[]) || arg == typeid(immutable(double[])) || arg == typeid(const(double[]))) {
112                 auto e = va_arg!(double[])(_argptr);
113                 a = to!(string[])(e).join(", ");
114             // REAL[]
115             } else if (arg == typeid(real[]) || arg == typeid(immutable(real[])) || arg == typeid(const(real[]))) {
116                 auto e = va_arg!(ushort[])(_argptr);
117                 a = to!(string[])(e).join(", ");
118             // STRING[]
119             } else if (arg == typeid(string[]) || arg == typeid(immutable(string[])) || arg == typeid(const(string[]))) {
120                 auto e = va_arg!(string[])(_argptr);
121                 string[] escaped;
122                 foreach(el; e) escaped ~= "'" ~ db.escape(el) ~ "'";
123                 a = escaped.join(", ");
124             // NULL
125             } else if (arg == typeid(null)) {
126                 a = null;
127             } else assert(0, "invalid type " ~ arg.toString() );
128 
129             args ~= Variant(a);
130         }
131 
132         return escapedVariants(db, sql, args);
133     }
134 
135     /*
136     static string escaped(T...)(string sql, T t) {
137         static if(t.length > 0) {
138             string fixedup;
139             int pos = 0;
140 
141 
142             void escAndAdd(string str, int q) {
143                 ubyte[] buffer = new ubyte[str.length * 2 + 1];
144                 buffer.length = mysql_real_escape_string(mysql, buffer.ptr, cast(cstring) str.ptr, str.length);
145 
146                 fixedup ~= sql[pos..q] ~ '\'' ~ cast(string) buffer ~ '\'';
147 
148             }
149 
150             foreach(a; t) {
151                 int q = sql[pos..$].indexOf("?");
152                 if(q == -1)
153                     break;
154                 q += pos;
155 
156                 static if(__traits(compiles, t is null)) {
157                     if(t is null)
158                         fixedup  ~= sql[pos..q] ~ "NULL";
159                     else
160                         escAndAdd(to!string(*a), q);
161                 } else {
162                     string str = to!string(a);
163                     escAndAdd(str, q);
164                 }
165 
166                 pos = q+1;
167             }
168 
169             fixedup ~= sql[pos..$];
170 
171             sql = fixedup;
172 
173             //writefln("\n\nExecuting sql: %s", sql);
174         }
175 
176         return sql;
177     }
178     */
179 
180     // used in the internal placeholder thing
181     static string toSql(Variant a, Mysql db) {
182         auto v = a.peek!(void*);
183         if(v && (*v is null))
184             return "NULL";
185         else {
186             string str = to!string(a);
187             return '\'' ~ db.escape(str) ~ '\'';
188         }
189 
190         assert(0);
191     }
192 
193     // just for convenience; "str".toSql(db);
194     static string toSql(string s, Mysql db) {
195         if(s is null) return "NULL";
196         return '\'' ~ db.escape(s) ~ '\'';
197     }
198 
199     static string toSql(long s, Mysql db) {
200         return to!string(s);
201     }
202 
203     static string toSqlName(string s, Mysql db) {
204         if(s is null) return "NULL";
205         return db.escape(s);
206     }
207 
208     static string toSqlName(long s, Mysql db) {
209         return toSql(s, db);
210     }
211 
212     static string toSqlName(Variant a, Mysql db) {
213         auto v = a.peek!(void*);
214         if(v && (*v is null))
215             return "NULL";
216         else {
217             string str = to!string(a);
218             return db.escape(str);
219         }
220 
221         assert(0);
222     }
223 
224     static string toSqlArray(long s, Mysql db) {
225         return toSql(s, db);
226     }
227 
228     static string toSqlArray(Variant a, Mysql db) {
229         auto v = a.peek!(void*);
230         if(v && (*v is null))
231             return "NULL";
232         else {
233             return to!string(a);
234         }
235 
236         assert(0);
237     }
238 
239     static string escapedVariants(Mysql db, in string sql, Variant[string] t) {
240         if(t.keys.length <= 0 || sql.indexOf("?") == -1) {
241             return sql;
242         }
243 
244         string fixedup;
245         int currentStart = 0;
246     // FIXME: let's make ?? render as ? so we have some escaping capability
247         foreach(int i, dchar c; sql) {
248             if (c == '?') {
249                 fixedup ~= sql[currentStart .. i];
250 
251                 int idxStart = i + 1;
252                 int idxLength;
253 
254                 bool isFirst = true;
255 
256                 while(idxStart + idxLength < sql.length) {
257                     char C = sql[idxStart + idxLength];
258 
259                     if((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') || C == '_' || (!isFirst && C >= '0' && C <= '9'))
260                         idxLength++;
261                     else
262                         break;
263 
264                     isFirst = false;
265                 }
266 
267                 auto idx = sql[idxStart .. idxStart + idxLength];
268 
269                 if(idx in t) {
270                     fixedup ~= toSql(t[idx], db);
271                     currentStart = idxStart + idxLength;
272                 } else {
273                     // just leave it there, it might be done on another layer
274                     currentStart = i;
275                 }
276             }
277         }
278 
279         fixedup ~= sql[currentStart .. $];
280 
281         return fixedup;
282     }
283 
284     /// Note: ?n params are zero based!
285     static string escapedVariants(Mysql db, in string sql, Variant[] t) {
286         // FIXME: let's make ?? render as ? so we have some escaping capability
287         // if nothing to escape or nothing to escape with, don't bother
288         if (t.length > 0 && sql.indexOf("?") != -1) {
289             string fixedup;
290             int currentIndex;
291             int currentStart = 0;
292             foreach (int i, dchar c; sql) {
293                 if (c == '?') {
294                     fixedup ~= sql[currentStart .. i];
295 
296                     int idx = -1;
297                     currentStart = i + 1;
298                     if((i + 1) < sql.length) {
299                         auto n = sql[i + 1];
300                         if(n >= '0' && n <= '9') {
301                             currentStart = i + 2;
302                             idx = n - '0';
303                         }
304                     }
305                     if(idx == -1) {
306                         idx = currentIndex;
307                         currentIndex++;
308                     }
309 
310                     if(idx < 0 || idx >= t.length)
311                         throw new Exception("SQL Parameter index is out of bounds: " ~ to!string(idx) ~ " at `" ~ sql[0 .. i] ~ "`");
312 
313                     if (sql[i - 1] == '`' && sql[i + 1] == '`') {
314                         fixedup ~= toSqlName(t[idx], db);
315                     } else if (sql[i - 1] == '(' && sql[i + 1] == ')') {
316                         fixedup ~= toSqlArray(t[idx], db);
317                     } else {
318                         fixedup ~= toSql(t[idx], db);
319                     }
320                 }
321             }
322 
323             fixedup ~= sql[currentStart .. $];
324 
325             return fixedup;
326         }
327 
328         return sql;
329     }
330 }