1
0
Fork 0

Factor out read-only check

master
Philip O'Toole 1 year ago
parent bb868595c1
commit 8f45c4c8c9

@ -579,19 +579,8 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([
// Do best-effort check that the statement won't try to change
// the database. As per the SQLite documentation, this will not
// cover 100% of possibilities, but should cover most.
var readOnly bool
f := func(driverConn interface{}) error {
c := driverConn.(*sqlite3.SQLiteConn)
drvStmt, err := c.Prepare(sql)
if err != nil {
return err
}
defer drvStmt.Close()
sqliteStmt := drvStmt.(*sqlite3.SQLiteStmt)
readOnly = sqliteStmt.Readonly()
return nil
}
if err := conn.Raw(f); err != nil {
readOnly, err := db.StmtReadOnly(sql)
if err != nil {
stats.Add(numQueryErrors, 1)
rows.Error = err.Error()
allRows = append(allRows, rows)
@ -834,6 +823,32 @@ func (db *DB) Dump(w io.Writer) error {
return nil
}
// StmtReadOnly returns whether the given SQL statement is read-only.
func (db *DB) StmtReadOnly(sql string) (bool, error) {
var readOnly bool
f := func(driverConn interface{}) error {
c := driverConn.(*sqlite3.SQLiteConn)
drvStmt, err := c.Prepare(sql)
if err != nil {
return err
}
defer drvStmt.Close()
sqliteStmt := drvStmt.(*sqlite3.SQLiteStmt)
readOnly = sqliteStmt.Readonly()
return nil
}
conn, err := db.roDB.Conn(context.Background())
if err != nil {
return false, err
}
defer conn.Close()
if err := conn.Raw(f); err != nil {
return false, err
}
return readOnly, nil
}
func (db *DB) memStats() (map[string]int64, error) {
ms := make(map[string]int64)
for _, p := range []string{

@ -2,6 +2,7 @@ package db
import (
"database/sql"
"errors"
"fmt"
"io/ioutil"
"os"
@ -747,7 +748,7 @@ func Test_SimplePragmaTableInfo(t *testing.T) {
}
}
func Test_WriteOnQueryOnDiskDatabase(t *testing.T) {
func Test_WriteOnQueryOnDiskDatabaseShouldFail(t *testing.T) {
db, path := mustCreateDatabase()
defer db.Close()
defer os.Remove(path)
@ -785,7 +786,7 @@ func Test_WriteOnQueryOnDiskDatabase(t *testing.T) {
}
}
func Test_WriteOnQueryInMemDatabase(t *testing.T) {
func Test_WriteOnQueryInMemDatabaseShouldFail(t *testing.T) {
db := mustCreateInMemoryDatabase()
defer db.Close()
@ -1983,6 +1984,107 @@ func Test_TableCreationInMemoryLoadRaw(t *testing.T) {
}
}
func Test_StmtReadOnly(t *testing.T) {
db := mustCreateInMemoryDatabase()
defer db.Close()
r, err := db.ExecuteStringStmt(`CREATE TABLE foo (id INTEGER NOT NULL PRIMARY KEY, name TEXT)`)
if err != nil {
t.Fatalf("failed to create table: %s", err.Error())
}
if exp, got := `[{}]`, asJSON(r); exp != got {
t.Fatalf("unexpected results for query\nexp: %s\ngot: %s", exp, got)
}
tests := []struct {
name string
sql string
readOnly bool
err error
}{
{
name: "CREATE TABLE statement",
sql: "CREATE TABLE bar (id INTEGER NOT NULL PRIMARY KEY, name TEXT)",
readOnly: false,
},
{
name: "SELECT statement",
sql: "SELECT * FROM foo",
readOnly: true,
},
{
name: "INSERT statement",
sql: "INSERT INTO foo VALUES (1, 'test')",
readOnly: false,
},
{
name: "UPDATE statement",
sql: "UPDATE foo SET name='test' WHERE id=1",
readOnly: false,
},
{
name: "DELETE statement",
sql: "DELETE FROM foo WHERE id=1",
readOnly: false,
},
{
name: "SELECT statement with positional parameters",
sql: "SELECT * FROM foo WHERE id = ?",
readOnly: true,
},
{
name: "SELECT statement with named parameters",
sql: "SELECT * FROM foo WHERE id = @id AND name = @name",
readOnly: true,
},
{
name: "INSERT statement with positional parameters",
sql: "INSERT INTO foo VALUES (?, ?)",
readOnly: false,
},
{
name: "INSERT statement with named parameters",
sql: "INSERT INTO foo VALUES (@id, @name)",
readOnly: false,
},
{
name: "WITH clause, read-only",
sql: "WITH bar AS (SELECT * FROM foo WHERE id = ?) SELECT * FROM bar",
readOnly: true,
},
{
name: "WITH clause, not read-only",
sql: "WITH bar AS (SELECT * FROM foo WHERE id = ?) DELETE FROM foo WHERE id IN (SELECT id FROM bar)",
readOnly: false,
},
{
name: "Invalid statement",
sql: "INVALID SQL STATEMENT",
err: errors.New(`near "INVALID": syntax error`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
readOnly, err := db.StmtReadOnly(tt.sql)
// Check if error is as expected
if err != nil && tt.err == nil {
t.Fatalf("unexpected error: got %v", err)
} else if err == nil && tt.err != nil {
t.Fatalf("expected error: got nil")
} else if err != nil && tt.err != nil && err.Error() != tt.err.Error() {
t.Fatalf("unexpected error: expected %v, got %v", tt.err, err)
}
// Check if result is as expected
if readOnly != tt.readOnly {
t.Fatalf("unexpected readOnly: expected %v, got %v", tt.readOnly, readOnly)
}
})
}
}
func mustCreateDatabase() (*DB, string) {
var err error
f := mustTempFile()

Loading…
Cancel
Save