1
0
Fork 0

Move context creation to top-level

This ensures that the context object is used for transactional queries
too.
master
Philip O'Toole 8 months ago
parent c2776f3bbf
commit 9057cf0083

@ -568,21 +568,28 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*command.ExecuteResul
return nil, err
}
defer conn.Close()
return db.executeWithConn(req, xTime, conn)
ctx := context.Background()
if req.DbTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(req.DbTimeout))
defer cancel()
}
return db.executeWithConn(ctx, req, xTime, conn)
}
type execer interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
func (db *DB) executeWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([]*command.ExecuteResult, error) {
func (db *DB) executeWithConn(ctx context.Context, req *command.Request, xTime bool, conn *sql.Conn) ([]*command.ExecuteResult, error) {
var err error
var execer execer
var tx *sql.Tx
if req.Transaction {
stats.Add(numETx, 1)
tx, err = conn.BeginTx(context.Background(), nil)
tx, err = conn.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
@ -619,7 +626,7 @@ func (db *DB) executeWithConn(req *command.Request, xTime bool, conn *sql.Conn)
continue
}
result, err := db.executeStmtWithConn(stmt, xTime, execer, time.Duration(req.DbTimeout))
result, err := db.executeStmtWithConn(ctx, stmt, xTime, execer, time.Duration(req.DbTimeout))
if err != nil {
if handleError(result, err) {
continue
@ -635,7 +642,7 @@ func (db *DB) executeWithConn(req *command.Request, xTime bool, conn *sql.Conn)
return allResults, err
}
func (db *DB) executeStmtWithConn(stmt *command.Statement, xTime bool, e execer, timeout time.Duration) (*command.ExecuteResult, error) {
func (db *DB) executeStmtWithConn(ctx context.Context, stmt *command.Statement, xTime bool, e execer, timeout time.Duration) (*command.ExecuteResult, error) {
result := &command.ExecuteResult{}
start := time.Now()
@ -645,7 +652,6 @@ func (db *DB) executeStmtWithConn(stmt *command.Statement, xTime bool, e execer,
return result, nil
}
ctx := context.Background()
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
@ -715,21 +721,28 @@ func (db *DB) Query(req *command.Request, xTime bool) ([]*command.QueryRows, err
return nil, err
}
defer conn.Close()
return db.queryWithConn(req, xTime, conn)
ctx := context.Background()
if req.DbTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(req.DbTimeout))
defer cancel()
}
return db.queryWithConn(ctx, req, xTime, conn)
}
type queryer interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([]*command.QueryRows, error) {
func (db *DB) queryWithConn(ctx context.Context, req *command.Request, xTime bool, conn *sql.Conn) ([]*command.QueryRows, error) {
var err error
var queryer queryer
var tx *sql.Tx
if req.Transaction {
stats.Add(numQTx, 1)
tx, err = conn.BeginTx(context.Background(), nil)
tx, err = conn.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
@ -767,7 +780,7 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([
continue
}
rows, err = db.queryStmtWithConn(stmt, xTime, queryer, time.Duration(req.DbTimeout))
rows, err = db.queryStmtWithConn(ctx, stmt, xTime, queryer, time.Duration(req.DbTimeout))
if err != nil {
stats.Add(numQueryErrors, 1)
rows = &command.QueryRows{
@ -783,7 +796,7 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([
return allRows, err
}
func (db *DB) queryStmtWithConn(stmt *command.Statement, xTime bool, q queryer, timeout time.Duration) (*command.QueryRows, error) {
func (db *DB) queryStmtWithConn(ctx context.Context, stmt *command.Statement, xTime bool, q queryer, timeout time.Duration) (*command.QueryRows, error) {
rows := &command.QueryRows{}
start := time.Now()
@ -794,13 +807,6 @@ func (db *DB) queryStmtWithConn(stmt *command.Statement, xTime bool, q queryer,
return rows, nil
}
ctx := context.Background()
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
rs, err := q.QueryContext(ctx, stmt.Sql, parameters...)
if err != nil {
stats.Add(numQueryErrors, 1)
@ -897,12 +903,19 @@ func (db *DB) Request(req *command.Request, xTime bool) ([]*command.ExecuteQuery
}
defer conn.Close()
ctx := context.Background()
if req.DbTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(req.DbTimeout))
defer cancel()
}
var queryer queryer
var execer execer
var tx *sql.Tx
if req.Transaction {
stats.Add(numRTx, 1)
tx, err = conn.BeginTx(context.Background(), nil)
tx, err = conn.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
@ -943,13 +956,13 @@ func (db *DB) Request(req *command.Request, xTime bool) ([]*command.ExecuteQuery
}
if ro {
rows, opErr := db.queryStmtWithConn(stmt, xTime, queryer, time.Duration(req.DbTimeout))
rows, opErr := db.queryStmtWithConn(ctx, stmt, xTime, queryer, time.Duration(req.DbTimeout))
eqResponse = append(eqResponse, createEQQueryResponse(rows, opErr))
if abortOnError(opErr) {
break
}
} else {
result, opErr := db.executeStmtWithConn(stmt, xTime, execer, time.Duration(req.DbTimeout))
result, opErr := db.executeStmtWithConn(ctx, stmt, xTime, execer, time.Duration(req.DbTimeout))
eqResponse = append(eqResponse, createEQExecuteResponse(result, opErr))
if abortOnError(opErr) {
break
@ -1045,6 +1058,7 @@ func (db *DB) Dump(w io.Writer) error {
return err
}
defer conn.Close()
ctx := context.Background()
// Convenience function to convert string query to protobuf.
commReq := func(query string) *command.Request {
@ -1064,7 +1078,7 @@ func (db *DB) Dump(w io.Writer) error {
// Get the schema.
query := `SELECT "name", "type", "sql" FROM "sqlite_master"
WHERE "sql" NOT NULL AND "type" == 'table' ORDER BY "name"`
rows, err := db.queryWithConn(commReq(query), false, conn)
rows, err := db.queryWithConn(ctx, commReq(query), false, conn)
if err != nil {
return err
}
@ -1088,7 +1102,7 @@ func (db *DB) Dump(w io.Writer) error {
}
tableIndent := strings.Replace(table, `"`, `""`, -1)
r, err := db.queryWithConn(commReq(fmt.Sprintf(`PRAGMA table_info("%s")`, tableIndent)),
r, err := db.queryWithConn(ctx, commReq(fmt.Sprintf(`PRAGMA table_info("%s")`, tableIndent)),
false, conn)
if err != nil {
return err
@ -1102,7 +1116,7 @@ func (db *DB) Dump(w io.Writer) error {
tableIndent,
strings.Join(columnNames, ","),
tableIndent)
r, err = db.queryWithConn(commReq(query), false, conn)
r, err = db.queryWithConn(ctx, commReq(query), false, conn)
if err != nil {
return err
@ -1118,7 +1132,7 @@ func (db *DB) Dump(w io.Writer) error {
// Do indexes, triggers, and views.
query = `SELECT "name", "type", "sql" FROM "sqlite_master"
WHERE "sql" NOT NULL AND "type" IN ('index', 'trigger', 'view')`
rows, err = db.queryWithConn(commReq(query), false, conn)
rows, err = db.queryWithConn(ctx, commReq(query), false, conn)
if err != nil {
return err
}

Loading…
Cancel
Save