From 9057cf008371ebf658c2553a3e1db7ac1e83c589 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 3 Feb 2024 10:03:27 -0500 Subject: [PATCH] Move context creation to top-level This ensures that the context object is used for transactional queries too. --- db/db.go | 64 ++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 25 deletions(-) diff --git a/db/db.go b/db/db.go index 0358254a..0705399a 100644 --- a/db/db.go +++ b/db/db.go @@ -568,21 +568,28 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*command.ExecuteResul return nil, err } defer conn.Close() - return db.executeWithConn(req, xTime, conn) + + ctx := context.Background() + if req.DbTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(req.DbTimeout)) + defer cancel() + } + return db.executeWithConn(ctx, req, xTime, conn) } type execer interface { ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) } -func (db *DB) executeWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([]*command.ExecuteResult, error) { +func (db *DB) executeWithConn(ctx context.Context, req *command.Request, xTime bool, conn *sql.Conn) ([]*command.ExecuteResult, error) { var err error var execer execer var tx *sql.Tx if req.Transaction { stats.Add(numETx, 1) - tx, err = conn.BeginTx(context.Background(), nil) + tx, err = conn.BeginTx(ctx, nil) if err != nil { return nil, err } @@ -619,7 +626,7 @@ func (db *DB) executeWithConn(req *command.Request, xTime bool, conn *sql.Conn) continue } - result, err := db.executeStmtWithConn(stmt, xTime, execer, time.Duration(req.DbTimeout)) + result, err := db.executeStmtWithConn(ctx, stmt, xTime, execer, time.Duration(req.DbTimeout)) if err != nil { if handleError(result, err) { continue @@ -635,7 +642,7 @@ func (db *DB) executeWithConn(req *command.Request, xTime bool, conn *sql.Conn) return allResults, err } -func (db *DB) executeStmtWithConn(stmt *command.Statement, xTime bool, e execer, timeout time.Duration) (*command.ExecuteResult, error) { +func (db *DB) executeStmtWithConn(ctx context.Context, stmt *command.Statement, xTime bool, e execer, timeout time.Duration) (*command.ExecuteResult, error) { result := &command.ExecuteResult{} start := time.Now() @@ -645,7 +652,6 @@ func (db *DB) executeStmtWithConn(stmt *command.Statement, xTime bool, e execer, return result, nil } - ctx := context.Background() if timeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, timeout) @@ -715,21 +721,28 @@ func (db *DB) Query(req *command.Request, xTime bool) ([]*command.QueryRows, err return nil, err } defer conn.Close() - return db.queryWithConn(req, xTime, conn) + + ctx := context.Background() + if req.DbTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(req.DbTimeout)) + defer cancel() + } + return db.queryWithConn(ctx, req, xTime, conn) } type queryer interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } -func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([]*command.QueryRows, error) { +func (db *DB) queryWithConn(ctx context.Context, req *command.Request, xTime bool, conn *sql.Conn) ([]*command.QueryRows, error) { var err error var queryer queryer var tx *sql.Tx if req.Transaction { stats.Add(numQTx, 1) - tx, err = conn.BeginTx(context.Background(), nil) + tx, err = conn.BeginTx(ctx, nil) if err != nil { return nil, err } @@ -767,7 +780,7 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([ continue } - rows, err = db.queryStmtWithConn(stmt, xTime, queryer, time.Duration(req.DbTimeout)) + rows, err = db.queryStmtWithConn(ctx, stmt, xTime, queryer, time.Duration(req.DbTimeout)) if err != nil { stats.Add(numQueryErrors, 1) rows = &command.QueryRows{ @@ -783,7 +796,7 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([ return allRows, err } -func (db *DB) queryStmtWithConn(stmt *command.Statement, xTime bool, q queryer, timeout time.Duration) (*command.QueryRows, error) { +func (db *DB) queryStmtWithConn(ctx context.Context, stmt *command.Statement, xTime bool, q queryer, timeout time.Duration) (*command.QueryRows, error) { rows := &command.QueryRows{} start := time.Now() @@ -794,13 +807,6 @@ func (db *DB) queryStmtWithConn(stmt *command.Statement, xTime bool, q queryer, return rows, nil } - ctx := context.Background() - if timeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, timeout) - defer cancel() - } - rs, err := q.QueryContext(ctx, stmt.Sql, parameters...) if err != nil { stats.Add(numQueryErrors, 1) @@ -897,12 +903,19 @@ func (db *DB) Request(req *command.Request, xTime bool) ([]*command.ExecuteQuery } defer conn.Close() + ctx := context.Background() + if req.DbTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(req.DbTimeout)) + defer cancel() + } + var queryer queryer var execer execer var tx *sql.Tx if req.Transaction { stats.Add(numRTx, 1) - tx, err = conn.BeginTx(context.Background(), nil) + tx, err = conn.BeginTx(ctx, nil) if err != nil { return nil, err } @@ -943,13 +956,13 @@ func (db *DB) Request(req *command.Request, xTime bool) ([]*command.ExecuteQuery } if ro { - rows, opErr := db.queryStmtWithConn(stmt, xTime, queryer, time.Duration(req.DbTimeout)) + rows, opErr := db.queryStmtWithConn(ctx, stmt, xTime, queryer, time.Duration(req.DbTimeout)) eqResponse = append(eqResponse, createEQQueryResponse(rows, opErr)) if abortOnError(opErr) { break } } else { - result, opErr := db.executeStmtWithConn(stmt, xTime, execer, time.Duration(req.DbTimeout)) + result, opErr := db.executeStmtWithConn(ctx, stmt, xTime, execer, time.Duration(req.DbTimeout)) eqResponse = append(eqResponse, createEQExecuteResponse(result, opErr)) if abortOnError(opErr) { break @@ -1045,6 +1058,7 @@ func (db *DB) Dump(w io.Writer) error { return err } defer conn.Close() + ctx := context.Background() // Convenience function to convert string query to protobuf. commReq := func(query string) *command.Request { @@ -1064,7 +1078,7 @@ func (db *DB) Dump(w io.Writer) error { // Get the schema. query := `SELECT "name", "type", "sql" FROM "sqlite_master" WHERE "sql" NOT NULL AND "type" == 'table' ORDER BY "name"` - rows, err := db.queryWithConn(commReq(query), false, conn) + rows, err := db.queryWithConn(ctx, commReq(query), false, conn) if err != nil { return err } @@ -1088,7 +1102,7 @@ func (db *DB) Dump(w io.Writer) error { } tableIndent := strings.Replace(table, `"`, `""`, -1) - r, err := db.queryWithConn(commReq(fmt.Sprintf(`PRAGMA table_info("%s")`, tableIndent)), + r, err := db.queryWithConn(ctx, commReq(fmt.Sprintf(`PRAGMA table_info("%s")`, tableIndent)), false, conn) if err != nil { return err @@ -1102,7 +1116,7 @@ func (db *DB) Dump(w io.Writer) error { tableIndent, strings.Join(columnNames, ","), tableIndent) - r, err = db.queryWithConn(commReq(query), false, conn) + r, err = db.queryWithConn(ctx, commReq(query), false, conn) if err != nil { return err @@ -1118,7 +1132,7 @@ func (db *DB) Dump(w io.Writer) error { // Do indexes, triggers, and views. query = `SELECT "name", "type", "sql" FROM "sqlite_master" WHERE "sql" NOT NULL AND "type" IN ('index', 'trigger', 'view')` - rows, err = db.queryWithConn(commReq(query), false, conn) + rows, err = db.queryWithConn(ctx, commReq(query), false, conn) if err != nil { return err }