@ -48,6 +48,12 @@ var (
// ErrWALReplayDirectoryMismatch is returned when the WAL file(s) are not in the same
// directory as the database file.
ErrWALReplayDirectoryMismatch = errors . New ( "WAL file(s) not in same directory as database file" )
// ErrQueryTimeout is returned when a query times out.
ErrQueryTimeout = errors . New ( "query timeout" )
// ErrExecuteTimeout is returned when an execute times out.
ErrExecuteTimeout = errors . New ( "execute timeout" )
)
// CheckpointMode is the mode in which a checkpoint runs.
@ -568,21 +574,28 @@ func (db *DB) Execute(req *command.Request, xTime bool) ([]*command.ExecuteResul
return nil , err
}
defer conn . Close ( )
return db . executeWithConn ( req , xTime , conn )
ctx := context . Background ( )
if req . DbTimeout > 0 {
var cancel context . CancelFunc
ctx , cancel = context . WithTimeout ( ctx , time . Duration ( req . DbTimeout ) )
defer cancel ( )
}
return db . executeWithConn ( ctx , req , xTime , conn )
}
type execer interface {
ExecContext ( ctx context . Context , query string , args ... interface { } ) ( sql . Result , error )
}
func ( db * DB ) executeWithConn ( req * command . Request , xTime bool , conn * sql . Conn ) ( [ ] * command . ExecuteResult , error ) {
func ( db * DB ) executeWithConn ( ctx context . Context , req * command . Request , xTime bool , conn * sql . Conn ) ( [ ] * command . ExecuteResult , error ) {
var err error
var execer execer
var tx * sql . Tx
if req . Transaction {
stats . Add ( numETx , 1 )
tx , err = conn . BeginTx ( c on te xt. Background ( ) , nil )
tx , err = conn . BeginTx ( c tx, nil )
if err != nil {
return nil , err
}
@ -619,7 +632,7 @@ func (db *DB) executeWithConn(req *command.Request, xTime bool, conn *sql.Conn)
continue
}
result , err := db . executeStmtWithConn ( stmt, xTime , execer , time . Duration ( req . DbTimeout ) )
result , err := db . executeStmtWithConn ( ctx, stmt, xTime , execer , time . Duration ( req . DbTimeout ) )
if err != nil {
if handleError ( result , err ) {
continue
@ -635,7 +648,15 @@ func (db *DB) executeWithConn(req *command.Request, xTime bool, conn *sql.Conn)
return allResults , err
}
func ( db * DB ) executeStmtWithConn ( stmt * command . Statement , xTime bool , e execer , timeout time . Duration ) ( * command . ExecuteResult , error ) {
func ( db * DB ) executeStmtWithConn ( ctx context . Context , stmt * command . Statement , xTime bool , e execer , timeout time . Duration ) ( res * command . ExecuteResult , retErr error ) {
defer func ( ) {
if retErr != nil {
retErr = rewriteContextTimeout ( retErr , ErrExecuteTimeout )
if res != nil {
res . Error = retErr . Error ( )
}
}
} ( )
result := & command . ExecuteResult { }
start := time . Now ( )
@ -645,7 +666,6 @@ func (db *DB) executeStmtWithConn(stmt *command.Statement, xTime bool, e execer,
return result , nil
}
ctx := context . Background ( )
if timeout > 0 {
var cancel context . CancelFunc
ctx , cancel = context . WithTimeout ( ctx , timeout )
@ -695,13 +715,14 @@ func (db *DB) QueryStringStmt(query string) ([]*command.QueryRows, error) {
// QueryStringStmtWithTimeout executes a single query that return rows, but don't modify database.
// It also sets a timeout for the query.
func ( db * DB ) QueryStringStmtWithTimeout ( query string , t imeout time . Duration ) ( [ ] * command . QueryRows , error ) {
func ( db * DB ) QueryStringStmtWithTimeout ( query string , t x bool , t imeout time . Duration ) ( [ ] * command . QueryRows , error ) {
r := & command . Request {
Statements : [ ] * command . Statement {
{
Sql : query ,
} ,
} ,
Transaction : tx ,
DbTimeout : int64 ( timeout ) ,
}
return db . Query ( r , false )
@ -715,21 +736,28 @@ func (db *DB) Query(req *command.Request, xTime bool) ([]*command.QueryRows, err
return nil , err
}
defer conn . Close ( )
return db . queryWithConn ( req , xTime , conn )
ctx := context . Background ( )
if req . DbTimeout > 0 {
var cancel context . CancelFunc
ctx , cancel = context . WithTimeout ( ctx , time . Duration ( req . DbTimeout ) )
defer cancel ( )
}
return db . queryWithConn ( ctx , req , xTime , conn )
}
type queryer interface {
QueryContext ( ctx context . Context , query string , args ... interface { } ) ( * sql . Rows , error )
}
func ( db * DB ) queryWithConn ( req * command . Request , xTime bool , conn * sql . Conn ) ( [ ] * command . QueryRows , error ) {
func ( db * DB ) queryWithConn ( ctx context . Context , req * command . Request , xTime bool , conn * sql . Conn ) ( [ ] * command . QueryRows , error ) {
var err error
var queryer queryer
var tx * sql . Tx
if req . Transaction {
stats . Add ( numQTx , 1 )
tx , err = conn . BeginTx ( c on te xt. Background ( ) , nil )
tx , err = conn . BeginTx ( c tx, nil )
if err != nil {
return nil , err
}
@ -767,7 +795,7 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([
continue
}
rows , err = db . queryStmtWithConn ( stmt, xTime , queryer , time . Duration ( req . DbTimeout ) )
rows , err = db . queryStmtWithConn ( ctx, stmt, xTime , queryer , time . Duration ( req . DbTimeout ) )
if err != nil {
stats . Add ( numQueryErrors , 1 )
rows = & command . QueryRows {
@ -783,7 +811,15 @@ func (db *DB) queryWithConn(req *command.Request, xTime bool, conn *sql.Conn) ([
return allRows , err
}
func ( db * DB ) queryStmtWithConn ( stmt * command . Statement , xTime bool , q queryer , timeout time . Duration ) ( * command . QueryRows , error ) {
func ( db * DB ) queryStmtWithConn ( ctx context . Context , stmt * command . Statement , xTime bool , q queryer , timeout time . Duration ) ( retRows * command . QueryRows , retErr error ) {
defer func ( ) {
if retErr != nil {
retErr = rewriteContextTimeout ( retErr , ErrQueryTimeout )
if retRows != nil {
retRows . Error = retErr . Error ( )
}
}
} ( )
rows := & command . QueryRows { }
start := time . Now ( )
@ -794,18 +830,11 @@ func (db *DB) queryStmtWithConn(stmt *command.Statement, xTime bool, q queryer,
return rows , nil
}
ctx := context . Background ( )
if timeout > 0 {
var cancel context . CancelFunc
ctx , cancel = context . WithTimeout ( ctx , timeout )
defer cancel ( )
}
rs , err := q . QueryContext ( ctx , stmt . Sql , parameters ... )
if err != nil {
stats . Add ( numQueryErrors , 1 )
rows . Error = err . Error ( )
return rows , nil
return rows , err
}
defer rs . Close ( )
@ -852,8 +881,7 @@ func (db *DB) queryStmtWithConn(stmt *command.Statement, xTime bool, q queryer,
// Check for errors from iterating over rows.
if err := rs . Err ( ) ; err != nil {
stats . Add ( numQueryErrors , 1 )
rows . Error = err . Error ( )
return rows , nil
return rows , err
}
if xTime {
@ -897,12 +925,19 @@ func (db *DB) Request(req *command.Request, xTime bool) ([]*command.ExecuteQuery
}
defer conn . Close ( )
ctx := context . Background ( )
if req . DbTimeout > 0 {
var cancel context . CancelFunc
ctx , cancel = context . WithTimeout ( ctx , time . Duration ( req . DbTimeout ) )
defer cancel ( )
}
var queryer queryer
var execer execer
var tx * sql . Tx
if req . Transaction {
stats . Add ( numRTx , 1 )
tx , err = conn . BeginTx ( context . Background ( ) , nil )
tx , err = conn . BeginTx ( c tx, nil )
if err != nil {
return nil , err
}
@ -943,13 +978,13 @@ func (db *DB) Request(req *command.Request, xTime bool) ([]*command.ExecuteQuery
}
if ro {
rows , opErr := db . queryStmtWithConn ( stmt, xTime , queryer , time . Duration ( req . DbTimeout ) )
rows , opErr := db . queryStmtWithConn ( ctx, stmt, xTime , queryer , time . Duration ( req . DbTimeout ) )
eqResponse = append ( eqResponse , createEQQueryResponse ( rows , opErr ) )
if abortOnError ( opErr ) {
break
}
} else {
result , opErr := db . executeStmtWithConn ( stmt, xTime , execer , time . Duration ( req . DbTimeout ) )
result , opErr := db . executeStmtWithConn ( ctx, stmt, xTime , execer , time . Duration ( req . DbTimeout ) )
eqResponse = append ( eqResponse , createEQExecuteResponse ( result , opErr ) )
if abortOnError ( opErr ) {
break
@ -1045,6 +1080,7 @@ func (db *DB) Dump(w io.Writer) error {
return err
}
defer conn . Close ( )
ctx := context . Background ( )
// Convenience function to convert string query to protobuf.
commReq := func ( query string ) * command . Request {
@ -1064,7 +1100,7 @@ func (db *DB) Dump(w io.Writer) error {
// Get the schema.
query := ` SELECT "name" , "type" , "sql" FROM "sqlite_master"
WHERE "sql" NOT NULL AND "type" == ' table ' ORDER BY "name" `
rows , err := db . queryWithConn ( commReq ( query ) , false , conn )
rows , err := db . queryWithConn ( ctx , commReq ( query ) , false , conn )
if err != nil {
return err
}
@ -1088,7 +1124,7 @@ func (db *DB) Dump(w io.Writer) error {
}
tableIndent := strings . Replace ( table , ` " ` , ` "" ` , - 1 )
r , err := db . queryWithConn ( commReq ( fmt . Sprintf ( ` PRAGMA table_info("%s") ` , tableIndent ) ) ,
r , err := db . queryWithConn ( ctx , commReq ( fmt . Sprintf ( ` PRAGMA table_info("%s") ` , tableIndent ) ) ,
false , conn )
if err != nil {
return err
@ -1102,7 +1138,7 @@ func (db *DB) Dump(w io.Writer) error {
tableIndent ,
strings . Join ( columnNames , "," ) ,
tableIndent )
r , err = db . queryWithConn ( commReq ( query ) , false , conn )
r , err = db . queryWithConn ( ctx , commReq ( query ) , false , conn )
if err != nil {
return err
@ -1118,7 +1154,7 @@ func (db *DB) Dump(w io.Writer) error {
// Do indexes, triggers, and views.
query = ` SELECT "name" , "type" , "sql" FROM "sqlite_master"
WHERE "sql" NOT NULL AND "type" IN ( ' index ' , ' trigger ' , ' view ' ) `
rows , err = db . queryWithConn ( commReq ( query ) , false , conn )
rows , err = db . queryWithConn ( ctx , commReq ( query ) , false , conn )
if err != nil {
return err
}
@ -1497,3 +1533,10 @@ func lastModified(path string) (t time.Time, retError error) {
}
return info . ModTime ( ) , nil
}
func rewriteContextTimeout ( err , retErr error ) error {
if err == context . DeadlineExceeded {
return retErr
}
return err
}