socket_linux.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. package netlink
  2. import (
  3. "errors"
  4. "fmt"
  5. "net"
  6. "syscall"
  7. "github.com/vishvananda/netlink/nl"
  8. )
  9. const (
  10. sizeofSocketID = 0x30
  11. sizeofSocketRequest = sizeofSocketID + 0x8
  12. sizeofSocket = sizeofSocketID + 0x18
  13. )
  14. type socketRequest struct {
  15. Family uint8
  16. Protocol uint8
  17. Ext uint8
  18. pad uint8
  19. States uint32
  20. ID SocketID
  21. }
  22. type writeBuffer struct {
  23. Bytes []byte
  24. pos int
  25. }
  26. func (b *writeBuffer) Write(c byte) {
  27. b.Bytes[b.pos] = c
  28. b.pos++
  29. }
  30. func (b *writeBuffer) Next(n int) []byte {
  31. s := b.Bytes[b.pos : b.pos+n]
  32. b.pos += n
  33. return s
  34. }
  35. func (r *socketRequest) Serialize() []byte {
  36. b := writeBuffer{Bytes: make([]byte, sizeofSocketRequest)}
  37. b.Write(r.Family)
  38. b.Write(r.Protocol)
  39. b.Write(r.Ext)
  40. b.Write(r.pad)
  41. native.PutUint32(b.Next(4), r.States)
  42. networkOrder.PutUint16(b.Next(2), r.ID.SourcePort)
  43. networkOrder.PutUint16(b.Next(2), r.ID.DestinationPort)
  44. copy(b.Next(4), r.ID.Source.To4())
  45. b.Next(12)
  46. copy(b.Next(4), r.ID.Destination.To4())
  47. b.Next(12)
  48. native.PutUint32(b.Next(4), r.ID.Interface)
  49. native.PutUint32(b.Next(4), r.ID.Cookie[0])
  50. native.PutUint32(b.Next(4), r.ID.Cookie[1])
  51. return b.Bytes
  52. }
  53. func (r *socketRequest) Len() int { return sizeofSocketRequest }
  54. type readBuffer struct {
  55. Bytes []byte
  56. pos int
  57. }
  58. func (b *readBuffer) Read() byte {
  59. c := b.Bytes[b.pos]
  60. b.pos++
  61. return c
  62. }
  63. func (b *readBuffer) Next(n int) []byte {
  64. s := b.Bytes[b.pos : b.pos+n]
  65. b.pos += n
  66. return s
  67. }
  68. func (s *Socket) deserialize(b []byte) error {
  69. if len(b) < sizeofSocket {
  70. return fmt.Errorf("socket data short read (%d); want %d", len(b), sizeofSocket)
  71. }
  72. rb := readBuffer{Bytes: b}
  73. s.Family = rb.Read()
  74. s.State = rb.Read()
  75. s.Timer = rb.Read()
  76. s.Retrans = rb.Read()
  77. s.ID.SourcePort = networkOrder.Uint16(rb.Next(2))
  78. s.ID.DestinationPort = networkOrder.Uint16(rb.Next(2))
  79. s.ID.Source = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read())
  80. rb.Next(12)
  81. s.ID.Destination = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read())
  82. rb.Next(12)
  83. s.ID.Interface = native.Uint32(rb.Next(4))
  84. s.ID.Cookie[0] = native.Uint32(rb.Next(4))
  85. s.ID.Cookie[1] = native.Uint32(rb.Next(4))
  86. s.Expires = native.Uint32(rb.Next(4))
  87. s.RQueue = native.Uint32(rb.Next(4))
  88. s.WQueue = native.Uint32(rb.Next(4))
  89. s.UID = native.Uint32(rb.Next(4))
  90. s.INode = native.Uint32(rb.Next(4))
  91. return nil
  92. }
  93. // SocketGet returns the Socket identified by its local and remote addresses.
  94. func SocketGet(local, remote net.Addr) (*Socket, error) {
  95. localTCP, ok := local.(*net.TCPAddr)
  96. if !ok {
  97. return nil, ErrNotImplemented
  98. }
  99. remoteTCP, ok := remote.(*net.TCPAddr)
  100. if !ok {
  101. return nil, ErrNotImplemented
  102. }
  103. localIP := localTCP.IP.To4()
  104. if localIP == nil {
  105. return nil, ErrNotImplemented
  106. }
  107. remoteIP := remoteTCP.IP.To4()
  108. if remoteIP == nil {
  109. return nil, ErrNotImplemented
  110. }
  111. s, err := nl.Subscribe(syscall.NETLINK_INET_DIAG)
  112. if err != nil {
  113. return nil, err
  114. }
  115. defer s.Close()
  116. req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, 0)
  117. req.AddData(&socketRequest{
  118. Family: syscall.AF_INET,
  119. Protocol: syscall.IPPROTO_TCP,
  120. ID: SocketID{
  121. SourcePort: uint16(localTCP.Port),
  122. DestinationPort: uint16(remoteTCP.Port),
  123. Source: localIP,
  124. Destination: remoteIP,
  125. Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE},
  126. },
  127. })
  128. s.Send(req)
  129. msgs, err := s.Receive()
  130. if err != nil {
  131. return nil, err
  132. }
  133. if len(msgs) == 0 {
  134. return nil, errors.New("no message nor error from netlink")
  135. }
  136. if len(msgs) > 2 {
  137. return nil, fmt.Errorf("multiple (%d) matching sockets", len(msgs))
  138. }
  139. sock := &Socket{}
  140. if err := sock.deserialize(msgs[0].Data); err != nil {
  141. return nil, err
  142. }
  143. return sock, nil
  144. }