1 module mysql.database;
2 
3 public import std.variant;
4 import std.string;
5 import std.stdio;
6 
7 import core.vararg;
8 
9 interface Database {
10     /// Actually implements the query for the database. The query() method
11     /// below might be easier to use.
12     ResultSet queryImpl(string sql, Variant[] args...);
13 
14     /// Escapes data for inclusion into an sql string literal
15     string escape(string sqlData);
16 
17     /// query to start a transaction, only here because sqlite is apparently different in syntax...
18     void startTransaction();
19 
20     // FIXME: this would be better as a template, but can't because it is an interface
21 
22     /// Just executes a query. It supports placeholders for parameters
23     /// by using ? in the sql string. NOTE: it only accepts string, int, long, and null types.
24     /// Others will fail runtime asserts.
25     final ResultSet query(string sql, ...) {
26         Variant[] args;
27         foreach(arg; _arguments) {
28             string a;
29             if(arg == typeid(string) || arg == typeid(immutable(string)) || arg == typeid(const(string)))
30                 a = va_arg!string(_argptr);
31             else if (arg == typeid(int) || arg == typeid(immutable(int)) || arg == typeid(const(int))) {
32                 auto e = va_arg!int(_argptr);
33                 a = to!string(e);
34             } else if (arg == typeid(uint) || arg == typeid(immutable(uint)) || arg == typeid(const(uint))) {
35                 auto e = va_arg!uint(_argptr);
36                 a = to!string(e);
37             } else if (arg == typeid(immutable(char))) {
38                 auto e = va_arg!char(_argptr);
39                 a = to!string(e);
40             } else if (arg == typeid(long) || arg == typeid(const(long)) || arg == typeid(immutable(long))) {
41                 auto e = va_arg!long(_argptr);
42                 a = to!string(e);
43             } else if (arg == typeid(ulong) || arg == typeid(const(ulong)) || arg == typeid(immutable(ulong))) {
44                 auto e = va_arg!ulong(_argptr);
45                 a = to!string(e);
46             } else if (arg == typeid(null)) {
47                 a = null;
48             } else assert(0, "invalid type " ~ arg.toString() );
49 
50             args ~= Variant(a);
51         }
52 
53         return queryImpl(sql, args);
54     }
55 }
56 
57 /*
58 Ret queryOneColumn(Ret, string file = __FILE__, size_t line = __LINE__, T...)(Database db, string sql, T t) {
59     auto res = db.query(sql, t);
60     if(res.empty)
61         throw new Exception("no row in result", file, line);
62     auto row = res.front;
63     return to!Ret(row[0]);
64 }
65 */
66 
67 struct Query {
68     ResultSet result;
69     this(T...)(Database db, string sql, T t) if(T.length!=1 || !is(T[0]==Variant[])) {
70         result = db.query(sql, t);
71     }
72     // Version for dynamic generation of args: (Needs to be a template for coexistence with other constructor.
73     this(T...)(Database db, string sql, T args) if (T.length==1 && is(T[0] == Variant[])) {
74         result = db.queryImpl(sql, args);
75     }
76 
77     int opApply(T)(T dg) if(is(T == delegate)) {
78         import std.traits;
79         foreach(row; result) {
80             ParameterTypeTuple!dg tuple;
81 
82             foreach(i, item; tuple) {
83                 tuple[i] = to!(typeof(item))(row[i]);
84             }
85 
86             if(auto result = dg(tuple))
87                 return result;
88         }
89 
90         return 0;
91     }
92 }
93 
94 struct Row {
95     package string[] row;
96     package ResultSet resultSet;
97 
98     string opIndex(size_t idx, string file = __FILE__, int line = __LINE__) {
99         if(idx >= row.length)
100             throw new Exception(text("index ", idx, " is out of bounds on result"), file, line);
101         return row[idx];
102     }
103 
104     string opIndex(string name, string file = __FILE__, int line = __LINE__) {
105         auto idx = resultSet.getFieldIndex(name);
106         if(idx >= row.length)
107             throw new Exception(text("no field ", name, " in result"), file, line);
108         return row[idx];
109     }
110 
111     string toString() {
112         return to!string(row);
113     }
114 
115     string[string] toAA() {
116         string[string] a;
117 
118         string[] fn = resultSet.fieldNames();
119 
120         foreach(i, r; row)
121             a[fn[i]] = r;
122 
123         return a;
124     }
125 
126     int opApply(int delegate(ref string, ref string) dg) {
127         foreach(a, b; toAA())
128             mixin(yield("a, b"));
129 
130         return 0;
131     }
132 
133 
134 
135     string[] toStringArray() {
136         return row;
137     }
138 }
139 import std.conv;
140 
141 interface ResultSet {
142     // name for associative array to result index
143     int getFieldIndex(string field);
144     string[] fieldNames();
145 
146     // this is a range that can offer other ranges to access it
147     bool empty() @property;
148     Row front() @property;
149     void popFront() ;
150     int length() @property;
151 
152     /* deprecated */ final ResultSet byAssoc() { return this; }
153 }
154 
155 class DatabaseException : Exception {
156     this(string msg, string file = __FILE__, size_t line = __LINE__) {
157         super(msg, file, line);
158     }
159 }
160 
161 
162 
163 abstract class SqlBuilder { }
164 
165 /// WARNING: this is as susceptible to SQL injections as you would be writing it out by hand
166 class SelectBuilder : SqlBuilder {
167     string[] fields;
168     string table;
169     string[] joins;
170     string[] wheres;
171     string[] orderBys;
172     string[] groupBys;
173 
174     int limit;
175     int limitStart;
176 
177     Variant[string] vars;
178     void setVariable(T)(string name, T value) {
179         vars[name] = Variant(value);
180     }
181 
182     Database db;
183     this(Database db = null) {
184         this.db = db;
185     }
186 
187     /*
188         It would be nice to put variables right here in the builder
189 
190         ?name
191 
192         will prolly be the syntax, and we'll do a Variant[string] of them.
193 
194         Anything not translated here will of course be in the ending string too
195     */
196 
197     SelectBuilder cloned() {
198         auto s = new SelectBuilder(this.db);
199         s.fields = this.fields.dup;
200         s.table = this.table;
201         s.joins = this.joins.dup;
202         s.wheres = this.wheres.dup;
203         s.orderBys = this.orderBys.dup;
204         s.groupBys = this.groupBys.dup;
205         s.limit = this.limit;
206         s.limitStart = this.limitStart;
207 
208         foreach(k, v; this.vars)
209             s.vars[k] = v;
210 
211         return s;
212     }
213 
214     override string toString() {
215         string sql = "SELECT ";
216 
217         // the fields first
218         {
219             bool outputted = false;
220             foreach(field; fields) {
221                 if(outputted)
222                     sql ~= ", ";
223                 else
224                     outputted = true;
225 
226                 sql ~= field; // "`" ~ field ~ "`";
227             }
228         }
229 
230         sql ~= " FROM " ~ table;
231 
232         if(joins.length) {
233             foreach(join; joins)
234                 sql ~= " " ~ join;
235         }
236 
237         if(wheres.length) {
238             bool outputted = false;
239             sql ~= " WHERE ";
240             foreach(w; wheres) {
241                 if(outputted)
242                     sql ~= " AND ";
243                 else
244                     outputted = true;
245                 sql ~= "(" ~ w ~ ")";
246             }
247         }
248 
249         if(groupBys.length) {
250             bool outputted = false;
251             sql ~= " GROUP BY ";
252             foreach(o; groupBys) {
253                 if(outputted)
254                     sql ~= ", ";
255                 else
256                     outputted = true;
257                 sql ~= o;
258             }
259         }
260         
261         if(orderBys.length) {
262             bool outputted = false;
263             sql ~= " ORDER BY ";
264             foreach(o; orderBys) {
265                 if(outputted)
266                     sql ~= ", ";
267                 else
268                     outputted = true;
269                 sql ~= o;
270             }
271         }
272 
273         if(limit) {
274             sql ~= " LIMIT ";
275             if(limitStart)
276                 sql ~= to!string(limitStart) ~ ", ";
277             sql ~= to!string(limit);
278         }
279 
280         if(db is null)
281             return sql;
282 
283         return escapedVariants(db, sql, vars);
284     }
285 }
286 
287 
288 // /////////////////////sql//////////////////////////////////
289 
290 
291 // used in the internal placeholder thing
292 string toSql(Database db, Variant a) {
293     auto v = a.peek!(void*);
294     if(v && (*v is null))
295         return "NULL";
296     else {
297         string str = to!string(a);
298         return '\'' ~ db.escape(str) ~ '\'';
299     }
300 
301     assert(0);
302 }
303 
304 // just for convenience; "str".toSql(db);
305 string toSql(string s, Database db) {
306     if(s is null)
307         return "NULL";
308     return '\'' ~ db.escape(s) ~ '\'';
309 }
310 
311 string toSql(long s, Database db) {
312     return to!string(s);
313 }
314 
315 string escapedVariants(Database db, in string sql, Variant[string] t) {
316     if(t.keys.length <= 0 || sql.indexOf("?") == -1) {
317         return sql;
318     }
319 
320     string fixedup;
321     int currentStart = 0;
322 // FIXME: let's make ?? render as ? so we have some escaping capability
323     foreach(int i, dchar c; sql) {
324         if(c == '?') {
325             fixedup ~= sql[currentStart .. i];
326 
327             int idxStart = i + 1;
328             int idxLength;
329 
330             bool isFirst = true;
331 
332             while(idxStart + idxLength < sql.length) {
333                 char C = sql[idxStart + idxLength];
334 
335                 if((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') || C == '_' || (!isFirst && C >= '0' && C <= '9'))
336                     idxLength++;
337                 else
338                     break;
339 
340                 isFirst = false;
341             }
342 
343             auto idx = sql[idxStart .. idxStart + idxLength];
344 
345             if(idx in t) {
346                 fixedup ~= toSql(db, t[idx]);
347                 currentStart = idxStart + idxLength;
348             } else {
349                 // just leave it there, it might be done on another layer
350                 currentStart = i;
351             }
352         }
353     }
354 
355     fixedup ~= sql[currentStart .. $];
356 
357     return fixedup;
358 }
359 
360 // TODO: cut this out
361 /// Note: ?n params are zero based!
362 string escapedVariants(Database db, in string sql, Variant[] t) {
363 // FIXME: let's make ?? render as ? so we have some escaping capability
364     // if nothing to escape or nothing to escape with, don't bother
365     if(t.length > 0 && sql.indexOf("?") != -1) {
366         string fixedup;
367         int currentIndex;
368         int currentStart = 0;
369         foreach(int i, dchar c; sql) {
370             if(c == '?') {
371                 fixedup ~= sql[currentStart .. i];
372 
373                 int idx = -1;
374                 currentStart = i + 1;
375                 if((i + 1) < sql.length) {
376                     auto n = sql[i + 1];
377                     if(n >= '0' && n <= '9') {
378                         currentStart = i + 2;
379                         idx = n - '0';
380                     }
381                 }
382                 if(idx == -1) {
383                     idx = currentIndex;
384                     currentIndex++;
385                 }
386 
387                 if(idx < 0 || idx >= t.length)
388                     throw new Exception("SQL Parameter index is out of bounds: " ~ to!string(idx) ~ " at `"~sql[0 .. i]~"`");
389 
390                 fixedup ~= toSql(db, t[idx]);
391             }
392         }
393 
394         fixedup ~= sql[currentStart .. $];
395 
396         return fixedup;
397         /*
398         string fixedup;
399         int pos = 0;
400 
401 
402         void escAndAdd(string str, int q) {
403             fixedup ~= sql[pos..q] ~ '\'' ~ db.escape(str) ~ '\'';
404 
405         }
406 
407         foreach(a; t) {
408             int q = sql[pos..$].indexOf("?");
409             if(q == -1)
410                 break;
411             q += pos;
412 
413             auto v = a.peek!(void*);
414             if(v && (*v is null))
415                 fixedup  ~= sql[pos..q] ~ "NULL";
416             else {
417                 string str = to!string(a);
418                 escAndAdd(str, q);
419             }
420 
421             pos = q+1;
422         }
423 
424         fixedup ~= sql[pos..$];
425 
426         sql = fixedup;
427         */
428     }
429 
430     return sql;
431 }
432 
433 
434 enum UpdateOrInsertMode {
435     CheckForMe,
436     AlwaysUpdate,
437     AlwaysInsert
438 }
439 
440 
441 // BIG FIXME: this should really use prepared statements
442 int updateOrInsert(Database db, string table, string[string] values, string where, UpdateOrInsertMode mode = UpdateOrInsertMode.CheckForMe, string key = "id") {
443     bool insert = false;
444 
445     final switch(mode) {
446         case UpdateOrInsertMode.CheckForMe:
447             auto res = db.query("SELECT "~key~" FROM `"~db.escape(table)~"` WHERE " ~ where);
448             insert = res.empty;
449 
450         break;
451         case UpdateOrInsertMode.AlwaysInsert:
452             insert = true;
453         break;
454         case UpdateOrInsertMode.AlwaysUpdate:
455             insert = false;
456         break;
457     }
458 
459 
460     if(insert) {
461         string insertSql = "INSERT INTO `" ~ db.escape(table) ~ "` ";
462 
463         bool outputted = false;
464         string vs, cs;
465         foreach(column, value; values) {
466             if(column is null)
467                 continue;
468             if(outputted) {
469                 vs ~= ", ";
470                 cs ~= ", ";
471             } else
472                 outputted = true;
473 
474             //cs ~= "`" ~ db.escape(column) ~ "`";
475             cs ~= "`" ~ column ~ "`"; // FIXME: possible insecure
476             if(value is null)
477                 vs ~= "NULL";
478             else
479                 vs ~= "'" ~ db.escape(value) ~ "'";
480         }
481 
482         if(!outputted)
483             return 0;
484 
485 
486         insertSql ~= "(" ~ cs ~ ")";
487         insertSql ~= " VALUES ";
488         insertSql ~= "(" ~ vs ~ ")";
489 
490         db.query(insertSql);
491 
492         return 0; // db.lastInsertId;
493     } else {
494         string updateSql = "UPDATE `"~db.escape(table)~"` SET ";
495 
496         bool outputted = false;
497         foreach(column, value; values) {
498             if(column is null)
499                 continue;
500             if(outputted)
501                 updateSql ~= ", ";
502             else
503                 outputted = true;
504 
505             if(value is null)
506                 updateSql ~= "`" ~ db.escape(column) ~ "` = NULL";
507             else
508                 updateSql ~= "`" ~ db.escape(column) ~ "` = '" ~ db.escape(value) ~ "'";
509         }
510 
511         if(!outputted)
512             return 0;
513 
514         updateSql ~= " WHERE " ~ where;
515 
516         db.query(updateSql);
517         return 0;
518     }
519 }
520 
521 
522 
523 
524 
525 string fixupSqlForDataObjectUse(string sql, string[string] keyMapping = null) {
526 
527     string[] tableNames;
528 
529     string piece = sql;
530     sizediff_t idx;
531     while((idx = piece.indexOf("JOIN")) != -1) {
532         auto start = idx + 5;
533         auto i = start;
534         while(piece[i] != ' ' && piece[i] != '\n' && piece[i] != '\t' && piece[i] != ',')
535             i++;
536         auto end = i;
537 
538         tableNames ~= strip(piece[start..end]);
539 
540         piece = piece[end..$];
541     }
542 
543     idx = sql.indexOf("FROM");
544     if(idx != -1) {
545         auto start = idx + 5;
546         auto i = start;
547         start = i;
548         while(i < sql.length && !(sql[i] > 'A' && sql[i] <= 'Z')) // if not uppercase, except for A (for AS) to avoid SQL keywords (hack)
549             i++;
550 
551         auto from = sql[start..i];
552         auto pieces = from.split(",");
553         foreach(p; pieces) {
554             p = p.strip();
555             start = 0;
556             i = 0;
557             while(i < p.length && p[i] != ' ' && p[i] != '\n' && p[i] != '\t' && p[i] != ',')
558                 i++;
559 
560             tableNames ~= strip(p[start..i]);
561         }
562 
563         string sqlToAdd;
564         foreach(tbl; tableNames) {
565             if(tbl.length) {
566                 string keyName = "id";
567                 if(tbl in keyMapping)
568                     keyName = keyMapping[tbl];
569                 sqlToAdd ~= ", " ~ tbl ~ "." ~ keyName ~ " AS " ~ "id_from_" ~ tbl;
570             }
571         }
572 
573         sqlToAdd ~= " ";
574 
575         sql = sql[0..idx] ~ sqlToAdd ~ sql[idx..$];
576     }
577 
578     return sql;
579 }
580 
581 import mysql.data_object;
582 
583 /**
584     Given some SQL, it finds the CREATE TABLE
585     instruction for the given tableName.
586     (this is so it can find one entry from
587     a file with several SQL commands. But it
588     may break on a complex file, so try to only
589     feed it simple sql files.)
590 
591     From that, it pulls out the members to create a
592     simple struct based on it.
593 
594     It's not terribly smart, so it will probably
595     break on complex tables.
596 
597     Data types handled:
598         INTEGER, SMALLINT, MEDIUMINT -> D's int
599         TINYINT -> D's bool
600         BIGINT -> D's long
601         TEXT, VARCHAR -> D's string
602         FLOAT, DOUBLE -> D's double
603 
604     It also reads DEFAULT values to pass to D, except for NULL.
605     It ignores any length restrictions.
606 
607     Bugs:
608         Skips all constraints
609         Doesn't handle nullable fields, except with strings
610         It only handles SQL keywords if they are all caps
611 
612     This, when combined with SimpleDataObject!(),
613     can automatically create usable D classes from
614     SQL input.
615 */
616 struct StructFromCreateTable(string sql, string tableName) {
617     mixin(getCreateTable(sql, tableName));
618 }
619 
620 string getCreateTable(string sql, string tableName) {
621    skip:
622     while(readWord(sql) != "CREATE") {}
623 
624     assert(readWord(sql) == "TABLE");
625 
626     if(readWord(sql) != tableName)
627         goto skip;
628 
629     assert(readWord(sql) == "(");
630 
631     int state;
632     int parens;
633 
634     struct Field {
635         string name;
636         string type;
637         string defaultValue;
638     }
639     Field*[] fields;
640 
641     string word = readWord(sql);
642     Field* current = new Field(); // well, this is interesting... under new DMD, not using new breaks it in CTFE because it overwrites the one entry!
643     while(word != ")" || parens) {
644         if(word == ")") {
645             parens --;
646             word = readWord(sql);
647             continue;
648         }
649         if(word == "(") {
650             parens ++;
651             word = readWord(sql);
652             continue;
653         }
654         switch(state) {
655             default: assert(0);
656             case 0:
657                 if(word[0] >= 'A' && word[0] <= 'Z') {
658                 state = 4;
659                 break; // we want to skip this since it starts with a keyword (we hope)
660             }
661             current.name = word;
662             state = 1;
663             break;
664             case 1:
665                 current.type ~= word;
666             state = 2;
667             break;
668             case 2:
669                 if(word == "DEFAULT")
670                 state = 3;
671             else if (word == ",") {
672                 fields ~= current;
673                 current = new Field();
674                 state = 0; // next
675             }
676             break;
677             case 3:
678                 current.defaultValue = word;
679             state = 2; // back to skipping
680             break;
681             case 4:
682                 if(word == ",")
683                 state = 0;
684         }
685 
686         word = readWord(sql);
687     }
688 
689     if(current.name !is null)
690         fields ~= current;
691 
692 
693     string structCode;
694     foreach(field; fields) {
695         structCode ~= "\t";
696 
697         switch(field.type) {
698             case "INTEGER":
699             case "SMALLINT":
700             case "MEDIUMINT":
701                 structCode ~= "int";
702             break;
703             case "BOOLEAN":
704             case "TINYINT":
705                 structCode ~= "bool";
706             break;
707             case "BIGINT":
708                 structCode ~= "long";
709             break;
710             case "CHAR":
711             case "char":
712             case "VARCHAR":
713             case "varchar":
714             case "TEXT":
715             case "text":
716                 structCode ~= "string";
717             break;
718             case "FLOAT":
719             case "DOUBLE":
720                 structCode ~= "double";
721             break;
722             default:
723                 assert(0, "unknown type " ~ field.type ~ " for " ~ field.name);
724         }
725 
726         structCode ~= " ";
727         structCode ~= field.name;
728 
729         if(field.defaultValue !is null) {
730             structCode ~= " = " ~ field.defaultValue;
731         }
732 
733         structCode ~= ";\n";
734     }
735 
736     return structCode;
737 }
738 
739 string readWord(ref string src) {
740    reset:
741     while(src[0] == ' ' || src[0] == '\t' || src[0] == '\n')
742         src = src[1..$];
743     if(src.length >= 2 && src[0] == '-' && src[1] == '-') { // a comment, skip it
744         while(src[0] != '\n')
745             src = src[1..$];
746         goto reset;
747     }
748 
749     int start, pos;
750     if(src[0] == '`') {
751         src = src[1..$];
752         while(src[pos] != '`')
753             pos++;
754         goto gotit;
755     }
756 
757 
758     while(
759         (src[pos] >= 'A' && src[pos] <= 'Z')
760         ||
761         (src[pos] >= 'a' && src[pos] <= 'z')
762         ||
763         (src[pos] >= '0' && src[pos] <= '9')
764         ||
765         src[pos] == '_'
766     )
767         pos++;
768     gotit:
769     if(pos == 0)
770         pos = 1;
771 
772     string tmp = src[0..pos];
773 
774     if(src[pos] == '`')
775         pos++; // skip the ending quote;
776 
777     src = src[pos..$];
778 
779     return tmp;
780 }
781 
782 /// Combines StructFromCreateTable and SimpleDataObject into a one-stop template.
783 /// alias DataObjectFromSqlCreateTable(import("file.sql"), "my_table") MyTable;
784 template DataObjectFromSqlCreateTable(string sql, string tableName) {
785     alias SimpleDataObject!(tableName, StructFromCreateTable!(sql, tableName)) DataObjectFromSqlCreateTable;
786 }
787 
788 /+
789 class MyDataObject : DataObject {
790     this() {
791         super(new Database("localhost", "root", "pass", "social"), null);
792     }
793 
794     mixin StrictDataObject!();
795 
796     mixin(DataObjectField!(int, "users", "id"));
797 }
798 
799 void main() {
800     auto a = new MyDataObject;
801 
802     a.fields["id"] = "10";
803 
804     a.id = 34;
805 
806     a.commitChanges;
807 }
808 +/
809 
810 /*
811 alias DataObjectFromSqlCreateTable!(import("db.sql"), "users") Test;
812 
813 void main() {
814     auto a = new Test(null);
815 
816     a.cool = "way";
817     a.value = 100;
818 }
819 */
820 
821 void typeinfoBugWorkaround() {
822     assert(0, to!string(typeid(immutable(char[])[immutable(char)[]])));
823 }