1
0
Fork 0

DB-level unit tests pass

master
Philip O'Toole 3 years ago
parent 0cbf96399a
commit aa8df67592

@ -75,23 +75,6 @@ type PoolStats struct {
MaxLifetimeClosed int64 `json:"max_lifetime_closed"`
}
// Result represents the outcome of an operation that changes rows.
type Result struct {
LastInsertID int64 `json:"last_insert_id,omitempty"`
RowsAffected int64 `json:"rows_affected,omitempty"`
Error string `json:"error,omitempty"`
Time float64 `json:"time,omitempty"`
}
// Rows represents the outcome of an operation that returns query data.
type Rows struct {
Columns []string `json:"columns,omitempty"`
Types []string `json:"types,omitempty"`
Values [][]interface{} `json:"values,omitempty"`
Error string `json:"error,omitempty"`
Time float64 `json:"time,omitempty"`
}
// Open opens a file-based database, creating it if it does not exist.
func Open(dbPath string, fkEnabled bool) (*DB, error) {
rwDSN := fmt.Sprintf("file:%s?_fk=%s", dbPath, strconv.FormatBool(fkEnabled))
@ -313,7 +296,7 @@ func (db *DB) Size() (int64, error) {
return 0, err
}
return rows[0].Values[0][0].(int64), nil
return rows[0].Values[0].Parameters[0].GetI(), nil
}
// FileSize returns the size of the SQLite file on disk. If running in
@ -341,14 +324,10 @@ func (db *DB) CompileOptions() ([]string, error) {
copts := make([]string, len(res[0].Values))
for i := range copts {
if len(res[0].Values[i]) != 1 {
if len(res[0].Values[i].Parameters) != 1 {
return nil, fmt.Errorf("compile options values wrong size (%d)", len(res))
}
if co, ok := res[0].Values[i][0].(string); !ok {
copts[i] = "not string!"
} else {
copts[i] = co
}
copts[i] = res[0].Values[i].Parameters[0].GetS()
}
return copts, nil
}
@ -372,7 +351,7 @@ func (db *DB) ConnectionPoolStats(sqlDB *sql.DB) *PoolStats {
// ExecuteStringStmt executes a single query that modifies the database. This is
// primarily a convenience function.
func (db *DB) ExecuteStringStmt(query string) ([]*Result, error) {
func (db *DB) ExecuteStringStmt(query string) ([]*command.ExecuteResult, error) {
r := &command.Request{
Statements: []*command.Statement{
{
@ -384,7 +363,7 @@ func (db *DB) ExecuteStringStmt(query string) ([]*Result, error) {
}
// Execute executes queries that modify the database.
func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) {
func (db *DB) Execute(req *command.Request, xTime bool) ([]*command.ExecuteResult, error) {
stats.Add(numExecutions, int64(len(req.Statements)))
conn, err := db.rwDB.Conn(context.Background())
@ -415,11 +394,11 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) {
execer = conn
}
var allResults []*Result
var allResults []*command.ExecuteResult
// handleError sets the error field on the given result. It returns
// whether the caller should continue processing or break.
handleError := func(result *Result, err error) bool {
handleError := func(result *command.ExecuteResult, err error) bool {
stats.Add(numExecutionErrors, 1)
result.Error = err.Error()
allResults = append(allResults, result)
@ -438,7 +417,7 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) {
continue
}
result := &Result{}
result := &command.ExecuteResult{}
start := time.Now()
parameters, err := parametersToValues(stmt.Parameters)
@ -468,7 +447,7 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) {
}
break
}
result.LastInsertID = lid
result.LastInsertId = lid
ra, err := r.RowsAffected()
if err != nil {
@ -491,7 +470,7 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) {
}
// QueryStringStmt executes a single query that return rows, but don't modify database.
func (db *DB) QueryStringStmt(query string) ([]*Rows, error) {
func (db *DB) QueryStringStmt(query string) ([]*command.QueryRows, error) {
r := &command.Request{
Statements: []*command.Statement{
{
@ -503,7 +482,7 @@ func (db *DB) QueryStringStmt(query string) ([]*Rows, error) {
}
// Query executes queries that return rows, but don't modify the database.
func (db *DB) Query(req *command.Request, xTime bool) ([]*Rows, error) {
func (db *DB) Query(req *command.Request, xTime bool) ([]*command.QueryRows, error) {
stats.Add(numQueries, int64(len(req.Statements)))
conn, err := db.roDB.Conn(context.Background())
if err != nil {
@ -513,7 +492,7 @@ func (db *DB) Query(req *command.Request, xTime bool) ([]*Rows, error) {
return db.queryWithConn(req, xTime, conn)
}
func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([]*Rows, error) {
func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([]*command.QueryRows, error) {
var err error
type Queryer interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
@ -533,14 +512,14 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([
queryer = conn
}
var allRows []*Rows
var allRows []*command.QueryRows
for _, stmt := range req.Statements {
sql := stmt.Sql
if sql == "" {
continue
}
rows := &Rows{}
rows := &command.QueryRows{}
start := time.Now()
parameters, err := parametersToValues(stmt.Parameters)
@ -583,8 +562,9 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([
if err := rs.Scan(ptrs...); err != nil {
return nil, err
}
values := normalizeRowValues(dest, xTypes)
rows.Values = append(rows.Values, values)
rows.Values = append(rows.Values, &command.Values{
Parameters: normalizeRowValues(dest, xTypes),
})
}
// Check for errors from iterating over rows.
@ -696,7 +676,7 @@ func (db *DB) Dump(w io.Writer) error {
}
row := rows[0]
for _, v := range row.Values {
table := v[0].(string)
table := v.Parameters[0].GetS()
var stmt string
if table == "sqlite_sequence" {
@ -706,7 +686,7 @@ func (db *DB) Dump(w io.Writer) error {
} else if strings.HasPrefix(table, "sqlite_") {
continue
} else {
stmt = v[2].(string)
stmt = v.Parameters[2].GetS()
}
if _, err := w.Write([]byte(fmt.Sprintf("%s;\n", stmt))); err != nil {
@ -721,7 +701,7 @@ func (db *DB) Dump(w io.Writer) error {
}
var columnNames []string
for _, w := range r[0].Values {
columnNames = append(columnNames, fmt.Sprintf(`'||quote("%s")||'`, w[1].(string)))
columnNames = append(columnNames, fmt.Sprintf(`'||quote("%s")||'`, w.Parameters[1].GetS()))
}
query = fmt.Sprintf(`SELECT 'INSERT INTO "%s" VALUES(%s)' FROM "%s";`,
@ -734,7 +714,7 @@ func (db *DB) Dump(w io.Writer) error {
return err
}
for _, x := range r[0].Values {
y := fmt.Sprintf("%s;\n", x[0].(string))
y := fmt.Sprintf("%s;\n", x.Parameters[0].GetS())
if _, err := w.Write([]byte(y)); err != nil {
return err
}
@ -750,7 +730,7 @@ func (db *DB) Dump(w io.Writer) error {
}
row = rows[0]
for _, v := range row.Values {
if _, err := w.Write([]byte(fmt.Sprintf("%s;\n", v[2]))); err != nil {
if _, err := w.Write([]byte(fmt.Sprintf("%s;\n", v.Parameters[2].GetS()))); err != nil {
return err
}
}
@ -839,16 +819,50 @@ func parametersToValues(parameters []*command.Parameter) ([]interface{}, error)
// Text values come over (from sqlite-go) as []byte instead of strings
// for some reason, so we have explicitly convert (but only when type
// is "text" so we don't affect BLOB types)
func normalizeRowValues(row []interface{}, types []string) []interface{} {
func normalizeRowValues(row []interface{}, types []string) []*command.Parameter {
values := make([]*command.Parameter, len(types))
for i, v := range row {
if isTextType(types[i]) {
val, ok := v.([]byte)
if ok {
row[i] = string(val)
switch val := v.(type) {
case int:
case int64:
values[i] = &command.Parameter{
Value: &command.Parameter_I{
I: val,
},
}
case float64:
values[i] = &command.Parameter{
Value: &command.Parameter_D{
D: val,
},
}
case bool:
values[i] = &command.Parameter{
Value: &command.Parameter_B{
B: val,
},
}
case string:
values[i] = &command.Parameter{
Value: &command.Parameter_S{
S: val,
},
}
case []byte:
if isTextType(types[i]) {
values[i].Value = &command.Parameter_S{
S: string(val),
}
} else {
values[i] = &command.Parameter{
Value: &command.Parameter_Y{
Y: val,
},
}
}
}
}
return row
return values
}
// isTextType returns whether the given type has a SQLite text affinity.

@ -1,7 +1,6 @@
package db
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
@ -12,6 +11,7 @@ import (
"time"
"github.com/rqlite/rqlite/command"
"github.com/rqlite/rqlite/command/encoding"
"github.com/rqlite/rqlite/testdata/chinook"
)
@ -1529,9 +1529,9 @@ func mustQuery(db *DB, stmt string) {
}
func asJSON(v interface{}) string {
b, err := json.Marshal(v)
b, err := encoding.JSONMarshal(v)
if err != nil {
panic("failed to JSON marshal value")
panic(fmt.Sprintf("failed to JSON marshal value: %s", err.Error()))
}
return string(b)
}

Loading…
Cancel
Save