diff --git a/db/db.go b/db/db.go index e10d850b..7809e3bf 100644 --- a/db/db.go +++ b/db/db.go @@ -75,23 +75,6 @@ type PoolStats struct { MaxLifetimeClosed int64 `json:"max_lifetime_closed"` } -// Result represents the outcome of an operation that changes rows. -type Result struct { - LastInsertID int64 `json:"last_insert_id,omitempty"` - RowsAffected int64 `json:"rows_affected,omitempty"` - Error string `json:"error,omitempty"` - Time float64 `json:"time,omitempty"` -} - -// Rows represents the outcome of an operation that returns query data. -type Rows struct { - Columns []string `json:"columns,omitempty"` - Types []string `json:"types,omitempty"` - Values [][]interface{} `json:"values,omitempty"` - Error string `json:"error,omitempty"` - Time float64 `json:"time,omitempty"` -} - // Open opens a file-based database, creating it if it does not exist. func Open(dbPath string, fkEnabled bool) (*DB, error) { rwDSN := fmt.Sprintf("file:%s?_fk=%s", dbPath, strconv.FormatBool(fkEnabled)) @@ -313,7 +296,7 @@ func (db *DB) Size() (int64, error) { return 0, err } - return rows[0].Values[0][0].(int64), nil + return rows[0].Values[0].Parameters[0].GetI(), nil } // FileSize returns the size of the SQLite file on disk. If running in @@ -341,14 +324,10 @@ func (db *DB) CompileOptions() ([]string, error) { copts := make([]string, len(res[0].Values)) for i := range copts { - if len(res[0].Values[i]) != 1 { + if len(res[0].Values[i].Parameters) != 1 { return nil, fmt.Errorf("compile options values wrong size (%d)", len(res)) } - if co, ok := res[0].Values[i][0].(string); !ok { - copts[i] = "not string!" - } else { - copts[i] = co - } + copts[i] = res[0].Values[i].Parameters[0].GetS() } return copts, nil } @@ -372,7 +351,7 @@ func (db *DB) ConnectionPoolStats(sqlDB *sql.DB) *PoolStats { // ExecuteStringStmt executes a single query that modifies the database. This is // primarily a convenience function. -func (db *DB) ExecuteStringStmt(query string) ([]*Result, error) { +func (db *DB) ExecuteStringStmt(query string) ([]*command.ExecuteResult, error) { r := &command.Request{ Statements: []*command.Statement{ { @@ -384,7 +363,7 @@ func (db *DB) ExecuteStringStmt(query string) ([]*Result, error) { } // Execute executes queries that modify the database. -func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) { +func (db *DB) Execute(req *command.Request, xTime bool) ([]*command.ExecuteResult, error) { stats.Add(numExecutions, int64(len(req.Statements))) conn, err := db.rwDB.Conn(context.Background()) @@ -415,11 +394,11 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) { execer = conn } - var allResults []*Result + var allResults []*command.ExecuteResult // handleError sets the error field on the given result. It returns // whether the caller should continue processing or break. - handleError := func(result *Result, err error) bool { + handleError := func(result *command.ExecuteResult, err error) bool { stats.Add(numExecutionErrors, 1) result.Error = err.Error() allResults = append(allResults, result) @@ -438,7 +417,7 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) { continue } - result := &Result{} + result := &command.ExecuteResult{} start := time.Now() parameters, err := parametersToValues(stmt.Parameters) @@ -468,7 +447,7 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) { } break } - result.LastInsertID = lid + result.LastInsertId = lid ra, err := r.RowsAffected() if err != nil { @@ -491,7 +470,7 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*Result, error) { } // QueryStringStmt executes a single query that return rows, but don't modify database. -func (db *DB) QueryStringStmt(query string) ([]*Rows, error) { +func (db *DB) QueryStringStmt(query string) ([]*command.QueryRows, error) { r := &command.Request{ Statements: []*command.Statement{ { @@ -503,7 +482,7 @@ func (db *DB) QueryStringStmt(query string) ([]*Rows, error) { } // Query executes queries that return rows, but don't modify the database. -func (db *DB) Query(req *command.Request, xTime bool) ([]*Rows, error) { +func (db *DB) Query(req *command.Request, xTime bool) ([]*command.QueryRows, error) { stats.Add(numQueries, int64(len(req.Statements))) conn, err := db.roDB.Conn(context.Background()) if err != nil { @@ -513,7 +492,7 @@ func (db *DB) Query(req *command.Request, xTime bool) ([]*Rows, error) { return db.queryWithConn(req, xTime, conn) } -func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([]*Rows, error) { +func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([]*command.QueryRows, error) { var err error type Queryer interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) @@ -533,14 +512,14 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([ queryer = conn } - var allRows []*Rows + var allRows []*command.QueryRows for _, stmt := range req.Statements { sql := stmt.Sql if sql == "" { continue } - rows := &Rows{} + rows := &command.QueryRows{} start := time.Now() parameters, err := parametersToValues(stmt.Parameters) @@ -583,8 +562,9 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([ if err := rs.Scan(ptrs...); err != nil { return nil, err } - values := normalizeRowValues(dest, xTypes) - rows.Values = append(rows.Values, values) + rows.Values = append(rows.Values, &command.Values{ + Parameters: normalizeRowValues(dest, xTypes), + }) } // Check for errors from iterating over rows. @@ -696,7 +676,7 @@ func (db *DB) Dump(w io.Writer) error { } row := rows[0] for _, v := range row.Values { - table := v[0].(string) + table := v.Parameters[0].GetS() var stmt string if table == "sqlite_sequence" { @@ -706,7 +686,7 @@ func (db *DB) Dump(w io.Writer) error { } else if strings.HasPrefix(table, "sqlite_") { continue } else { - stmt = v[2].(string) + stmt = v.Parameters[2].GetS() } if _, err := w.Write([]byte(fmt.Sprintf("%s;\n", stmt))); err != nil { @@ -721,7 +701,7 @@ func (db *DB) Dump(w io.Writer) error { } var columnNames []string for _, w := range r[0].Values { - columnNames = append(columnNames, fmt.Sprintf(`'||quote("%s")||'`, w[1].(string))) + columnNames = append(columnNames, fmt.Sprintf(`'||quote("%s")||'`, w.Parameters[1].GetS())) } query = fmt.Sprintf(`SELECT 'INSERT INTO "%s" VALUES(%s)' FROM "%s";`, @@ -734,7 +714,7 @@ func (db *DB) Dump(w io.Writer) error { return err } for _, x := range r[0].Values { - y := fmt.Sprintf("%s;\n", x[0].(string)) + y := fmt.Sprintf("%s;\n", x.Parameters[0].GetS()) if _, err := w.Write([]byte(y)); err != nil { return err } @@ -750,7 +730,7 @@ func (db *DB) Dump(w io.Writer) error { } row = rows[0] for _, v := range row.Values { - if _, err := w.Write([]byte(fmt.Sprintf("%s;\n", v[2]))); err != nil { + if _, err := w.Write([]byte(fmt.Sprintf("%s;\n", v.Parameters[2].GetS()))); err != nil { return err } } @@ -839,16 +819,50 @@ func parametersToValues(parameters []*command.Parameter) ([]interface{}, error) // Text values come over (from sqlite-go) as []byte instead of strings // for some reason, so we have explicitly convert (but only when type // is "text" so we don't affect BLOB types) -func normalizeRowValues(row []interface{}, types []string) []interface{} { +func normalizeRowValues(row []interface{}, types []string) []*command.Parameter { + values := make([]*command.Parameter, len(types)) for i, v := range row { - if isTextType(types[i]) { - val, ok := v.([]byte) - if ok { - row[i] = string(val) + switch val := v.(type) { + case int: + case int64: + values[i] = &command.Parameter{ + Value: &command.Parameter_I{ + I: val, + }, + } + case float64: + values[i] = &command.Parameter{ + Value: &command.Parameter_D{ + D: val, + }, + } + case bool: + values[i] = &command.Parameter{ + Value: &command.Parameter_B{ + B: val, + }, + } + case string: + values[i] = &command.Parameter{ + Value: &command.Parameter_S{ + S: val, + }, + } + case []byte: + if isTextType(types[i]) { + values[i].Value = &command.Parameter_S{ + S: string(val), + } + } else { + values[i] = &command.Parameter{ + Value: &command.Parameter_Y{ + Y: val, + }, + } } } } - return row + return values } // isTextType returns whether the given type has a SQLite text affinity. diff --git a/db/db_test.go b/db/db_test.go index ca880dfc..7888b429 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -1,7 +1,6 @@ package db import ( - "encoding/json" "fmt" "io/ioutil" "os" @@ -12,6 +11,7 @@ import ( "time" "github.com/rqlite/rqlite/command" + "github.com/rqlite/rqlite/command/encoding" "github.com/rqlite/rqlite/testdata/chinook" ) @@ -1529,9 +1529,9 @@ func mustQuery(db *DB, stmt string) { } func asJSON(v interface{}) string { - b, err := json.Marshal(v) + b, err := encoding.JSONMarshal(v) if err != nil { - panic("failed to JSON marshal value") + panic(fmt.Sprintf("failed to JSON marshal value: %s", err.Error())) } return string(b) }