configure_transport.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. // Copyright 2015 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // +build go1.6
  5. package http2
  6. import (
  7. "crypto/tls"
  8. "fmt"
  9. "net/http"
  10. )
  11. func configureTransport(t1 *http.Transport) (*Transport, error) {
  12. connPool := new(clientConnPool)
  13. t2 := &Transport{
  14. ConnPool: noDialClientConnPool{connPool},
  15. t1: t1,
  16. }
  17. connPool.t = t2
  18. if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
  19. return nil, err
  20. }
  21. if t1.TLSClientConfig == nil {
  22. t1.TLSClientConfig = new(tls.Config)
  23. }
  24. if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
  25. t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
  26. }
  27. if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") {
  28. t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
  29. }
  30. upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
  31. addr := authorityAddr(authority)
  32. if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
  33. go c.Close()
  34. return erringRoundTripper{err}
  35. } else if !used {
  36. // Turns out we don't need this c.
  37. // For example, two goroutines made requests to the same host
  38. // at the same time, both kicking off TCP dials. (since protocol
  39. // was unknown)
  40. go c.Close()
  41. }
  42. return t2
  43. }
  44. if m := t1.TLSNextProto; len(m) == 0 {
  45. t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{
  46. "h2": upgradeFn,
  47. }
  48. } else {
  49. m["h2"] = upgradeFn
  50. }
  51. return t2, nil
  52. }
  53. // registerHTTPSProtocol calls Transport.RegisterProtocol but
  54. // convering panics into errors.
  55. func registerHTTPSProtocol(t *http.Transport, rt http.RoundTripper) (err error) {
  56. defer func() {
  57. if e := recover(); e != nil {
  58. err = fmt.Errorf("%v", e)
  59. }
  60. }()
  61. t.RegisterProtocol("https", rt)
  62. return nil
  63. }
  64. // noDialClientConnPool is an implementation of http2.ClientConnPool
  65. // which never dials. We let the HTTP/1.1 client dial and use its TLS
  66. // connection instead.
  67. type noDialClientConnPool struct{ *clientConnPool }
  68. func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
  69. return p.getClientConn(req, addr, noDialOnMiss)
  70. }
  71. // noDialH2RoundTripper is a RoundTripper which only tries to complete the request
  72. // if there's already has a cached connection to the host.
  73. type noDialH2RoundTripper struct{ t *Transport }
  74. func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
  75. res, err := rt.t.RoundTrip(req)
  76. if err == ErrNoCachedConn {
  77. return nil, http.ErrSkipAltProtocol
  78. }
  79. return res, err
  80. }