1 module mysql.database; 2 3 public import std.variant; 4 import std.string; 5 import std.stdio; 6 7 import core.vararg; 8 9 interface Database { 10 /// Actually implements the query for the database. The query() method 11 /// below might be easier to use. 12 ResultSet queryImpl(string sql, Variant[] args...); 13 14 /// Escapes data for inclusion into an sql string literal 15 string escape(string sqlData); 16 17 /// query to start a transaction, only here because sqlite is apparently different in syntax... 18 void startTransaction(); 19 20 // FIXME: this would be better as a template, but can't because it is an interface 21 22 /// Just executes a query. It supports placeholders for parameters 23 /// by using ? in the sql string. NOTE: it only accepts string, int, long, and null types. 24 /// Others will fail runtime asserts. 25 final ResultSet query(string sql, ...) { 26 Variant[] args; 27 foreach(arg; _arguments) { 28 string a; 29 if(arg == typeid(string) || arg == typeid(immutable(string)) || arg == typeid(const(string))) 30 a = va_arg!string(_argptr); 31 else if (arg == typeid(int) || arg == typeid(immutable(int)) || arg == typeid(const(int))) { 32 auto e = va_arg!int(_argptr); 33 a = to!string(e); 34 } else if (arg == typeid(uint) || arg == typeid(immutable(uint)) || arg == typeid(const(uint))) { 35 auto e = va_arg!uint(_argptr); 36 a = to!string(e); 37 } else if (arg == typeid(immutable(char))) { 38 auto e = va_arg!char(_argptr); 39 a = to!string(e); 40 } else if (arg == typeid(long) || arg == typeid(const(long)) || arg == typeid(immutable(long))) { 41 auto e = va_arg!long(_argptr); 42 a = to!string(e); 43 } else if (arg == typeid(ulong) || arg == typeid(const(ulong)) || arg == typeid(immutable(ulong))) { 44 auto e = va_arg!ulong(_argptr); 45 a = to!string(e); 46 } else if (arg == typeid(null)) { 47 a = null; 48 } else assert(0, "invalid type " ~ arg.toString() ); 49 50 args ~= Variant(a); 51 } 52 53 return queryImpl(sql, args); 54 } 55 } 56 57 /* 58 Ret queryOneColumn(Ret, string file = __FILE__, size_t line = __LINE__, T...)(Database db, string sql, T t) { 59 auto res = db.query(sql, t); 60 if(res.empty) 61 throw new Exception("no row in result", file, line); 62 auto row = res.front; 63 return to!Ret(row[0]); 64 } 65 */ 66 67 struct Query { 68 ResultSet result; 69 this(T...)(Database db, string sql, T t) if(T.length!=1 || !is(T[0]==Variant[])) { 70 result = db.query(sql, t); 71 } 72 // Version for dynamic generation of args: (Needs to be a template for coexistence with other constructor. 73 this(T...)(Database db, string sql, T args) if (T.length==1 && is(T[0] == Variant[])) { 74 result = db.queryImpl(sql, args); 75 } 76 77 int opApply(T)(T dg) if(is(T == delegate)) { 78 import std.traits; 79 foreach(row; result) { 80 ParameterTypeTuple!dg tuple; 81 82 foreach(i, item; tuple) { 83 tuple[i] = to!(typeof(item))(row[i]); 84 } 85 86 if(auto result = dg(tuple)) 87 return result; 88 } 89 90 return 0; 91 } 92 } 93 94 struct Row { 95 package string[] row; 96 package ResultSet resultSet; 97 98 string opIndex(size_t idx, string file = __FILE__, int line = __LINE__) { 99 if(idx >= row.length) 100 throw new Exception(text("index ", idx, " is out of bounds on result"), file, line); 101 return row[idx]; 102 } 103 104 string opIndex(string name, string file = __FILE__, int line = __LINE__) { 105 auto idx = resultSet.getFieldIndex(name); 106 if(idx >= row.length) 107 throw new Exception(text("no field ", name, " in result"), file, line); 108 return row[idx]; 109 } 110 111 string toString() { 112 return to!string(row); 113 } 114 115 string[string] toAA() { 116 string[string] a; 117 118 string[] fn = resultSet.fieldNames(); 119 120 foreach(i, r; row) 121 a[fn[i]] = r; 122 123 return a; 124 } 125 126 int opApply(int delegate(ref string, ref string) dg) { 127 foreach(a, b; toAA()) 128 mixin(yield("a, b")); 129 130 return 0; 131 } 132 133 134 135 string[] toStringArray() { 136 return row; 137 } 138 } 139 import std.conv; 140 141 interface ResultSet { 142 // name for associative array to result index 143 int getFieldIndex(string field); 144 string[] fieldNames(); 145 146 // this is a range that can offer other ranges to access it 147 bool empty() @property; 148 Row front() @property; 149 void popFront() ; 150 int length() @property; 151 152 /* deprecated */ final ResultSet byAssoc() { return this; } 153 } 154 155 class DatabaseException : Exception { 156 this(string msg, string file = __FILE__, size_t line = __LINE__) { 157 super(msg, file, line); 158 } 159 } 160 161 162 163 abstract class SqlBuilder { } 164 165 /// WARNING: this is as susceptible to SQL injections as you would be writing it out by hand 166 class SelectBuilder : SqlBuilder { 167 string[] fields; 168 string table; 169 string[] joins; 170 string[] wheres; 171 string[] orderBys; 172 string[] groupBys; 173 174 int limit; 175 int limitStart; 176 177 Variant[string] vars; 178 void setVariable(T)(string name, T value) { 179 vars[name] = Variant(value); 180 } 181 182 Database db; 183 this(Database db = null) { 184 this.db = db; 185 } 186 187 /* 188 It would be nice to put variables right here in the builder 189 190 ?name 191 192 will prolly be the syntax, and we'll do a Variant[string] of them. 193 194 Anything not translated here will of course be in the ending string too 195 */ 196 197 SelectBuilder cloned() { 198 auto s = new SelectBuilder(this.db); 199 s.fields = this.fields.dup; 200 s.table = this.table; 201 s.joins = this.joins.dup; 202 s.wheres = this.wheres.dup; 203 s.orderBys = this.orderBys.dup; 204 s.groupBys = this.groupBys.dup; 205 s.limit = this.limit; 206 s.limitStart = this.limitStart; 207 208 foreach(k, v; this.vars) 209 s.vars[k] = v; 210 211 return s; 212 } 213 214 override string toString() { 215 string sql = "SELECT "; 216 217 // the fields first 218 { 219 bool outputted = false; 220 foreach(field; fields) { 221 if(outputted) 222 sql ~= ", "; 223 else 224 outputted = true; 225 226 sql ~= field; // "`" ~ field ~ "`"; 227 } 228 } 229 230 sql ~= " FROM " ~ table; 231 232 if(joins.length) { 233 foreach(join; joins) 234 sql ~= " " ~ join; 235 } 236 237 if(wheres.length) { 238 bool outputted = false; 239 sql ~= " WHERE "; 240 foreach(w; wheres) { 241 if(outputted) 242 sql ~= " AND "; 243 else 244 outputted = true; 245 sql ~= "(" ~ w ~ ")"; 246 } 247 } 248 249 if(groupBys.length) { 250 bool outputted = false; 251 sql ~= " GROUP BY "; 252 foreach(o; groupBys) { 253 if(outputted) 254 sql ~= ", "; 255 else 256 outputted = true; 257 sql ~= o; 258 } 259 } 260 261 if(orderBys.length) { 262 bool outputted = false; 263 sql ~= " ORDER BY "; 264 foreach(o; orderBys) { 265 if(outputted) 266 sql ~= ", "; 267 else 268 outputted = true; 269 sql ~= o; 270 } 271 } 272 273 if(limit) { 274 sql ~= " LIMIT "; 275 if(limitStart) 276 sql ~= to!string(limitStart) ~ ", "; 277 sql ~= to!string(limit); 278 } 279 280 if(db is null) 281 return sql; 282 283 return escapedVariants(db, sql, vars); 284 } 285 } 286 287 288 // /////////////////////sql////////////////////////////////// 289 290 291 // used in the internal placeholder thing 292 string toSql(Database db, Variant a) { 293 auto v = a.peek!(void*); 294 if(v && (*v is null)) 295 return "NULL"; 296 else { 297 string str = to!string(a); 298 return '\'' ~ db.escape(str) ~ '\''; 299 } 300 301 assert(0); 302 } 303 304 // just for convenience; "str".toSql(db); 305 string toSql(string s, Database db) { 306 if(s is null) 307 return "NULL"; 308 return '\'' ~ db.escape(s) ~ '\''; 309 } 310 311 string toSql(long s, Database db) { 312 return to!string(s); 313 } 314 315 string escapedVariants(Database db, in string sql, Variant[string] t) { 316 if(t.keys.length <= 0 || sql.indexOf("?") == -1) { 317 return sql; 318 } 319 320 string fixedup; 321 int currentStart = 0; 322 // FIXME: let's make ?? render as ? so we have some escaping capability 323 foreach(int i, dchar c; sql) { 324 if(c == '?') { 325 fixedup ~= sql[currentStart .. i]; 326 327 int idxStart = i + 1; 328 int idxLength; 329 330 bool isFirst = true; 331 332 while(idxStart + idxLength < sql.length) { 333 char C = sql[idxStart + idxLength]; 334 335 if((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') || C == '_' || (!isFirst && C >= '0' && C <= '9')) 336 idxLength++; 337 else 338 break; 339 340 isFirst = false; 341 } 342 343 auto idx = sql[idxStart .. idxStart + idxLength]; 344 345 if(idx in t) { 346 fixedup ~= toSql(db, t[idx]); 347 currentStart = idxStart + idxLength; 348 } else { 349 // just leave it there, it might be done on another layer 350 currentStart = i; 351 } 352 } 353 } 354 355 fixedup ~= sql[currentStart .. $]; 356 357 return fixedup; 358 } 359 360 // TODO: cut this out 361 /// Note: ?n params are zero based! 362 string escapedVariants(Database db, in string sql, Variant[] t) { 363 // FIXME: let's make ?? render as ? so we have some escaping capability 364 // if nothing to escape or nothing to escape with, don't bother 365 if(t.length > 0 && sql.indexOf("?") != -1) { 366 string fixedup; 367 int currentIndex; 368 int currentStart = 0; 369 foreach(int i, dchar c; sql) { 370 if(c == '?') { 371 fixedup ~= sql[currentStart .. i]; 372 373 int idx = -1; 374 currentStart = i + 1; 375 if((i + 1) < sql.length) { 376 auto n = sql[i + 1]; 377 if(n >= '0' && n <= '9') { 378 currentStart = i + 2; 379 idx = n - '0'; 380 } 381 } 382 if(idx == -1) { 383 idx = currentIndex; 384 currentIndex++; 385 } 386 387 if(idx < 0 || idx >= t.length) 388 throw new Exception("SQL Parameter index is out of bounds: " ~ to!string(idx) ~ " at `"~sql[0 .. i]~"`"); 389 390 fixedup ~= toSql(db, t[idx]); 391 } 392 } 393 394 fixedup ~= sql[currentStart .. $]; 395 396 return fixedup; 397 /* 398 string fixedup; 399 int pos = 0; 400 401 402 void escAndAdd(string str, int q) { 403 fixedup ~= sql[pos..q] ~ '\'' ~ db.escape(str) ~ '\''; 404 405 } 406 407 foreach(a; t) { 408 int q = sql[pos..$].indexOf("?"); 409 if(q == -1) 410 break; 411 q += pos; 412 413 auto v = a.peek!(void*); 414 if(v && (*v is null)) 415 fixedup ~= sql[pos..q] ~ "NULL"; 416 else { 417 string str = to!string(a); 418 escAndAdd(str, q); 419 } 420 421 pos = q+1; 422 } 423 424 fixedup ~= sql[pos..$]; 425 426 sql = fixedup; 427 */ 428 } 429 430 return sql; 431 } 432 433 434 enum UpdateOrInsertMode { 435 CheckForMe, 436 AlwaysUpdate, 437 AlwaysInsert 438 } 439 440 441 // BIG FIXME: this should really use prepared statements 442 int updateOrInsert(Database db, string table, string[string] values, string where, UpdateOrInsertMode mode = UpdateOrInsertMode.CheckForMe, string key = "id") { 443 bool insert = false; 444 445 final switch(mode) { 446 case UpdateOrInsertMode.CheckForMe: 447 auto res = db.query("SELECT "~key~" FROM `"~db.escape(table)~"` WHERE " ~ where); 448 insert = res.empty; 449 450 break; 451 case UpdateOrInsertMode.AlwaysInsert: 452 insert = true; 453 break; 454 case UpdateOrInsertMode.AlwaysUpdate: 455 insert = false; 456 break; 457 } 458 459 460 if(insert) { 461 string insertSql = "INSERT INTO `" ~ db.escape(table) ~ "` "; 462 463 bool outputted = false; 464 string vs, cs; 465 foreach(column, value; values) { 466 if(column is null) 467 continue; 468 if(outputted) { 469 vs ~= ", "; 470 cs ~= ", "; 471 } else 472 outputted = true; 473 474 //cs ~= "`" ~ db.escape(column) ~ "`"; 475 cs ~= "`" ~ column ~ "`"; // FIXME: possible insecure 476 if(value is null) 477 vs ~= "NULL"; 478 else 479 vs ~= "'" ~ db.escape(value) ~ "'"; 480 } 481 482 if(!outputted) 483 return 0; 484 485 486 insertSql ~= "(" ~ cs ~ ")"; 487 insertSql ~= " VALUES "; 488 insertSql ~= "(" ~ vs ~ ")"; 489 490 db.query(insertSql); 491 492 return 0; // db.lastInsertId; 493 } else { 494 string updateSql = "UPDATE `"~db.escape(table)~"` SET "; 495 496 bool outputted = false; 497 foreach(column, value; values) { 498 if(column is null) 499 continue; 500 if(outputted) 501 updateSql ~= ", "; 502 else 503 outputted = true; 504 505 if(value is null) 506 updateSql ~= "`" ~ db.escape(column) ~ "` = NULL"; 507 else 508 updateSql ~= "`" ~ db.escape(column) ~ "` = '" ~ db.escape(value) ~ "'"; 509 } 510 511 if(!outputted) 512 return 0; 513 514 updateSql ~= " WHERE " ~ where; 515 516 db.query(updateSql); 517 return 0; 518 } 519 } 520 521 522 523 524 525 string fixupSqlForDataObjectUse(string sql, string[string] keyMapping = null) { 526 527 string[] tableNames; 528 529 string piece = sql; 530 sizediff_t idx; 531 while((idx = piece.indexOf("JOIN")) != -1) { 532 auto start = idx + 5; 533 auto i = start; 534 while(piece[i] != ' ' && piece[i] != '\n' && piece[i] != '\t' && piece[i] != ',') 535 i++; 536 auto end = i; 537 538 tableNames ~= strip(piece[start..end]); 539 540 piece = piece[end..$]; 541 } 542 543 idx = sql.indexOf("FROM"); 544 if(idx != -1) { 545 auto start = idx + 5; 546 auto i = start; 547 start = i; 548 while(i < sql.length && !(sql[i] > 'A' && sql[i] <= 'Z')) // if not uppercase, except for A (for AS) to avoid SQL keywords (hack) 549 i++; 550 551 auto from = sql[start..i]; 552 auto pieces = from.split(","); 553 foreach(p; pieces) { 554 p = p.strip(); 555 start = 0; 556 i = 0; 557 while(i < p.length && p[i] != ' ' && p[i] != '\n' && p[i] != '\t' && p[i] != ',') 558 i++; 559 560 tableNames ~= strip(p[start..i]); 561 } 562 563 string sqlToAdd; 564 foreach(tbl; tableNames) { 565 if(tbl.length) { 566 string keyName = "id"; 567 if(tbl in keyMapping) 568 keyName = keyMapping[tbl]; 569 sqlToAdd ~= ", " ~ tbl ~ "." ~ keyName ~ " AS " ~ "id_from_" ~ tbl; 570 } 571 } 572 573 sqlToAdd ~= " "; 574 575 sql = sql[0..idx] ~ sqlToAdd ~ sql[idx..$]; 576 } 577 578 return sql; 579 } 580 581 import mysql.data_object; 582 583 /** 584 Given some SQL, it finds the CREATE TABLE 585 instruction for the given tableName. 586 (this is so it can find one entry from 587 a file with several SQL commands. But it 588 may break on a complex file, so try to only 589 feed it simple sql files.) 590 591 From that, it pulls out the members to create a 592 simple struct based on it. 593 594 It's not terribly smart, so it will probably 595 break on complex tables. 596 597 Data types handled: 598 INTEGER, SMALLINT, MEDIUMINT -> D's int 599 TINYINT -> D's bool 600 BIGINT -> D's long 601 TEXT, VARCHAR -> D's string 602 FLOAT, DOUBLE -> D's double 603 604 It also reads DEFAULT values to pass to D, except for NULL. 605 It ignores any length restrictions. 606 607 Bugs: 608 Skips all constraints 609 Doesn't handle nullable fields, except with strings 610 It only handles SQL keywords if they are all caps 611 612 This, when combined with SimpleDataObject!(), 613 can automatically create usable D classes from 614 SQL input. 615 */ 616 struct StructFromCreateTable(string sql, string tableName) { 617 mixin(getCreateTable(sql, tableName)); 618 } 619 620 string getCreateTable(string sql, string tableName) { 621 skip: 622 while(readWord(sql) != "CREATE") {} 623 624 assert(readWord(sql) == "TABLE"); 625 626 if(readWord(sql) != tableName) 627 goto skip; 628 629 assert(readWord(sql) == "("); 630 631 int state; 632 int parens; 633 634 struct Field { 635 string name; 636 string type; 637 string defaultValue; 638 } 639 Field*[] fields; 640 641 string word = readWord(sql); 642 Field* current = new Field(); // well, this is interesting... under new DMD, not using new breaks it in CTFE because it overwrites the one entry! 643 while(word != ")" || parens) { 644 if(word == ")") { 645 parens --; 646 word = readWord(sql); 647 continue; 648 } 649 if(word == "(") { 650 parens ++; 651 word = readWord(sql); 652 continue; 653 } 654 switch(state) { 655 default: assert(0); 656 case 0: 657 if(word[0] >= 'A' && word[0] <= 'Z') { 658 state = 4; 659 break; // we want to skip this since it starts with a keyword (we hope) 660 } 661 current.name = word; 662 state = 1; 663 break; 664 case 1: 665 current.type ~= word; 666 state = 2; 667 break; 668 case 2: 669 if(word == "DEFAULT") 670 state = 3; 671 else if (word == ",") { 672 fields ~= current; 673 current = new Field(); 674 state = 0; // next 675 } 676 break; 677 case 3: 678 current.defaultValue = word; 679 state = 2; // back to skipping 680 break; 681 case 4: 682 if(word == ",") 683 state = 0; 684 } 685 686 word = readWord(sql); 687 } 688 689 if(current.name !is null) 690 fields ~= current; 691 692 693 string structCode; 694 foreach(field; fields) { 695 structCode ~= "\t"; 696 697 switch(field.type) { 698 case "INTEGER": 699 case "SMALLINT": 700 case "MEDIUMINT": 701 structCode ~= "int"; 702 break; 703 case "BOOLEAN": 704 case "TINYINT": 705 structCode ~= "bool"; 706 break; 707 case "BIGINT": 708 structCode ~= "long"; 709 break; 710 case "CHAR": 711 case "char": 712 case "VARCHAR": 713 case "varchar": 714 case "TEXT": 715 case "text": 716 structCode ~= "string"; 717 break; 718 case "FLOAT": 719 case "DOUBLE": 720 structCode ~= "double"; 721 break; 722 default: 723 assert(0, "unknown type " ~ field.type ~ " for " ~ field.name); 724 } 725 726 structCode ~= " "; 727 structCode ~= field.name; 728 729 if(field.defaultValue !is null) { 730 structCode ~= " = " ~ field.defaultValue; 731 } 732 733 structCode ~= ";\n"; 734 } 735 736 return structCode; 737 } 738 739 string readWord(ref string src) { 740 reset: 741 while(src[0] == ' ' || src[0] == '\t' || src[0] == '\n') 742 src = src[1..$]; 743 if(src.length >= 2 && src[0] == '-' && src[1] == '-') { // a comment, skip it 744 while(src[0] != '\n') 745 src = src[1..$]; 746 goto reset; 747 } 748 749 int start, pos; 750 if(src[0] == '`') { 751 src = src[1..$]; 752 while(src[pos] != '`') 753 pos++; 754 goto gotit; 755 } 756 757 758 while( 759 (src[pos] >= 'A' && src[pos] <= 'Z') 760 || 761 (src[pos] >= 'a' && src[pos] <= 'z') 762 || 763 (src[pos] >= '0' && src[pos] <= '9') 764 || 765 src[pos] == '_' 766 ) 767 pos++; 768 gotit: 769 if(pos == 0) 770 pos = 1; 771 772 string tmp = src[0..pos]; 773 774 if(src[pos] == '`') 775 pos++; // skip the ending quote; 776 777 src = src[pos..$]; 778 779 return tmp; 780 } 781 782 /// Combines StructFromCreateTable and SimpleDataObject into a one-stop template. 783 /// alias DataObjectFromSqlCreateTable(import("file.sql"), "my_table") MyTable; 784 template DataObjectFromSqlCreateTable(string sql, string tableName) { 785 alias SimpleDataObject!(tableName, StructFromCreateTable!(sql, tableName)) DataObjectFromSqlCreateTable; 786 } 787 788 /+ 789 class MyDataObject : DataObject { 790 this() { 791 super(new Database("localhost", "root", "pass", "social"), null); 792 } 793 794 mixin StrictDataObject!(); 795 796 mixin(DataObjectField!(int, "users", "id")); 797 } 798 799 void main() { 800 auto a = new MyDataObject; 801 802 a.fields["id"] = "10"; 803 804 a.id = 34; 805 806 a.commitChanges; 807 } 808 +/ 809 810 /* 811 alias DataObjectFromSqlCreateTable!(import("db.sql"), "users") Test; 812 813 void main() { 814 auto a = new Test(null); 815 816 a.cool = "way"; 817 a.value = 100; 818 } 819 */ 820 821 void typeinfoBugWorkaround() { 822 assert(0, to!string(typeid(immutable(char[])[immutable(char)[]]))); 823 }