diff --git a/db/db.go b/db/db.go index daa7bcca..0358254a 100644 --- a/db/db.go +++ b/db/db.go @@ -876,6 +876,18 @@ func (db *DB) RequestStringStmts(stmts []string) ([]*command.ExecuteQueryRespons return db.Request(req, false) } +// RequestStringStmtsWithTimeout processes a request that can contain both executes and queries. +func (db *DB) RequestStringStmtsWithTimeout(stmts []string, timeout time.Duration) ([]*command.ExecuteQueryResponse, error) { + req := &command.Request{} + for _, q := range stmts { + req.Statements = append(req.Statements, &command.Statement{ + Sql: q, + }) + } + req.DbTimeout = int64(timeout) + return db.Request(req, false) +} + // Request processes a request that can contain both executes and queries. func (db *DB) Request(req *command.Request, xTime bool) ([]*command.ExecuteQueryResponse, error) { stats.Add(numRequests, int64(len(req.Statements))) diff --git a/db/db_test.go b/db/db_test.go index e3cfadb0..f24970d3 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -1090,28 +1090,28 @@ func Test_QueryShouldTimeout(t *testing.T) { } } -// func Test_RequestShouldTimeout(t *testing.T) { -// db, path := mustSetupDBForTimeoutTests(t, 1000) -// defer db.Close() -// defer os.Remove(path) - -// q := `SELECT key1, key_id, key2, key3, key4, key5, key6, data -// FROM test_table -// ORDER BY key2 ASC` -// r, err := db.QueryStringStmtWithTimeout(q, 1*time.Microsecond) -// if err != nil { -// t.Fatalf("failed to run query: %s", err.Error()) -// } - -// if len(r) != 1 { -// t.Fatalf("expected one result, got %d: %s", len(r), asJSON(r)) -// } - -// res := r[0] -// if !strings.Contains(res.Error, "context deadline exceeded") { -// t.Fatalf("expected context.DeadlineExceeded, got %s", res.Error) -// } -// } +func Test_RequestShouldTimeout(t *testing.T) { + db, path := mustSetupDBForTimeoutTests(t, 1000) + defer db.Close() + defer os.Remove(path) + + q := `SELECT key1, key_id, key2, key3, key4, key5, key6, data + FROM test_table + ORDER BY key2 ASC` + res, err := db.RequestStringStmtsWithTimeout([]string{q}, 1*time.Microsecond) + if err != nil { + t.Fatalf("failed to run query: %s", err.Error()) + } + + if len(res) != 1 { + t.Fatalf("expected one result, got %d: %s", len(res), asJSON(res)) + } + + r := res[0] + if !strings.Contains(r.GetQ().Error, "context deadline exceeded") { + t.Fatalf("expected context.DeadlineExceeded, got %s", r.GetQ().Error) + } +} func mustCreateOnDiskDatabase() (*DB, string) { var err error