123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- // Copyright 2015 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- // +build go1.6
- package http2
- import (
- "crypto/tls"
- "fmt"
- "net/http"
- )
- func configureTransport(t1 *http.Transport) (*Transport, error) {
- connPool := new(clientConnPool)
- t2 := &Transport{
- ConnPool: noDialClientConnPool{connPool},
- t1: t1,
- }
- connPool.t = t2
- if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
- return nil, err
- }
- if t1.TLSClientConfig == nil {
- t1.TLSClientConfig = new(tls.Config)
- }
- if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
- t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
- }
- if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") {
- t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
- }
- upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
- addr := authorityAddr(authority)
- if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
- go c.Close()
- return erringRoundTripper{err}
- } else if !used {
- // Turns out we don't need this c.
- // For example, two goroutines made requests to the same host
- // at the same time, both kicking off TCP dials. (since protocol
- // was unknown)
- go c.Close()
- }
- return t2
- }
- if m := t1.TLSNextProto; len(m) == 0 {
- t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{
- "h2": upgradeFn,
- }
- } else {
- m["h2"] = upgradeFn
- }
- return t2, nil
- }
- // registerHTTPSProtocol calls Transport.RegisterProtocol but
- // convering panics into errors.
- func registerHTTPSProtocol(t *http.Transport, rt http.RoundTripper) (err error) {
- defer func() {
- if e := recover(); e != nil {
- err = fmt.Errorf("%v", e)
- }
- }()
- t.RegisterProtocol("https", rt)
- return nil
- }
- // noDialClientConnPool is an implementation of http2.ClientConnPool
- // which never dials. We let the HTTP/1.1 client dial and use its TLS
- // connection instead.
- type noDialClientConnPool struct{ *clientConnPool }
- func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
- return p.getClientConn(req, addr, noDialOnMiss)
- }
- // noDialH2RoundTripper is a RoundTripper which only tries to complete the request
- // if there's already has a cached connection to the host.
- type noDialH2RoundTripper struct{ t *Transport }
- func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
- res, err := rt.t.RoundTrip(req)
- if err == ErrNoCachedConn {
- return nil, http.ErrSkipAltProtocol
- }
- return res, err
- }
|