|
|
|
@ -2,6 +2,7 @@ package db
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"database/sql"
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
|
"io/ioutil"
|
|
|
|
|
"os"
|
|
|
|
@ -747,7 +748,7 @@ func Test_SimplePragmaTableInfo(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_WriteOnQueryOnDiskDatabase(t *testing.T) {
|
|
|
|
|
func Test_WriteOnQueryOnDiskDatabaseShouldFail(t *testing.T) {
|
|
|
|
|
db, path := mustCreateDatabase()
|
|
|
|
|
defer db.Close()
|
|
|
|
|
defer os.Remove(path)
|
|
|
|
@ -785,7 +786,7 @@ func Test_WriteOnQueryOnDiskDatabase(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_WriteOnQueryInMemDatabase(t *testing.T) {
|
|
|
|
|
func Test_WriteOnQueryInMemDatabaseShouldFail(t *testing.T) {
|
|
|
|
|
db := mustCreateInMemoryDatabase()
|
|
|
|
|
defer db.Close()
|
|
|
|
|
|
|
|
|
@ -1983,6 +1984,107 @@ func Test_TableCreationInMemoryLoadRaw(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_StmtReadOnly(t *testing.T) {
|
|
|
|
|
db := mustCreateInMemoryDatabase()
|
|
|
|
|
defer db.Close()
|
|
|
|
|
|
|
|
|
|
r, err := db.ExecuteStringStmt(`CREATE TABLE foo (id INTEGER NOT NULL PRIMARY KEY, name TEXT)`)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("failed to create table: %s", err.Error())
|
|
|
|
|
}
|
|
|
|
|
if exp, got := `[{}]`, asJSON(r); exp != got {
|
|
|
|
|
t.Fatalf("unexpected results for query\nexp: %s\ngot: %s", exp, got)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
sql string
|
|
|
|
|
readOnly bool
|
|
|
|
|
err error
|
|
|
|
|
}{
|
|
|
|
|
{
|
|
|
|
|
name: "CREATE TABLE statement",
|
|
|
|
|
sql: "CREATE TABLE bar (id INTEGER NOT NULL PRIMARY KEY, name TEXT)",
|
|
|
|
|
readOnly: false,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "SELECT statement",
|
|
|
|
|
sql: "SELECT * FROM foo",
|
|
|
|
|
readOnly: true,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "INSERT statement",
|
|
|
|
|
sql: "INSERT INTO foo VALUES (1, 'test')",
|
|
|
|
|
readOnly: false,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "UPDATE statement",
|
|
|
|
|
sql: "UPDATE foo SET name='test' WHERE id=1",
|
|
|
|
|
readOnly: false,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "DELETE statement",
|
|
|
|
|
sql: "DELETE FROM foo WHERE id=1",
|
|
|
|
|
readOnly: false,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "SELECT statement with positional parameters",
|
|
|
|
|
sql: "SELECT * FROM foo WHERE id = ?",
|
|
|
|
|
readOnly: true,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "SELECT statement with named parameters",
|
|
|
|
|
sql: "SELECT * FROM foo WHERE id = @id AND name = @name",
|
|
|
|
|
readOnly: true,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "INSERT statement with positional parameters",
|
|
|
|
|
sql: "INSERT INTO foo VALUES (?, ?)",
|
|
|
|
|
readOnly: false,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "INSERT statement with named parameters",
|
|
|
|
|
sql: "INSERT INTO foo VALUES (@id, @name)",
|
|
|
|
|
readOnly: false,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "WITH clause, read-only",
|
|
|
|
|
sql: "WITH bar AS (SELECT * FROM foo WHERE id = ?) SELECT * FROM bar",
|
|
|
|
|
readOnly: true,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "WITH clause, not read-only",
|
|
|
|
|
sql: "WITH bar AS (SELECT * FROM foo WHERE id = ?) DELETE FROM foo WHERE id IN (SELECT id FROM bar)",
|
|
|
|
|
readOnly: false,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "Invalid statement",
|
|
|
|
|
sql: "INVALID SQL STATEMENT",
|
|
|
|
|
err: errors.New(`near "INVALID": syntax error`),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
readOnly, err := db.StmtReadOnly(tt.sql)
|
|
|
|
|
|
|
|
|
|
// Check if error is as expected
|
|
|
|
|
if err != nil && tt.err == nil {
|
|
|
|
|
t.Fatalf("unexpected error: got %v", err)
|
|
|
|
|
} else if err == nil && tt.err != nil {
|
|
|
|
|
t.Fatalf("expected error: got nil")
|
|
|
|
|
} else if err != nil && tt.err != nil && err.Error() != tt.err.Error() {
|
|
|
|
|
t.Fatalf("unexpected error: expected %v, got %v", tt.err, err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Check if result is as expected
|
|
|
|
|
if readOnly != tt.readOnly {
|
|
|
|
|
t.Fatalf("unexpected readOnly: expected %v, got %v", tt.readOnly, readOnly)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func mustCreateDatabase() (*DB, string) {
|
|
|
|
|
var err error
|
|
|
|
|
f := mustTempFile()
|
|
|
|
|