1
0
Fork 0

DB-level unit tests pass

master
Philip O'Toole 3 years ago
parent 0cbf96399a
commit aa8df67592

@ -75,23 +75,6 @@ type PoolStats struct {
MaxLifetimeClosed int64 `json:"max_lifetime_closed"` MaxLifetimeClosed int64 `json:"max_lifetime_closed"`
} }
// Result represents the outcome of an operation that changes rows.
type Result struct {
LastInsertID int64 `json:"last_insert_id,omitempty"`
RowsAffected int64 `json:"rows_affected,omitempty"`
Error string `json:"error,omitempty"`
Time float64 `json:"time,omitempty"`
}
// Rows represents the outcome of an operation that returns query data.
type Rows struct {
Columns []string `json:"columns,omitempty"`
Types []string `json:"types,omitempty"`
Values [][]interface{} `json:"values,omitempty"`
Error string `json:"error,omitempty"`
Time float64 `json:"time,omitempty"`
}
// Open opens a file-based database, creating it if it does not exist. // Open opens a file-based database, creating it if it does not exist.
func Open(dbPath string, fkEnabled bool) (*DB, error) { func Open(dbPath string, fkEnabled bool) (*DB, error) {
rwDSN := fmt.Sprintf("file:%s?_fk=%s", dbPath, strconv.FormatBool(fkEnabled)) rwDSN := fmt.Sprintf("file:%s?_fk=%s", dbPath, strconv.FormatBool(fkEnabled))
@ -313,7 +296,7 @@ func (db *DB) Size() (int64, error) {
return 0, err return 0, err
} }
return rows[0].Values[0][0].(int64), nil return rows[0].Values[0].Parameters[0].GetI(), nil
} }
// FileSize returns the size of the SQLite file on disk. If running in // FileSize returns the size of the SQLite file on disk. If running in
@ -341,14 +324,10 @@ func (db *DB) CompileOptions() ([]string, error) {
copts := make([]string, len(res[0].Values)) copts := make([]string, len(res[0].Values))
for i := range copts { for i := range copts {
if len(res[0].Values[i]) != 1 { if len(res[0].Values[i].Parameters) != 1 {
return nil, fmt.Errorf("compile options values wrong size (%d)", len(res)) return nil, fmt.Errorf("compile options values wrong size (%d)", len(res))
} }
if co, ok := res[0].Values[i][0].(string); !ok { copts[i] = res[0].Values[i].Parameters[0].GetS()
copts[i] = "not string!"
} else {
copts[i] = co
}
} }
return copts, nil return copts, nil
} }
@ -372,7 +351,7 @@ func (db *DB) ConnectionPoolStats(sqlDB *sql.DB) *PoolStats {
// ExecuteStringStmt executes a single query that modifies the database. This is // ExecuteStringStmt executes a single query that modifies the database. This is
// primarily a convenience function. // primarily a convenience function.
func (db *DB) ExecuteStringStmt(query string) ([]*Result, error) { func (db *DB) ExecuteStringStmt(query string) ([]*command.ExecuteResult, error) {
r := &command.Request{ r := &command.Request{
Statements: []*command.Statement{ Statements: []*command.Statement{
{ {
@ -384,7 +363,7 @@ func (db *DB) ExecuteStringStmt(query string) ([]*Result, error) {
} }
// Execute executes queries that modify the database. // Execute executes queries that modify the database.
func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) { func (db *DB) Execute(req *command.Request, xTime bool) ([]*command.ExecuteResult, error) {
stats.Add(numExecutions, int64(len(req.Statements))) stats.Add(numExecutions, int64(len(req.Statements)))
conn, err := db.rwDB.Conn(context.Background()) conn, err := db.rwDB.Conn(context.Background())
@ -415,11 +394,11 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) {
execer = conn execer = conn
} }
var allResults []*Result var allResults []*command.ExecuteResult
// handleError sets the error field on the given result. It returns // handleError sets the error field on the given result. It returns
// whether the caller should continue processing or break. // whether the caller should continue processing or break.
handleError := func(result *Result, err error) bool { handleError := func(result *command.ExecuteResult, err error) bool {
stats.Add(numExecutionErrors, 1) stats.Add(numExecutionErrors, 1)
result.Error = err.Error() result.Error = err.Error()
allResults = append(allResults, result) allResults = append(allResults, result)
@ -438,7 +417,7 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) {
continue continue
} }
result := &Result{} result := &command.ExecuteResult{}
start := time.Now() start := time.Now()
parameters, err := parametersToValues(stmt.Parameters) parameters, err := parametersToValues(stmt.Parameters)
@ -468,7 +447,7 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) {
} }
break break
} }
result.LastInsertID = lid result.LastInsertId = lid
ra, err := r.RowsAffected() ra, err := r.RowsAffected()
if err != nil { if err != nil {
@ -491,7 +470,7 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) {
} }
// QueryStringStmt executes a single query that return rows, but don't modify database. // QueryStringStmt executes a single query that return rows, but don't modify database.
func (db *DB) QueryStringStmt(query string) ([]*Rows, error) { func (db *DB) QueryStringStmt(query string) ([]*command.QueryRows, error) {
r := &command.Request{ r := &command.Request{
Statements: []*command.Statement{ Statements: []*command.Statement{
{ {
@ -503,7 +482,7 @@ func (db *DB) QueryStringStmt(query string) ([]*Rows, error) {
} }
// Query executes queries that return rows, but don't modify the database. // Query executes queries that return rows, but don't modify the database.
func (db *DB) Query(req *command.Request, xTime bool) ([]*Rows, error) { func (db *DB) Query(req *command.Request, xTime bool) ([]*command.QueryRows, error) {
stats.Add(numQueries, int64(len(req.Statements))) stats.Add(numQueries, int64(len(req.Statements)))
conn, err := db.roDB.Conn(context.Background()) conn, err := db.roDB.Conn(context.Background())
if err != nil { if err != nil {
@ -513,7 +492,7 @@ func (db *DB) Query(req *command.Request, xTime bool) ([]*Rows, error) {
return db.queryWithConn(req, xTime, conn) return db.queryWithConn(req, xTime, conn)
} }
func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([]*Rows, error) { func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([]*command.QueryRows, error) {
var err error var err error
type Queryer interface { type Queryer interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
@ -533,14 +512,14 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([
queryer = conn queryer = conn
} }
var allRows []*Rows var allRows []*command.QueryRows
for _, stmt := range req.Statements { for _, stmt := range req.Statements {
sql := stmt.Sql sql := stmt.Sql
if sql == "" { if sql == "" {
continue continue
} }
rows := &Rows{} rows := &command.QueryRows{}
start := time.Now() start := time.Now()
parameters, err := parametersToValues(stmt.Parameters) parameters, err := parametersToValues(stmt.Parameters)
@ -583,8 +562,9 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([
if err := rs.Scan(ptrs...); err != nil { if err := rs.Scan(ptrs...); err != nil {
return nil, err return nil, err
} }
values := normalizeRowValues(dest, xTypes) rows.Values = append(rows.Values, &command.Values{
rows.Values = append(rows.Values, values) Parameters: normalizeRowValues(dest, xTypes),
})
} }
// Check for errors from iterating over rows. // Check for errors from iterating over rows.
@ -696,7 +676,7 @@ func (db *DB) Dump(w io.Writer) error {
} }
row := rows[0] row := rows[0]
for _, v := range row.Values { for _, v := range row.Values {
table := v[0].(string) table := v.Parameters[0].GetS()
var stmt string var stmt string
if table == "sqlite_sequence" { if table == "sqlite_sequence" {
@ -706,7 +686,7 @@ func (db *DB) Dump(w io.Writer) error {
} else if strings.HasPrefix(table, "sqlite_") { } else if strings.HasPrefix(table, "sqlite_") {
continue continue
} else { } else {
stmt = v[2].(string) stmt = v.Parameters[2].GetS()
} }
if _, err := w.Write([]byte(fmt.Sprintf("%s;\n", stmt))); err != nil { if _, err := w.Write([]byte(fmt.Sprintf("%s;\n", stmt))); err != nil {
@ -721,7 +701,7 @@ func (db *DB) Dump(w io.Writer) error {
} }
var columnNames []string var columnNames []string
for _, w := range r[0].Values { for _, w := range r[0].Values {
columnNames = append(columnNames, fmt.Sprintf(`'||quote("%s")||'`, w[1].(string))) columnNames = append(columnNames, fmt.Sprintf(`'||quote("%s")||'`, w.Parameters[1].GetS()))
} }
query = fmt.Sprintf(`SELECT 'INSERT INTO "%s" VALUES(%s)' FROM "%s";`, query = fmt.Sprintf(`SELECT 'INSERT INTO "%s" VALUES(%s)' FROM "%s";`,
@ -734,7 +714,7 @@ func (db *DB) Dump(w io.Writer) error {
return err return err
} }
for _, x := range r[0].Values { for _, x := range r[0].Values {
y := fmt.Sprintf("%s;\n", x[0].(string)) y := fmt.Sprintf("%s;\n", x.Parameters[0].GetS())
if _, err := w.Write([]byte(y)); err != nil { if _, err := w.Write([]byte(y)); err != nil {
return err return err
} }
@ -750,7 +730,7 @@ func (db *DB) Dump(w io.Writer) error {
} }
row = rows[0] row = rows[0]
for _, v := range row.Values { for _, v := range row.Values {
if _, err := w.Write([]byte(fmt.Sprintf("%s;\n", v[2]))); err != nil { if _, err := w.Write([]byte(fmt.Sprintf("%s;\n", v.Parameters[2].GetS()))); err != nil {
return err return err
} }
} }
@ -839,16 +819,50 @@ func parametersToValues(parameters []*command.Parameter) ([]interface{}, error)
// Text values come over (from sqlite-go) as []byte instead of strings // Text values come over (from sqlite-go) as []byte instead of strings
// for some reason, so we have explicitly convert (but only when type // for some reason, so we have explicitly convert (but only when type
// is "text" so we don't affect BLOB types) // is "text" so we don't affect BLOB types)
func normalizeRowValues(row []interface{}, types []string) []interface{} { func normalizeRowValues(row []interface{}, types []string) []*command.Parameter {
values := make([]*command.Parameter, len(types))
for i, v := range row { for i, v := range row {
if isTextType(types[i]) { switch val := v.(type) {
val, ok := v.([]byte) case int:
if ok { case int64:
row[i] = string(val) values[i] = &command.Parameter{
Value: &command.Parameter_I{
I: val,
},
}
case float64:
values[i] = &command.Parameter{
Value: &command.Parameter_D{
D: val,
},
}
case bool:
values[i] = &command.Parameter{
Value: &command.Parameter_B{
B: val,
},
}
case string:
values[i] = &command.Parameter{
Value: &command.Parameter_S{
S: val,
},
}
case []byte:
if isTextType(types[i]) {
values[i].Value = &command.Parameter_S{
S: string(val),
}
} else {
values[i] = &command.Parameter{
Value: &command.Parameter_Y{
Y: val,
},
}
} }
} }
} }
return row return values
} }
// isTextType returns whether the given type has a SQLite text affinity. // isTextType returns whether the given type has a SQLite text affinity.

@ -1,7 +1,6 @@
package db package db
import ( import (
"encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -12,6 +11,7 @@ import (
"time" "time"
"github.com/rqlite/rqlite/command" "github.com/rqlite/rqlite/command"
"github.com/rqlite/rqlite/command/encoding"
"github.com/rqlite/rqlite/testdata/chinook" "github.com/rqlite/rqlite/testdata/chinook"
) )
@ -1529,9 +1529,9 @@ func mustQuery(db *DB, stmt string) {
} }
func asJSON(v interface{}) string { func asJSON(v interface{}) string {
b, err := json.Marshal(v) b, err := encoding.JSONMarshal(v)
if err != nil { if err != nil {
panic("failed to JSON marshal value") panic(fmt.Sprintf("failed to JSON marshal value: %s", err.Error()))
} }
return string(b) return string(b)
} }

Loading…
Cancel
Save