diff --git a/db/db.go b/db/db.go index 02ab8cf0..8424b983 100644 --- a/db/db.go +++ b/db/db.go @@ -51,6 +51,9 @@ var ( // ErrQueryTimeout is returned when a query times out. ErrQueryTimeout = errors.New("query timeout") + + // ErrExecuteTimeout is returned when an execute times out. + ErrExecuteTimeout = errors.New("execute timeout") ) // CheckpointMode is the mode in which a checkpoint runs. @@ -663,6 +666,7 @@ func (db *DB) executeStmtWithConn(ctx context.Context, stmt *command.Statement, r, err := e.ExecContext(ctx, stmt.Sql, parameters...) if err != nil { + err = rewriteContextTimeout(err, ErrExecuteTimeout) result.Error = err.Error() return result, err } @@ -862,7 +866,7 @@ func (db *DB) queryStmtWithConn(ctx context.Context, stmt *command.Statement, xT // Check for errors from iterating over rows. if err := rs.Err(); err != nil { stats.Add(numQueryErrors, 1) - rows.Error = rewriteContextTimeout(err).Error() + rows.Error = rewriteContextTimeout(err, ErrQueryTimeout).Error() return rows, nil } @@ -1516,9 +1520,9 @@ func lastModified(path string) (t time.Time, retError error) { return info.ModTime(), nil } -func rewriteContextTimeout(err error) error { +func rewriteContextTimeout(err, retErr error) error { if err == context.DeadlineExceeded { - return ErrQueryTimeout + return retErr } return err } diff --git a/db/db_test.go b/db/db_test.go index 8551f630..4b0fd448 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -1054,8 +1054,8 @@ FROM test_table t1 LEFT OUTER JOIN test_table t2` } res := r[0] - if !strings.Contains(res.Error, "context deadline exceeded") { - t.Fatalf("expected context.DeadlineExceeded, got %s", res.Error) + if !strings.Contains(res.Error, ErrExecuteTimeout.Error()) { + t.Fatalf("expected execute timeout, got %s", res.Error) } qr, err := db.QueryStringStmt("SELECT COUNT(*) FROM test_table") @@ -1125,7 +1125,7 @@ func Test_RequestShouldTimeout(t *testing.T) { } r := res[0] - if !strings.Contains(r.GetQ().Error, "context deadline exceeded") { + if !strings.Contains(r.GetQ().Error, ErrQueryTimeout.Error()) { t.Fatalf("expected context.DeadlineExceeded, got %s", r.GetQ().Error) } }