nl_linux.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. // Package nl has low level primitives for making Netlink calls.
  2. package nl
  3. import (
  4. "bytes"
  5. "encoding/binary"
  6. "fmt"
  7. "net"
  8. "sync/atomic"
  9. "syscall"
  10. "unsafe"
  11. )
  12. const (
  13. // Family type definitions
  14. FAMILY_ALL = syscall.AF_UNSPEC
  15. FAMILY_V4 = syscall.AF_INET
  16. FAMILY_V6 = syscall.AF_INET6
  17. )
  18. var nextSeqNr uint32
  19. // GetIPFamily returns the family type of a net.IP.
  20. func GetIPFamily(ip net.IP) int {
  21. if len(ip) <= net.IPv4len {
  22. return FAMILY_V4
  23. }
  24. if ip.To4() != nil {
  25. return FAMILY_V4
  26. }
  27. return FAMILY_V6
  28. }
  29. var nativeEndian binary.ByteOrder
  30. // Get native endianness for the system
  31. func NativeEndian() binary.ByteOrder {
  32. if nativeEndian == nil {
  33. var x uint32 = 0x01020304
  34. if *(*byte)(unsafe.Pointer(&x)) == 0x01 {
  35. nativeEndian = binary.BigEndian
  36. } else {
  37. nativeEndian = binary.LittleEndian
  38. }
  39. }
  40. return nativeEndian
  41. }
  42. // Byte swap a 16 bit value if we aren't big endian
  43. func Swap16(i uint16) uint16 {
  44. if NativeEndian() == binary.BigEndian {
  45. return i
  46. }
  47. return (i&0xff00)>>8 | (i&0xff)<<8
  48. }
  49. // Byte swap a 32 bit value if aren't big endian
  50. func Swap32(i uint32) uint32 {
  51. if NativeEndian() == binary.BigEndian {
  52. return i
  53. }
  54. return (i&0xff000000)>>24 | (i&0xff0000)>>8 | (i&0xff00)<<8 | (i&0xff)<<24
  55. }
  56. type NetlinkRequestData interface {
  57. Len() int
  58. Serialize() []byte
  59. }
  60. // IfInfomsg is related to links, but it is used for list requests as well
  61. type IfInfomsg struct {
  62. syscall.IfInfomsg
  63. }
  64. // Create an IfInfomsg with family specified
  65. func NewIfInfomsg(family int) *IfInfomsg {
  66. return &IfInfomsg{
  67. IfInfomsg: syscall.IfInfomsg{
  68. Family: uint8(family),
  69. },
  70. }
  71. }
  72. func DeserializeIfInfomsg(b []byte) *IfInfomsg {
  73. return (*IfInfomsg)(unsafe.Pointer(&b[0:syscall.SizeofIfInfomsg][0]))
  74. }
  75. func (msg *IfInfomsg) Serialize() []byte {
  76. return (*(*[syscall.SizeofIfInfomsg]byte)(unsafe.Pointer(msg)))[:]
  77. }
  78. func (msg *IfInfomsg) Len() int {
  79. return syscall.SizeofIfInfomsg
  80. }
  81. func rtaAlignOf(attrlen int) int {
  82. return (attrlen + syscall.RTA_ALIGNTO - 1) & ^(syscall.RTA_ALIGNTO - 1)
  83. }
  84. func NewIfInfomsgChild(parent *RtAttr, family int) *IfInfomsg {
  85. msg := NewIfInfomsg(family)
  86. parent.children = append(parent.children, msg)
  87. return msg
  88. }
  89. // Extend RtAttr to handle data and children
  90. type RtAttr struct {
  91. syscall.RtAttr
  92. Data []byte
  93. children []NetlinkRequestData
  94. }
  95. // Create a new Extended RtAttr object
  96. func NewRtAttr(attrType int, data []byte) *RtAttr {
  97. return &RtAttr{
  98. RtAttr: syscall.RtAttr{
  99. Type: uint16(attrType),
  100. },
  101. children: []NetlinkRequestData{},
  102. Data: data,
  103. }
  104. }
  105. // Create a new RtAttr obj anc add it as a child of an existing object
  106. func NewRtAttrChild(parent *RtAttr, attrType int, data []byte) *RtAttr {
  107. attr := NewRtAttr(attrType, data)
  108. parent.children = append(parent.children, attr)
  109. return attr
  110. }
  111. func (a *RtAttr) Len() int {
  112. if len(a.children) == 0 {
  113. return (syscall.SizeofRtAttr + len(a.Data))
  114. }
  115. l := 0
  116. for _, child := range a.children {
  117. l += rtaAlignOf(child.Len())
  118. }
  119. l += syscall.SizeofRtAttr
  120. return rtaAlignOf(l + len(a.Data))
  121. }
  122. // Serialize the RtAttr into a byte array
  123. // This can't just unsafe.cast because it must iterate through children.
  124. func (a *RtAttr) Serialize() []byte {
  125. native := NativeEndian()
  126. length := a.Len()
  127. buf := make([]byte, rtaAlignOf(length))
  128. next := 4
  129. if a.Data != nil {
  130. copy(buf[next:], a.Data)
  131. next += rtaAlignOf(len(a.Data))
  132. }
  133. if len(a.children) > 0 {
  134. for _, child := range a.children {
  135. childBuf := child.Serialize()
  136. copy(buf[next:], childBuf)
  137. next += rtaAlignOf(len(childBuf))
  138. }
  139. }
  140. if l := uint16(length); l != 0 {
  141. native.PutUint16(buf[0:2], l)
  142. }
  143. native.PutUint16(buf[2:4], a.Type)
  144. return buf
  145. }
  146. type NetlinkRequest struct {
  147. syscall.NlMsghdr
  148. Data []NetlinkRequestData
  149. }
  150. // Serialize the Netlink Request into a byte array
  151. func (req *NetlinkRequest) Serialize() []byte {
  152. length := syscall.SizeofNlMsghdr
  153. dataBytes := make([][]byte, len(req.Data))
  154. for i, data := range req.Data {
  155. dataBytes[i] = data.Serialize()
  156. length = length + len(dataBytes[i])
  157. }
  158. req.Len = uint32(length)
  159. b := make([]byte, length)
  160. hdr := (*(*[syscall.SizeofNlMsghdr]byte)(unsafe.Pointer(req)))[:]
  161. next := syscall.SizeofNlMsghdr
  162. copy(b[0:next], hdr)
  163. for _, data := range dataBytes {
  164. for _, dataByte := range data {
  165. b[next] = dataByte
  166. next = next + 1
  167. }
  168. }
  169. return b
  170. }
  171. func (req *NetlinkRequest) AddData(data NetlinkRequestData) {
  172. if data != nil {
  173. req.Data = append(req.Data, data)
  174. }
  175. }
  176. // Execute the request against a the given sockType.
  177. // Returns a list of netlink messages in seriaized format, optionally filtered
  178. // by resType.
  179. func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, error) {
  180. s, err := getNetlinkSocket(sockType)
  181. if err != nil {
  182. return nil, err
  183. }
  184. defer s.Close()
  185. if err := s.Send(req); err != nil {
  186. return nil, err
  187. }
  188. pid, err := s.GetPid()
  189. if err != nil {
  190. return nil, err
  191. }
  192. var res [][]byte
  193. done:
  194. for {
  195. msgs, err := s.Receive()
  196. if err != nil {
  197. return nil, err
  198. }
  199. for _, m := range msgs {
  200. if m.Header.Seq != req.Seq {
  201. return nil, fmt.Errorf("Wrong Seq nr %d, expected 1", m.Header.Seq)
  202. }
  203. if m.Header.Pid != pid {
  204. return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
  205. }
  206. if m.Header.Type == syscall.NLMSG_DONE {
  207. break done
  208. }
  209. if m.Header.Type == syscall.NLMSG_ERROR {
  210. native := NativeEndian()
  211. error := int32(native.Uint32(m.Data[0:4]))
  212. if error == 0 {
  213. break done
  214. }
  215. return nil, syscall.Errno(-error)
  216. }
  217. if resType != 0 && m.Header.Type != resType {
  218. continue
  219. }
  220. res = append(res, m.Data)
  221. if m.Header.Flags&syscall.NLM_F_MULTI == 0 {
  222. break done
  223. }
  224. }
  225. }
  226. return res, nil
  227. }
  228. // Create a new netlink request from proto and flags
  229. // Note the Len value will be inaccurate once data is added until
  230. // the message is serialized
  231. func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
  232. return &NetlinkRequest{
  233. NlMsghdr: syscall.NlMsghdr{
  234. Len: uint32(syscall.SizeofNlMsghdr),
  235. Type: uint16(proto),
  236. Flags: syscall.NLM_F_REQUEST | uint16(flags),
  237. Seq: atomic.AddUint32(&nextSeqNr, 1),
  238. },
  239. }
  240. }
  241. type NetlinkSocket struct {
  242. fd int
  243. lsa syscall.SockaddrNetlink
  244. }
  245. func getNetlinkSocket(protocol int) (*NetlinkSocket, error) {
  246. fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, protocol)
  247. if err != nil {
  248. return nil, err
  249. }
  250. s := &NetlinkSocket{
  251. fd: fd,
  252. }
  253. s.lsa.Family = syscall.AF_NETLINK
  254. if err := syscall.Bind(fd, &s.lsa); err != nil {
  255. syscall.Close(fd)
  256. return nil, err
  257. }
  258. return s, nil
  259. }
  260. // Create a netlink socket with a given protocol (e.g. NETLINK_ROUTE)
  261. // and subscribe it to multicast groups passed in variable argument list.
  262. // Returns the netlink socket on which Receive() method can be called
  263. // to retrieve the messages from the kernel.
  264. func Subscribe(protocol int, groups ...uint) (*NetlinkSocket, error) {
  265. fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, protocol)
  266. if err != nil {
  267. return nil, err
  268. }
  269. s := &NetlinkSocket{
  270. fd: fd,
  271. }
  272. s.lsa.Family = syscall.AF_NETLINK
  273. for _, g := range groups {
  274. s.lsa.Groups |= (1 << (g - 1))
  275. }
  276. if err := syscall.Bind(fd, &s.lsa); err != nil {
  277. syscall.Close(fd)
  278. return nil, err
  279. }
  280. return s, nil
  281. }
  282. func (s *NetlinkSocket) Close() {
  283. syscall.Close(s.fd)
  284. }
  285. func (s *NetlinkSocket) GetFd() int {
  286. return s.fd
  287. }
  288. func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
  289. if err := syscall.Sendto(s.fd, request.Serialize(), 0, &s.lsa); err != nil {
  290. return err
  291. }
  292. return nil
  293. }
  294. func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, error) {
  295. rb := make([]byte, syscall.Getpagesize())
  296. nr, _, err := syscall.Recvfrom(s.fd, rb, 0)
  297. if err != nil {
  298. return nil, err
  299. }
  300. if nr < syscall.NLMSG_HDRLEN {
  301. return nil, fmt.Errorf("Got short response from netlink")
  302. }
  303. rb = rb[:nr]
  304. return syscall.ParseNetlinkMessage(rb)
  305. }
  306. func (s *NetlinkSocket) GetPid() (uint32, error) {
  307. lsa, err := syscall.Getsockname(s.fd)
  308. if err != nil {
  309. return 0, err
  310. }
  311. switch v := lsa.(type) {
  312. case *syscall.SockaddrNetlink:
  313. return v.Pid, nil
  314. }
  315. return 0, fmt.Errorf("Wrong socket type")
  316. }
  317. func ZeroTerminated(s string) []byte {
  318. bytes := make([]byte, len(s)+1)
  319. for i := 0; i < len(s); i++ {
  320. bytes[i] = s[i]
  321. }
  322. bytes[len(s)] = 0
  323. return bytes
  324. }
  325. func NonZeroTerminated(s string) []byte {
  326. bytes := make([]byte, len(s))
  327. for i := 0; i < len(s); i++ {
  328. bytes[i] = s[i]
  329. }
  330. return bytes
  331. }
  332. func BytesToString(b []byte) string {
  333. n := bytes.Index(b, []byte{0})
  334. return string(b[:n])
  335. }
  336. func Uint8Attr(v uint8) []byte {
  337. return []byte{byte(v)}
  338. }
  339. func Uint16Attr(v uint16) []byte {
  340. native := NativeEndian()
  341. bytes := make([]byte, 2)
  342. native.PutUint16(bytes, v)
  343. return bytes
  344. }
  345. func Uint32Attr(v uint32) []byte {
  346. native := NativeEndian()
  347. bytes := make([]byte, 4)
  348. native.PutUint32(bytes, v)
  349. return bytes
  350. }
  351. func ParseRouteAttr(b []byte) ([]syscall.NetlinkRouteAttr, error) {
  352. var attrs []syscall.NetlinkRouteAttr
  353. for len(b) >= syscall.SizeofRtAttr {
  354. a, vbuf, alen, err := netlinkRouteAttrAndValue(b)
  355. if err != nil {
  356. return nil, err
  357. }
  358. ra := syscall.NetlinkRouteAttr{Attr: *a, Value: vbuf[:int(a.Len)-syscall.SizeofRtAttr]}
  359. attrs = append(attrs, ra)
  360. b = b[alen:]
  361. }
  362. return attrs, nil
  363. }
  364. func netlinkRouteAttrAndValue(b []byte) (*syscall.RtAttr, []byte, int, error) {
  365. a := (*syscall.RtAttr)(unsafe.Pointer(&b[0]))
  366. if int(a.Len) < syscall.SizeofRtAttr || int(a.Len) > len(b) {
  367. return nil, nil, 0, syscall.EINVAL
  368. }
  369. return a, b[syscall.SizeofRtAttr:], rtaAlignOf(int(a.Len)), nil
  370. }