package command import ( "regexp" "testing" "github.com/rqlite/rqlite/v8/command/proto" ) func Test_NoRewrites(t *testing.T) { for _, str := range []string{ `INSERT INTO "names" VALUES (1, 'bob', '123-45-678')`, `INSERT INTO "names" VALUES (RANDOM(), 'bob', '123-45-678')`, `SELECT title FROM albums ORDER BY RANDOM()`, `INSERT INTO foo(name, age) VALUES(?, ?)`, } { stmts := []*proto.Statement{ { Sql: str, }, } if err := Rewrite(stmts, false); err != nil { t.Fatalf("failed to not rewrite: %s", err) } if stmts[0].Sql != str { t.Fatalf("SQL is modified: %s", stmts[0].Sql) } } } func Test_NoRewritesMulti(t *testing.T) { stmts := []*proto.Statement{ { Sql: `INSERT INTO "names" VALUES (1, 'bob', '123-45-678')`, }, { Sql: `INSERT INTO "names" VALUES (RANDOM(), 'bob', '123-45-678')`, }, { Sql: `SELECT title FROM albums ORDER BY RANDOM()`, }, } if err := Rewrite(stmts, false); err != nil { t.Fatalf("failed to not rewrite: %s", err) } if len(stmts) != 3 { t.Fatalf("returned stmts is wrong length: %d", len(stmts)) } if stmts[0].Sql != `INSERT INTO "names" VALUES (1, 'bob', '123-45-678')` { t.Fatalf("SQL is modified: %s", stmts[0].Sql) } if stmts[1].Sql != `INSERT INTO "names" VALUES (RANDOM(), 'bob', '123-45-678')` { t.Fatalf("SQL is modified: %s", stmts[0].Sql) } if stmts[2].Sql != `SELECT title FROM albums ORDER BY RANDOM()` { t.Fatalf("SQL is modified: %s", stmts[0].Sql) } } func Test_Rewrites(t *testing.T) { testSQLs := []string{ `INSERT INTO "names" VALUES (1, 'bob', '123-45-678')`, `INSERT INTO "names" VALUES \(1, 'bob', '123-45-678'\)`, `INSERT INTO "names" VALUES (RANDOM(), 'bob', '123-45-678')`, `INSERT INTO "names" VALUES \(-?[0-9]+, 'bob', '123-45-678'\)`, `SELECT title FROM albums ORDER BY RANDOM()`, `SELECT title FROM albums ORDER BY RANDOM\(\)`, `SELECT RANDOM()`, `SELECT -?[0-9]+`, `CREATE TABLE tbl (col1 TEXT, ts DATETIME DEFAULT CURRENT_TIMESTAMP)`, `CREATE TABLE tbl \(col1 TEXT, ts DATETIME DEFAULT CURRENT_TIMESTAMP\)`, } for i := 0; i < len(testSQLs)-1; i += 2 { stmts := []*proto.Statement{ { Sql: testSQLs[i], }, } if err := Rewrite(stmts, true); err != nil { t.Fatalf("failed to not rewrite: %s", err) } match := regexp.MustCompile(testSQLs[i+1]) if !match.MatchString(stmts[0].Sql) { t.Fatalf("test %d failed, %s (rewritten as %s) does not regex-match with %s", i, testSQLs[i], stmts[0].Sql, testSQLs[i+1]) } } }