1
0
Fork 0

Merge pull request #1669 from rqlite/system-test-query-timeout

System-level testing of Query Timeouts
master
Philip O'Toole 8 months ago committed by GitHub
commit 89e189a893
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,6 +1,6 @@
## 8.19.0 (unreleased)
### New features
- [PR #1666](https://github.com/rqlite/rqlite/pull/1667), [PR #1667](https://github.com/rqlite/rqlite/pull/1667): Support timing out if query doesn't finish within specified interval. Fixes issue [#1657](https://github.com/rqlite/rqlite/issues/1657). Thanks @mauri870
- [PR #1666](https://github.com/rqlite/rqlite/pull/1667), [PR #1667](https://github.com/rqlite/rqlite/pull/1667), [PR #1669](https://github.com/rqlite/rqlite/pull/1669): Support timing out if query doesn't finish within specified interval. Fixes issue [#1657](https://github.com/rqlite/rqlite/issues/1657). Thanks @mauri870
### Implementation changes and bug fixes
- [PR #1665](https://github.com/rqlite/rqlite/pull/1665): Minor improvements to `random` module.

@ -134,6 +134,11 @@ func Test_UploaderFailThenOK(t *testing.T) {
wg.Add(2)
sc := &mockStorageClient{
uploadFn: func(ctx context.Context, reader io.Reader, id string) error {
if uploadCount == 2 {
// uploadFn can be called a third time before the cancel kicks in.
// This would push waitGroup into negative numbers and panic.
return nil
}
defer wg.Done()
if uploadCount == 0 {
uploadCount++
@ -175,6 +180,11 @@ func Test_UploaderOKThenFail(t *testing.T) {
wg.Add(2)
sc := &mockStorageClient{
uploadFn: func(ctx context.Context, reader io.Reader, id string) error {
if uploadCount == 2 {
// uploadFn can be called a third time before the cancel kicks in.
// This would push waitGroup into negative numbers and panic.
return nil
}
defer wg.Done()
if uploadCount == 1 {

@ -14,6 +14,32 @@ import (
"github.com/rqlite/rqlite/v8/aws"
)
// stats captures stats for the Uploader service.
var stats *expvar.Map
var (
gzipMagic = []byte{0x1f, 0x8b, 0x08}
)
const (
numDownloadsOK = "num_downloads_ok"
numDownloadsFail = "num_downloads_fail"
numDownloadBytes = "download_bytes"
)
func init() {
stats = expvar.NewMap("downloader")
ResetStats()
}
// ResetStats resets the expvar stats for this module. Mostly for test purposes.
func ResetStats() {
stats.Init()
stats.Add(numDownloadsOK, 0)
stats.Add(numDownloadsFail, 0)
stats.Add(numDownloadBytes, 0)
}
// DownloadFile downloads the auto-restore file from the given URL, and returns the path to
// the downloaded file. If the download fails, and the config is marked as continue-on-failure, then
// the error is returned, but errOK is set to true. If the download fails, and the file is not
@ -64,32 +90,6 @@ type StorageClient interface {
fmt.Stringer
}
// stats captures stats for the Uploader service.
var stats *expvar.Map
var (
gzipMagic = []byte{0x1f, 0x8b, 0x08}
)
const (
numDownloadsOK = "num_downloads_ok"
numDownloadsFail = "num_downloads_fail"
numDownloadBytes = "download_bytes"
)
func init() {
stats = expvar.NewMap("downloader")
ResetStats()
}
// ResetStats resets the expvar stats for this module. Mostly for test purposes.
func ResetStats() {
stats.Init()
stats.Add(numDownloadsOK, 0)
stats.Add(numDownloadsFail, 0)
stats.Add(numDownloadBytes, 0)
}
type Downloader struct {
storageClient StorageClient
logger *log.Logger

@ -48,6 +48,12 @@ var (
// ErrWALReplayDirectoryMismatch is returned when the WAL file(s) are not in the same
// directory as the database file.
ErrWALReplayDirectoryMismatch = errors.New("WAL file(s) not in same directory as database file")
// ErrQueryTimeout is returned when a query times out.
ErrQueryTimeout = errors.New("query timeout")
// ErrExecuteTimeout is returned when an execute times out.
ErrExecuteTimeout = errors.New("execute timeout")
)
// CheckpointMode is the mode in which a checkpoint runs.
@ -568,21 +574,28 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*command.ExecuteResul
return nil, err
}
defer conn.Close()
return db.executeWithConn(req, xTime, conn)
ctx := context.Background()
if req.DbTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(req.DbTimeout))
defer cancel()
}
return db.executeWithConn(ctx, req, xTime, conn)
}
type execer interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
func (db *DB) executeWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([]*command.ExecuteResult, error) {
func (db *DB) executeWithConn(ctx context.Context, req *command.Request, xTime bool, conn *sql.Conn) ([]*command.ExecuteResult, error) {
var err error
var execer execer
var tx *sql.Tx
if req.Transaction {
stats.Add(numETx, 1)
tx, err = conn.BeginTx(context.Background(), nil)
tx, err = conn.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
@ -619,7 +632,7 @@ func (db *DB) executeWithConn(req *command.Request, xTime bool, conn *sql.Conn)
continue
}
result, err := db.executeStmtWithConn(stmt, xTime, execer, time.Duration(req.DbTimeout))
result, err := db.executeStmtWithConn(ctx, stmt, xTime, execer, time.Duration(req.DbTimeout))
if err != nil {
if handleError(result, err) {
continue
@ -635,7 +648,15 @@ func (db *DB) executeWithConn(req *command.Request, xTime bool, conn *sql.Conn)
return allResults, err
}
func (db *DB) executeStmtWithConn(stmt *command.Statement, xTime bool, e execer, timeout time.Duration) (*command.ExecuteResult, error) {
func (db *DB) executeStmtWithConn(ctx context.Context, stmt *command.Statement, xTime bool, e execer, timeout time.Duration) (res *command.ExecuteResult, retErr error) {
defer func() {
if retErr != nil {
retErr = rewriteContextTimeout(retErr, ErrExecuteTimeout)
if res != nil {
res.Error = retErr.Error()
}
}
}()
result := &command.ExecuteResult{}
start := time.Now()
@ -645,7 +666,6 @@ func (db *DB) executeStmtWithConn(stmt *command.Statement, xTime bool, e execer,
return result, nil
}
ctx := context.Background()
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
@ -695,14 +715,15 @@ func (db *DB) QueryStringStmt(query string) ([]*command.QueryRows, error) {
// QueryStringStmtWithTimeout executes a single query that return rows, but don't modify database.
// It also sets a timeout for the query.
func (db *DB) QueryStringStmtWithTimeout(query string, timeout time.Duration) ([]*command.QueryRows, error) {
func (db *DB) QueryStringStmtWithTimeout(query string, tx bool, timeout time.Duration) ([]*command.QueryRows, error) {
r := &command.Request{
Statements: []*command.Statement{
{
Sql: query,
},
},
DbTimeout: int64(timeout),
Transaction: tx,
DbTimeout: int64(timeout),
}
return db.Query(r, false)
}
@ -715,21 +736,28 @@ func (db *DB) Query(req *command.Request, xTime bool) ([]*command.QueryRows, err
return nil, err
}
defer conn.Close()
return db.queryWithConn(req, xTime, conn)
ctx := context.Background()
if req.DbTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(req.DbTimeout))
defer cancel()
}
return db.queryWithConn(ctx, req, xTime, conn)
}
type queryer interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([]*command.QueryRows, error) {
func (db *DB) queryWithConn(ctx context.Context, req *command.Request, xTime bool, conn *sql.Conn) ([]*command.QueryRows, error) {
var err error
var queryer queryer
var tx *sql.Tx
if req.Transaction {
stats.Add(numQTx, 1)
tx, err = conn.BeginTx(context.Background(), nil)
tx, err = conn.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
@ -767,7 +795,7 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([
continue
}
rows, err = db.queryStmtWithConn(stmt, xTime, queryer, time.Duration(req.DbTimeout))
rows, err = db.queryStmtWithConn(ctx, stmt, xTime, queryer, time.Duration(req.DbTimeout))
if err != nil {
stats.Add(numQueryErrors, 1)
rows = &command.QueryRows{
@ -783,7 +811,15 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([
return allRows, err
}
func (db *DB) queryStmtWithConn(stmt *command.Statement, xTime bool, q queryer, timeout time.Duration) (*command.QueryRows, error) {
func (db *DB) queryStmtWithConn(ctx context.Context, stmt *command.Statement, xTime bool, q queryer, timeout time.Duration) (retRows *command.QueryRows, retErr error) {
defer func() {
if retErr != nil {
retErr = rewriteContextTimeout(retErr, ErrQueryTimeout)
if retRows != nil {
retRows.Error = retErr.Error()
}
}
}()
rows := &command.QueryRows{}
start := time.Now()
@ -794,18 +830,11 @@ func (db *DB) queryStmtWithConn(stmt *command.Statement, xTime bool, q queryer,
return rows, nil
}
ctx := context.Background()
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
rs, err := q.QueryContext(ctx, stmt.Sql, parameters...)
if err != nil {
stats.Add(numQueryErrors, 1)
rows.Error = err.Error()
return rows, nil
return rows, err
}
defer rs.Close()
@ -852,8 +881,7 @@ func (db *DB) queryStmtWithConn(stmt *command.Statement, xTime bool, q queryer,
// Check for errors from iterating over rows.
if err := rs.Err(); err != nil {
stats.Add(numQueryErrors, 1)
rows.Error = err.Error()
return rows, nil
return rows, err
}
if xTime {
@ -897,12 +925,19 @@ func (db *DB) Request(req *command.Request, xTime bool) ([]*command.ExecuteQuery
}
defer conn.Close()
ctx := context.Background()
if req.DbTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(req.DbTimeout))
defer cancel()
}
var queryer queryer
var execer execer
var tx *sql.Tx
if req.Transaction {
stats.Add(numRTx, 1)
tx, err = conn.BeginTx(context.Background(), nil)
tx, err = conn.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
@ -943,13 +978,13 @@ func (db *DB) Request(req *command.Request, xTime bool) ([]*command.ExecuteQuery
}
if ro {
rows, opErr := db.queryStmtWithConn(stmt, xTime, queryer, time.Duration(req.DbTimeout))
rows, opErr := db.queryStmtWithConn(ctx, stmt, xTime, queryer, time.Duration(req.DbTimeout))
eqResponse = append(eqResponse, createEQQueryResponse(rows, opErr))
if abortOnError(opErr) {
break
}
} else {
result, opErr := db.executeStmtWithConn(stmt, xTime, execer, time.Duration(req.DbTimeout))
result, opErr := db.executeStmtWithConn(ctx, stmt, xTime, execer, time.Duration(req.DbTimeout))
eqResponse = append(eqResponse, createEQExecuteResponse(result, opErr))
if abortOnError(opErr) {
break
@ -1045,6 +1080,7 @@ func (db *DB) Dump(w io.Writer) error {
return err
}
defer conn.Close()
ctx := context.Background()
// Convenience function to convert string query to protobuf.
commReq := func(query string) *command.Request {
@ -1064,7 +1100,7 @@ func (db *DB) Dump(w io.Writer) error {
// Get the schema.
query := `SELECT "name", "type", "sql" FROM "sqlite_master"
WHERE "sql" NOT NULL AND "type" == 'table' ORDER BY "name"`
rows, err := db.queryWithConn(commReq(query), false, conn)
rows, err := db.queryWithConn(ctx, commReq(query), false, conn)
if err != nil {
return err
}
@ -1088,7 +1124,7 @@ func (db *DB) Dump(w io.Writer) error {
}
tableIndent := strings.Replace(table, `"`, `""`, -1)
r, err := db.queryWithConn(commReq(fmt.Sprintf(`PRAGMA table_info("%s")`, tableIndent)),
r, err := db.queryWithConn(ctx, commReq(fmt.Sprintf(`PRAGMA table_info("%s")`, tableIndent)),
false, conn)
if err != nil {
return err
@ -1102,7 +1138,7 @@ func (db *DB) Dump(w io.Writer) error {
tableIndent,
strings.Join(columnNames, ","),
tableIndent)
r, err = db.queryWithConn(commReq(query), false, conn)
r, err = db.queryWithConn(ctx, commReq(query), false, conn)
if err != nil {
return err
@ -1118,7 +1154,7 @@ func (db *DB) Dump(w io.Writer) error {
// Do indexes, triggers, and views.
query = `SELECT "name", "type", "sql" FROM "sqlite_master"
WHERE "sql" NOT NULL AND "type" IN ('index', 'trigger', 'view')`
rows, err = db.queryWithConn(commReq(query), false, conn)
rows, err = db.queryWithConn(ctx, commReq(query), false, conn)
if err != nil {
return err
}
@ -1497,3 +1533,10 @@ func lastModified(path string) (t time.Time, retError error) {
}
return info.ModTime(), nil
}
func rewriteContextTimeout(err, retErr error) error {
if err == context.DeadlineExceeded {
return retErr
}
return err
}

@ -1027,16 +1027,23 @@ func mustSetupDBForTimeoutTests(t *testing.T, n int) (*DB, string) {
})
}
// Insert the records, and confirm that they were inserted.
_, err := db.Execute(req, false)
if err != nil {
t.Fatalf("failed to insert records: %s", err.Error())
}
qr, err := db.QueryStringStmt("SELECT COUNT(*) FROM test_table")
if err != nil {
t.Fatalf("error counting rows: %s", err.Error())
}
if want, got := fmt.Sprintf(`[{"columns":["COUNT(*)"],"types":["integer"],"values":[[%d]]}]`, n), asJSON(qr); want != got {
t.Fatalf("want response %s, got %s", want, got)
}
return db, path
}
func Test_ExecShouldTimeout(t *testing.T) {
db, path := mustSetupDBForTimeoutTests(t, 1000)
db, path := mustSetupDBForTimeoutTests(t, 5000)
defer db.Close()
defer os.Remove(path)
@ -1054,51 +1061,60 @@ FROM test_table t1 LEFT OUTER JOIN test_table t2`
}
res := r[0]
if !strings.Contains(res.Error, "context deadline exceeded") {
t.Fatalf("expected context.DeadlineExceeded, got %s", res.Error)
}
qr, err := db.QueryStringStmt("SELECT COUNT(*) FROM test_table")
if err != nil {
t.Fatalf("error counting rows: %s", err.Error())
}
if want, got := `[{"columns":["COUNT(*)"],"types":["integer"],"values":[[1000]]}]`, asJSON(qr); want != got {
t.Fatalf("want response %s, got %s", want, got)
if !strings.Contains(res.Error, ErrExecuteTimeout.Error()) {
t.Fatalf("expected execute timeout, got %s", res.Error)
}
}
func Test_QueryShouldTimeout(t *testing.T) {
db, path := mustSetupDBForTimeoutTests(t, 1000)
db, path := mustSetupDBForTimeoutTests(t, 5000)
defer db.Close()
defer os.Remove(path)
q := `SELECT key1, key_id, key2, key3, key4, key5, key6, data
FROM test_table
ORDER BY key2 ASC`
r, err := db.QueryStringStmtWithTimeout(q, 1*time.Microsecond)
// Without tx....
r, err := db.QueryStringStmtWithTimeout(q, false, 1*time.Millisecond)
if err != nil {
t.Fatalf("failed to run query: %s", err.Error())
}
if len(r) != 1 {
t.Fatalf("expected one result, got %d: %s", len(r), asJSON(r))
}
res := r[0]
if !strings.Contains(res.Error, "context deadline exceeded") {
t.Fatalf("expected context.DeadlineExceeded, got %s", res.Error)
if !strings.Contains(res.Error, ErrQueryTimeout.Error()) {
t.Fatalf("expected query timeout, got %s", res.Error)
}
// ... and with tx
r, err = db.QueryStringStmtWithTimeout(q, true, 1*time.Millisecond)
if err != nil {
if !strings.Contains(err.Error(), "context deadline exceeded") &&
!strings.Contains(err.Error(), "transaction has already been committed or rolled back") {
t.Fatalf("failed to run query: %s", err.Error())
}
} else {
if len(r) != 1 {
t.Fatalf("expected one result, got %d: %s", len(r), asJSON(r))
}
res = r[0]
if !strings.Contains(res.Error, ErrQueryTimeout.Error()) {
t.Fatalf("expected query timeout, got %s", res.Error)
}
}
}
func Test_RequestShouldTimeout(t *testing.T) {
db, path := mustSetupDBForTimeoutTests(t, 1000)
db, path := mustSetupDBForTimeoutTests(t, 5000)
defer db.Close()
defer os.Remove(path)
q := `SELECT key1, key_id, key2, key3, key4, key5, key6, data
FROM test_table
ORDER BY key2 ASC`
res, err := db.RequestStringStmtsWithTimeout([]string{q}, 1*time.Microsecond)
res, err := db.RequestStringStmtsWithTimeout([]string{q}, 1*time.Millisecond)
if err != nil {
t.Fatalf("failed to run query: %s", err.Error())
}
@ -1108,7 +1124,7 @@ func Test_RequestShouldTimeout(t *testing.T) {
}
r := res[0]
if !strings.Contains(r.GetQ().Error, "context deadline exceeded") {
if !strings.Contains(r.GetQ().Error, ErrQueryTimeout.Error()) {
t.Fatalf("expected context.DeadlineExceeded, got %s", r.GetQ().Error)
}
}

@ -150,6 +150,8 @@ class TestAutoRestoreS3(unittest.TestCase):
n0.start()
n0.wait_for_ready()
n0.execute('CREATE TABLE bar (id INTEGER NOT NULL PRIMARY KEY, name TEXT)')
j = n0.query('SELECT * FROM bar', level='strong')
self.assertEqual(j, d_("{'results': [{'types': ['integer', 'text'], 'columns': ['id', 'name']}]}"))
n0.stop()
# Create a new node, using the directory from the previous node, but check

@ -37,6 +37,8 @@ const (
// ElectionTimeout is the period between elections. It's longer than
// the default to allow for slow CI systems.
ElectionTimeout = 2 * time.Second
NoQueryTimeout = 0
)
var (
@ -132,17 +134,22 @@ func (n *Node) ExecuteQueuedMulti(stmts []string, wait bool) (string, error) {
// Query runs a single query against the node.
func (n *Node) Query(stmt string) (string, error) {
return n.query(stmt, "weak")
return n.query(stmt, "weak", NoQueryTimeout)
}
// QueryWithTimeout runs a single query against the node, with a timeout.
func (n *Node) QueryWithTimeout(stmt string, timeout time.Duration) (string, error) {
return n.query(stmt, "weak", timeout)
}
// QueryNoneConsistency runs a single query against the node, with no read consistency.
func (n *Node) QueryNoneConsistency(stmt string) (string, error) {
return n.query(stmt, "none")
return n.query(stmt, "none", NoQueryTimeout)
}
// QueryStrongConsistency runs a single query against the node, with Strong read consistency.
func (n *Node) QueryStrongConsistency(stmt string) (string, error) {
return n.query(stmt, "strong")
return n.query(stmt, "strong", NoQueryTimeout)
}
// QueryMulti runs multiple queries against the node.
@ -444,11 +451,12 @@ func (n *Node) postExecuteQueued(stmt string, wait bool) (string, error) {
return string(body), nil
}
func (n *Node) query(stmt, consistency string) (string, error) {
func (n *Node) query(stmt, consistency string, timeout time.Duration) (string, error) {
v, _ := url.Parse("http://" + n.APIAddr + "/db/query")
v.RawQuery = url.Values{
"q": []string{stmt},
"level": []string{consistency},
"q": []string{stmt},
"level": []string{consistency},
"db_timeout": []string{timeout.String()},
}.Encode()
resp, err := http.Get(v.String())

@ -10,12 +10,14 @@ import (
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"testing"
"time"
"github.com/rqlite/rqlite/v8/cluster"
httpd "github.com/rqlite/rqlite/v8/http"
"github.com/rqlite/rqlite/v8/random"
"github.com/rqlite/rqlite/v8/store"
"github.com/rqlite/rqlite/v8/tcp"
)
@ -520,6 +522,70 @@ func Test_SingleNodeParameterizedNamedConstraints(t *testing.T) {
}
}
func Test_SingleNodeQueryTimeout(t *testing.T) {
node := mustNewLeaderNode("leader1")
defer node.Deprovision()
sql := `CREATE TABLE IF NOT EXISTS test_table (
key1 VARCHAR(64) PRIMARY KEY,
key_id VARCHAR(64) NOT NULL,
key2 VARCHAR(64) NOT NULL,
key3 VARCHAR(64) NOT NULL,
key4 VARCHAR(64) NOT NULL,
key5 VARCHAR(64) NOT NULL,
key6 VARCHAR(64) NOT NULL,
data BLOB NOT NULL
)`
if _, err := node.Execute(sql); err != nil {
t.Fatalf("failed to create table: %s", err.Error())
}
// Bulk insert rows (for speed and to avoid snapshotting)
sqls := make([]string, 5000)
for i := 0; i < cap(sqls); i++ {
args := []any{
random.String(),
fmt.Sprint(i),
random.String(),
random.String(),
random.String(),
random.String(),
random.String(),
random.String(),
}
sql := fmt.Sprintf(`INSERT INTO test_table
(key1, key_id, key2, key3, key4, key5, key6, data)
VALUES
(%q, %q, %q, %q, %q, %q, %q, %q);`, args...)
sqls[i] = sql
}
if _, err := node.ExecuteMulti(sqls); err != nil {
t.Fatalf("failed to insert data: %s", err.Error())
}
r, err := node.Query(`SELECT COUNT(*) FROM test_table`)
if err != nil {
t.Fatalf("failed to count records: %s", err.Error())
}
exp := fmt.Sprintf(`{"results":[{"columns":["COUNT(*)"],"types":["integer"],"values":[[%d]]}]}`, len(sqls))
if r != exp {
t.Fatalf("test received wrong result\nexp: %s\ngot: %s\n", exp, r)
}
q := `SELECT key1, key_id, key2, key3, key4, key5, key6, data
FROM test_table
ORDER BY key2 ASC`
r, err = node.QueryWithTimeout(q, 1*time.Millisecond)
if err != nil {
t.Fatalf("failed to query with timeout: %s", err.Error())
}
if !strings.Contains(r, `"error":"query timeout"`) {
// This test is brittle, but it's the best we can do, as we can't be sure
// how much of the query will actually be executed. We just know it should
// time out at some point.
t.Fatalf("query ran to completion, but should have timed out")
}
}
func Test_SingleNodeRewriteRandom(t *testing.T) {
node := mustNewLeaderNode("leader1")
defer node.Deprovision()

Loading…
Cancel
Save