net/http: simplify HTTP/1 request cancelation

HTTP requests have three separate user cancelation signals:

	Transport.CancelRequest
	Request.Cancel
	Request.Context()

In addition, a request can be canceled due to errors.

The Transport keeps a map of all in-flight requests,
with an associated func to run if CancelRequest is
called. Confusingly, this func is *not* run if
Request.Cancel is closed or the request context expires.

The map of in-flight requests is also used to communicate
between roundTrip and readLoop. In particular, if readLoop
reads a response immediately followed by an EOF, it may
send racing signals to roundTrip: The connection has
closed, but also there is a response available.
This race is resolved by readLoop communicating through
the request map that this request has successfully
completed.

This CL refactors all of this.

In-flight requests now have a context which is canceled
when any of the above cancelation events occurs.

The map of requests to cancel funcs remains, but is
used strictly for implementing Transport.CancelRequest.
It is not used to communicate information about the
state of a request.

Change-Id: Ie157edc0ce35f719866a0a2cb0e70514fd119ff8
Reviewed-on: https://go-review.googlesource.com/c/go/+/546676
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
changes/76/546676/6
Damien Neil 2023-12-01 16:26:14 -08:00
parent 33d725e575
commit 505000b2d3
2 changed files with 145 additions and 145 deletions

View File

@ -100,7 +100,7 @@ type Transport struct {
idleLRU connLRU
reqMu sync.Mutex
reqCanceler map[cancelKey]func(error)
reqCanceler map[*Request]context.CancelCauseFunc
altMu sync.Mutex // guards changing altProto only
altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
@ -294,13 +294,6 @@ type Transport struct {
ForceAttemptHTTP2 bool
}
// A cancelKey is the key of the reqCanceler map.
// We wrap the *Request in this type since we want to use the original request,
// not any transient one created by roundTrip.
type cancelKey struct {
req *Request
}
func (t *Transport) writeBufferSize() int {
if t.WriteBufferSize > 0 {
return t.WriteBufferSize
@ -466,10 +459,12 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
// optional extra headers to write and stores any error to return
// from roundTrip.
type transportRequest struct {
*Request // original request, not to be mutated
extra Header // extra headers to write, or nil
trace *httptrace.ClientTrace // optional
cancelKey cancelKey
*Request // original request, not to be mutated
extra Header // extra headers to write, or nil
trace *httptrace.ClientTrace // optional
ctx context.Context // canceled when we are done with the request
cancel context.CancelCauseFunc
mu sync.Mutex // guards err
err error // first setError value for mapRoundTripError to consider
@ -531,7 +526,7 @@ func validateHeaders(hdrs Header) string {
}
// roundTrip implements a RoundTripper over HTTP.
func (t *Transport) roundTrip(req *Request) (*Response, error) {
func (t *Transport) roundTrip(req *Request) (_ *Response, err error) {
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
ctx := req.Context()
trace := httptrace.ContextClientTrace(ctx)
@ -561,7 +556,6 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
}
origReq := req
cancelKey := cancelKey{origReq}
req = setupRewindBody(req)
if altRT := t.alternateRoundTripper(req); altRT != nil {
@ -587,16 +581,44 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
return nil, errors.New("http: no Host in request URL")
}
// Transport request context.
//
// If RoundTrip returns an error, it cancels this context before returning.
//
// If RoundTrip returns no error:
// - For an HTTP/1 request, persistConn.readLoop cancels this context
// after reading the request body.
// - For an HTTP/2 request, RoundTrip cancels this context after the HTTP/2
// RoundTripper returns.
ctx, cancel := context.WithCancelCause(req.Context())
// Convert Request.Cancel into context cancelation.
if origReq.Cancel != nil {
go awaitLegacyCancel(ctx, cancel, origReq)
}
// Convert Transport.CancelRequest into context cancelation.
//
// This is lamentably expensive. CancelRequest has been deprecated for a long time
// and doesn't work on HTTP/2 requests. Perhaps we should drop support for it entirely.
cancel = t.prepareTransportCancel(origReq, cancel)
defer func() {
if err != nil {
cancel(err)
}
}()
for {
select {
case <-ctx.Done():
req.closeBody()
return nil, ctx.Err()
return nil, context.Cause(ctx)
default:
}
// treq gets modified by roundTrip, so we need to recreate for each retry.
treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey}
treq := &transportRequest{Request: req, trace: trace, ctx: ctx, cancel: cancel}
cm, err := t.connectMethodForRequest(treq)
if err != nil {
req.closeBody()
@ -609,7 +631,6 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
// to send it requests.
pconn, err := t.getConn(treq, cm)
if err != nil {
t.setReqCanceler(cancelKey, nil)
req.closeBody()
return nil, err
}
@ -617,12 +638,19 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
var resp *Response
if pconn.alt != nil {
// HTTP/2 path.
t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest
resp, err = pconn.alt.RoundTrip(req)
} else {
resp, err = pconn.roundTrip(treq)
}
if err == nil {
if pconn.alt != nil {
// HTTP/2 requests are not cancelable with CancelRequest,
// so we have no further need for the request context.
//
// On the HTTP/1 path, roundTrip takes responsibility for
// canceling the context after the response body is read.
cancel(errRequestDone)
}
resp.Request = origReq
return resp, nil
}
@ -659,6 +687,14 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
}
}
func awaitLegacyCancel(ctx context.Context, cancel context.CancelCauseFunc, req *Request) {
select {
case <-req.Cancel:
cancel(errRequestCanceled)
case <-ctx.Done():
}
}
var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss")
type readTrackingBody struct {
@ -820,30 +856,42 @@ func (t *Transport) CloseIdleConnections() {
}
}
// prepareTransportCancel sets up state to convert Transport.CancelRequest into context cancelation.
func (t *Transport) prepareTransportCancel(req *Request, origCancel context.CancelCauseFunc) context.CancelCauseFunc {
// Historically, RoundTrip has not modified the Request in any way.
// We could avoid the need to keep a map of all in-flight requests by adding
// a field to the Request containing its cancel func, and setting that field
// while the request is in-flight. Callers aren't supposed to reuse a Request
// until after the response body is closed, so this wouldn't violate any
// concurrency guarantees.
cancel := func(err error) {
origCancel(err)
t.reqMu.Lock()
delete(t.reqCanceler, req)
t.reqMu.Unlock()
}
t.reqMu.Lock()
if t.reqCanceler == nil {
t.reqCanceler = make(map[*Request]context.CancelCauseFunc)
}
t.reqCanceler[req] = cancel
t.reqMu.Unlock()
return cancel
}
// CancelRequest cancels an in-flight request by closing its connection.
// CancelRequest should only be called after [Transport.RoundTrip] has returned.
//
// Deprecated: Use [Request.WithContext] to create a request with a
// cancelable context instead. CancelRequest cannot cancel HTTP/2
// requests.
// requests. This may become a no-op in a future release of Go.
func (t *Transport) CancelRequest(req *Request) {
t.cancelRequest(cancelKey{req}, errRequestCanceled)
}
// Cancel an in-flight request, recording the error value.
// Returns whether the request was canceled.
func (t *Transport) cancelRequest(key cancelKey, err error) bool {
// This function must not return until the cancel func has completed.
// See: https://golang.org/issue/34658
t.reqMu.Lock()
defer t.reqMu.Unlock()
cancel := t.reqCanceler[key]
delete(t.reqCanceler, key)
cancel := t.reqCanceler[req]
t.reqMu.Unlock()
if cancel != nil {
cancel(err)
cancel(errRequestCanceled)
}
return cancel != nil
}
//
@ -1170,38 +1218,6 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool {
return removed
}
func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) {
t.reqMu.Lock()
defer t.reqMu.Unlock()
if t.reqCanceler == nil {
t.reqCanceler = make(map[cancelKey]func(error))
}
if fn != nil {
t.reqCanceler[key] = fn
} else {
delete(t.reqCanceler, key)
}
}
// replaceReqCanceler replaces an existing cancel function. If there is no cancel function
// for the request, we don't set the function and return false.
// Since CancelRequest will clear the canceler, we can use the return value to detect if
// the request was canceled since the last setReqCancel call.
func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool {
t.reqMu.Lock()
defer t.reqMu.Unlock()
_, ok := t.reqCanceler[key]
if !ok {
return false
}
if fn != nil {
t.reqCanceler[key] = fn
} else {
delete(t.reqCanceler, key)
}
return true
}
var zeroDialer net.Dialer
func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, error) {
@ -1442,19 +1458,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
}
}()
var cancelc chan error
// Queue for idle connection.
if delivered := t.queueForIdleConn(w); delivered {
// set request canceler to some non-nil function so we
// can detect whether it was cleared between now and when
// we enter roundTrip
t.setReqCanceler(treq.cancelKey, func(error) {})
} else {
cancelc = make(chan error, 1)
t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err })
// Queue for permission to dial.
if delivered := t.queueForIdleConn(w); !delivered {
t.queueForDial(w)
}
@ -1479,11 +1484,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
// what caused r.err; if so, prefer to return the
// cancellation error (see golang.org/issue/16049).
select {
case <-req.Cancel:
return nil, errRequestCanceledConn
case <-req.Context().Done():
return nil, req.Context().Err()
case err := <-cancelc:
case <-treq.ctx.Done():
err := context.Cause(treq.ctx)
if err == errRequestCanceled {
err = errRequestCanceledConn
}
@ -1493,11 +1495,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis
}
}
return r.pc, r.err
case <-req.Cancel:
return nil, errRequestCanceledConn
case <-req.Context().Done():
return nil, req.Context().Err()
case err := <-cancelc:
case <-treq.ctx.Done():
err := context.Cause(treq.ctx)
if err == errRequestCanceled {
err = errRequestCanceledConn
}
@ -2173,7 +2172,8 @@ func (pc *persistConn) readLoop() {
pc.t.removeIdleConn(pc)
}()
tryPutIdleConn := func(trace *httptrace.ClientTrace) bool {
tryPutIdleConn := func(treq *transportRequest) bool {
trace := treq.trace
if err := pc.t.tryPutIdleConn(pc); err != nil {
closeErr = err
if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled {
@ -2212,7 +2212,7 @@ func (pc *persistConn) readLoop() {
pc.mu.Unlock()
rc := <-pc.reqch
trace := httptrace.ContextClientTrace(rc.req.Context())
trace := rc.treq.trace
var resp *Response
if err == nil {
@ -2241,9 +2241,9 @@ func (pc *persistConn) readLoop() {
pc.mu.Unlock()
bodyWritable := resp.bodyIsWritable()
hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0
hasBody := rc.treq.Request.Method != "HEAD" && resp.ContentLength != 0
if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable {
if resp.Close || rc.treq.Request.Close || resp.StatusCode <= 199 || bodyWritable {
// Don't do keep-alive on error if either party requested a close
// or we get an unexpected informational (1xx) response.
// StatusCode 100 is already handled above.
@ -2251,8 +2251,6 @@ func (pc *persistConn) readLoop() {
}
if !hasBody || bodyWritable {
replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil)
// Put the idle conn back into the pool before we send the response
// so if they process it quickly and make another request, they'll
// get this same conn. But we use the unbuffered channel 'rc'
@ -2261,7 +2259,7 @@ func (pc *persistConn) readLoop() {
alive = alive &&
!pc.sawEOF &&
pc.wroteRequest() &&
replaced && tryPutIdleConn(trace)
tryPutIdleConn(rc.treq)
if bodyWritable {
closeErr = errCallerOwnsConn
@ -2273,6 +2271,8 @@ func (pc *persistConn) readLoop() {
return
}
rc.treq.cancel(errRequestDone)
// Now that they've read from the unbuffered channel, they're safely
// out of the select that also waits on this goroutine to die, so
// we're allowed to exit now if needed (if alive is false)
@ -2323,26 +2323,22 @@ func (pc *persistConn) readLoop() {
// reading the response body. (or for cancellation or death)
select {
case bodyEOF := <-waitForBodyRead:
replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
alive = alive &&
bodyEOF &&
!pc.sawEOF &&
pc.wroteRequest() &&
replaced && tryPutIdleConn(trace)
tryPutIdleConn(rc.treq)
if bodyEOF {
eofc <- struct{}{}
}
case <-rc.req.Cancel:
case <-rc.treq.ctx.Done():
alive = false
pc.t.cancelRequest(rc.cancelKey, errRequestCanceled)
case <-rc.req.Context().Done():
alive = false
pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err())
pc.cancelRequest(errRequestCanceled)
case <-pc.closech:
alive = false
pc.t.setReqCanceler(rc.cancelKey, nil)
}
rc.treq.cancel(errRequestDone)
testHookReadLoopBeforeNextRead()
}
}
@ -2395,7 +2391,7 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr
continueCh := rc.continueCh
for {
resp, err = ReadResponse(pc.br, rc.req)
resp, err = ReadResponse(pc.br, rc.treq.Request)
if err != nil {
return
}
@ -2587,10 +2583,9 @@ type responseAndError struct {
}
type requestAndChan struct {
_ incomparable
req *Request
cancelKey cancelKey
ch chan responseAndError // unbuffered; always send in select on callerGone
_ incomparable
treq *transportRequest
ch chan responseAndError // unbuffered; always send in select on callerGone
// whether the Transport (as opposed to the user client code)
// added the Accept-Encoding gzip header. If the Transport
@ -2638,6 +2633,10 @@ var errTimeout error = &timeoutError{"net/http: timeout awaiting response header
var errRequestCanceled = http2errRequestCanceled
var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify?
// errRequestDone is used to cancel the round trip Context after a request is successfully done.
// It should not be seen by the user.
var errRequestDone = errors.New("net/http: request completed")
func nop() {}
// testHooks. Always non-nil.
@ -2654,10 +2653,6 @@ var (
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
testHookEnterRoundTrip()
if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) {
pc.t.putOrCloseIdleConn(pc)
return nil, errRequestCanceled
}
pc.mu.Lock()
pc.numExpectedResponses++
headerFn := pc.mutateHeaderFunc
@ -2706,12 +2701,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
gone := make(chan struct{})
defer close(gone)
defer func() {
if err != nil {
pc.t.setReqCanceler(req.cancelKey, nil)
}
}()
const debugRoundTrip = false
// Write the request concurrently with waiting for a response,
@ -2723,19 +2712,29 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
resc := make(chan responseAndError)
pc.reqch <- requestAndChan{
req: req.Request,
cancelKey: req.cancelKey,
treq: req,
ch: resc,
addedGzip: requestedGzip,
continueCh: continueCh,
callerGone: gone,
}
handleResponse := func(re responseAndError) (*Response, error) {
if (re.res == nil) == (re.err == nil) {
panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil))
}
if debugRoundTrip {
req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err)
}
if re.err != nil {
return nil, pc.mapRoundTripError(req, startBytesWritten, re.err)
}
return re.res, nil
}
var respHeaderTimer <-chan time.Time
cancelChan := req.Request.Cancel
ctxDoneChan := req.Context().Done()
ctxDoneChan := req.ctx.Done()
pcClosed := pc.closech
canceled := false
for {
testHookWaitResLoop()
select {
@ -2756,13 +2755,18 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
respHeaderTimer = timer.C
}
case <-pcClosed:
pcClosed = nil
if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) {
if debugRoundTrip {
req.logf("closech recv: %T %#v", pc.closed, pc.closed)
}
return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed)
select {
case re := <-resc:
// The pconn closing raced with the response to the request,
// probably after the server wrote a response and immediately
// closed the connection. Use the response.
return handleResponse(re)
default:
}
if debugRoundTrip {
req.logf("closech recv: %T %#v", pc.closed, pc.closed)
}
return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed)
case <-respHeaderTimer:
if debugRoundTrip {
req.logf("timeout waiting for response headers.")
@ -2770,23 +2774,17 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
pc.close(errTimeout)
return nil, errTimeout
case re := <-resc:
if (re.res == nil) == (re.err == nil) {
panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil))
}
if debugRoundTrip {
req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err)
}
if re.err != nil {
return nil, pc.mapRoundTripError(req, startBytesWritten, re.err)
}
return re.res, nil
case <-cancelChan:
canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
cancelChan = nil
return handleResponse(re)
case <-ctxDoneChan:
canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err())
cancelChan = nil
ctxDoneChan = nil
select {
case re := <-resc:
// readLoop is responsible for canceling req.ctx after
// it reads the response body. Check for a response racing
// the context close, and use the response if available.
return handleResponse(re)
default:
}
pc.cancelRequest(context.Cause(req.ctx))
}
}
}

View File

@ -8,6 +8,7 @@ package http
import (
"bytes"
"context"
"crypto/tls"
"errors"
"io"
@ -36,7 +37,8 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) {
tr := new(Transport)
req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
req = req.WithT(t)
treq := &transportRequest{Request: req}
ctx, cancel := context.WithCancelCause(context.Background())
treq := &transportRequest{Request: req, ctx: ctx, cancel: cancel}
cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
pc, err := tr.getConn(treq, cm)
if err != nil {