diff --git a/db/db.go b/db/db.go index 0705399a..adbf2c81 100644 --- a/db/db.go +++ b/db/db.go @@ -701,14 +701,15 @@ func (db *DB) QueryStringStmt(query string) ([]*command.QueryRows, error) { // QueryStringStmtWithTimeout executes a single query that return rows, but don't modify database. // It also sets a timeout for the query. -func (db *DB) QueryStringStmtWithTimeout(query string, timeout time.Duration) ([]*command.QueryRows, error) { +func (db *DB) QueryStringStmtWithTimeout(query string, tx bool, timeout time.Duration) ([]*command.QueryRows, error) { r := &command.Request{ Statements: []*command.Statement{ { Sql: query, }, }, - DbTimeout: int64(timeout), + Transaction: tx, + DbTimeout: int64(timeout), } return db.Query(r, false) } diff --git a/db/db_test.go b/db/db_test.go index f24970d3..cb307266 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -1075,19 +1075,36 @@ func Test_QueryShouldTimeout(t *testing.T) { q := `SELECT key1, key_id, key2, key3, key4, key5, key6, data FROM test_table ORDER BY key2 ASC` - r, err := db.QueryStringStmtWithTimeout(q, 1*time.Microsecond) + + // Without tx.... + r, err := db.QueryStringStmtWithTimeout(q, false, 1*time.Microsecond) if err != nil { t.Fatalf("failed to run query: %s", err.Error()) } - if len(r) != 1 { t.Fatalf("expected one result, got %d: %s", len(r), asJSON(r)) } - res := r[0] if !strings.Contains(res.Error, "context deadline exceeded") { t.Fatalf("expected context.DeadlineExceeded, got %s", res.Error) } + + // ... and with tx + r, err = db.QueryStringStmtWithTimeout(q, true, 1*time.Microsecond) + if err != nil { + if !strings.Contains(err.Error(), "context deadline exceeded") && + !strings.Contains(err.Error(), "transaction has already been committed or rolled back") { + t.Fatalf("failed to run query: %s", err.Error()) + } + } else { + if len(r) != 1 { + t.Fatalf("expected one result, got %d: %s", len(r), asJSON(r)) + } + res = r[0] + if !strings.Contains(res.Error, "context deadline exceeded") { + t.Fatalf("expected context.DeadlineExceeded, got %s", res.Error) + } + } } func Test_RequestShouldTimeout(t *testing.T) {