route_linux.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. package netlink
  2. import (
  3. "fmt"
  4. "net"
  5. "syscall"
  6. "github.com/vishvananda/netlink/nl"
  7. )
  8. // RtAttr is shared so it is in netlink_linux.go
  9. const (
  10. RT_FILTER_PROTOCOL uint64 = 1 << (1 + iota)
  11. RT_FILTER_SCOPE
  12. RT_FILTER_TYPE
  13. RT_FILTER_TOS
  14. RT_FILTER_IIF
  15. RT_FILTER_OIF
  16. RT_FILTER_DST
  17. RT_FILTER_SRC
  18. RT_FILTER_GW
  19. RT_FILTER_TABLE
  20. )
  21. // RouteAdd will add a route to the system.
  22. // Equivalent to: `ip route add $route`
  23. func RouteAdd(route *Route) error {
  24. req := nl.NewNetlinkRequest(syscall.RTM_NEWROUTE, syscall.NLM_F_CREATE|syscall.NLM_F_EXCL|syscall.NLM_F_ACK)
  25. return routeHandle(route, req, nl.NewRtMsg())
  26. }
  27. // RouteDel will delete a route from the system.
  28. // Equivalent to: `ip route del $route`
  29. func RouteDel(route *Route) error {
  30. req := nl.NewNetlinkRequest(syscall.RTM_DELROUTE, syscall.NLM_F_ACK)
  31. return routeHandle(route, req, nl.NewRtDelMsg())
  32. }
  33. func routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg) error {
  34. if (route.Dst == nil || route.Dst.IP == nil) && route.Src == nil && route.Gw == nil {
  35. return fmt.Errorf("one of Dst.IP, Src, or Gw must not be nil")
  36. }
  37. family := -1
  38. var rtAttrs []*nl.RtAttr
  39. if route.Dst != nil && route.Dst.IP != nil {
  40. dstLen, _ := route.Dst.Mask.Size()
  41. msg.Dst_len = uint8(dstLen)
  42. dstFamily := nl.GetIPFamily(route.Dst.IP)
  43. family = dstFamily
  44. var dstData []byte
  45. if dstFamily == FAMILY_V4 {
  46. dstData = route.Dst.IP.To4()
  47. } else {
  48. dstData = route.Dst.IP.To16()
  49. }
  50. rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_DST, dstData))
  51. }
  52. if route.Src != nil {
  53. srcFamily := nl.GetIPFamily(route.Src)
  54. if family != -1 && family != srcFamily {
  55. return fmt.Errorf("source and destination ip are not the same IP family")
  56. }
  57. family = srcFamily
  58. var srcData []byte
  59. if srcFamily == FAMILY_V4 {
  60. srcData = route.Src.To4()
  61. } else {
  62. srcData = route.Src.To16()
  63. }
  64. // The commonly used src ip for routes is actually PREFSRC
  65. rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_PREFSRC, srcData))
  66. }
  67. if route.Gw != nil {
  68. gwFamily := nl.GetIPFamily(route.Gw)
  69. if family != -1 && family != gwFamily {
  70. return fmt.Errorf("gateway, source, and destination ip are not the same IP family")
  71. }
  72. family = gwFamily
  73. var gwData []byte
  74. if gwFamily == FAMILY_V4 {
  75. gwData = route.Gw.To4()
  76. } else {
  77. gwData = route.Gw.To16()
  78. }
  79. rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_GATEWAY, gwData))
  80. }
  81. if route.Table > 0 {
  82. if route.Table >= 256 {
  83. msg.Table = syscall.RT_TABLE_UNSPEC
  84. b := make([]byte, 4)
  85. native.PutUint32(b, uint32(route.Table))
  86. rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_TABLE, b))
  87. } else {
  88. msg.Table = uint8(route.Table)
  89. }
  90. }
  91. if route.Priority > 0 {
  92. b := make([]byte, 4)
  93. native.PutUint32(b, uint32(route.Priority))
  94. rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_PRIORITY, b))
  95. }
  96. if route.Tos > 0 {
  97. msg.Tos = uint8(route.Tos)
  98. }
  99. if route.Protocol > 0 {
  100. msg.Protocol = uint8(route.Protocol)
  101. }
  102. if route.Type > 0 {
  103. msg.Type = uint8(route.Type)
  104. }
  105. msg.Scope = uint8(route.Scope)
  106. msg.Family = uint8(family)
  107. req.AddData(msg)
  108. for _, attr := range rtAttrs {
  109. req.AddData(attr)
  110. }
  111. var (
  112. b = make([]byte, 4)
  113. native = nl.NativeEndian()
  114. )
  115. native.PutUint32(b, uint32(route.LinkIndex))
  116. req.AddData(nl.NewRtAttr(syscall.RTA_OIF, b))
  117. _, err := req.Execute(syscall.NETLINK_ROUTE, 0)
  118. return err
  119. }
  120. // RouteList gets a list of routes in the system.
  121. // Equivalent to: `ip route show`.
  122. // The list can be filtered by link and ip family.
  123. func RouteList(link Link, family int) ([]Route, error) {
  124. var routeFilter *Route
  125. if link != nil {
  126. routeFilter = &Route{
  127. LinkIndex: link.Attrs().Index,
  128. }
  129. }
  130. return RouteListFiltered(family, routeFilter, RT_FILTER_OIF)
  131. }
  132. // RouteListFiltered gets a list of routes in the system filtered with specified rules.
  133. // All rules must be defined in RouteFilter struct
  134. func RouteListFiltered(family int, filter *Route, filterMask uint64) ([]Route, error) {
  135. req := nl.NewNetlinkRequest(syscall.RTM_GETROUTE, syscall.NLM_F_DUMP)
  136. infmsg := nl.NewIfInfomsg(family)
  137. req.AddData(infmsg)
  138. msgs, err := req.Execute(syscall.NETLINK_ROUTE, syscall.RTM_NEWROUTE)
  139. if err != nil {
  140. return nil, err
  141. }
  142. var res []Route
  143. for _, m := range msgs {
  144. msg := nl.DeserializeRtMsg(m)
  145. if msg.Flags&syscall.RTM_F_CLONED != 0 {
  146. // Ignore cloned routes
  147. continue
  148. }
  149. if msg.Table != syscall.RT_TABLE_MAIN {
  150. if filter == nil || filter != nil && filterMask&RT_FILTER_TABLE == 0 {
  151. // Ignore non-main tables
  152. continue
  153. }
  154. }
  155. route, err := deserializeRoute(m)
  156. if err != nil {
  157. return nil, err
  158. }
  159. if filter != nil {
  160. switch {
  161. case filterMask&RT_FILTER_TABLE != 0 && route.Table != filter.Table:
  162. continue
  163. case filterMask&RT_FILTER_PROTOCOL != 0 && route.Protocol != filter.Protocol:
  164. continue
  165. case filterMask&RT_FILTER_SCOPE != 0 && route.Scope != filter.Scope:
  166. continue
  167. case filterMask&RT_FILTER_TYPE != 0 && route.Type != filter.Type:
  168. continue
  169. case filterMask&RT_FILTER_TOS != 0 && route.Tos != filter.Tos:
  170. continue
  171. case filterMask&RT_FILTER_OIF != 0 && route.LinkIndex != filter.LinkIndex:
  172. continue
  173. case filterMask&RT_FILTER_IIF != 0 && route.ILinkIndex != filter.ILinkIndex:
  174. continue
  175. case filterMask&RT_FILTER_GW != 0 && !route.Gw.Equal(filter.Gw):
  176. continue
  177. case filterMask&RT_FILTER_SRC != 0 && !route.Src.Equal(filter.Src):
  178. continue
  179. case filterMask&RT_FILTER_DST != 0 && filter.Dst != nil:
  180. if route.Dst == nil {
  181. continue
  182. }
  183. aMaskLen, aMaskBits := route.Dst.Mask.Size()
  184. bMaskLen, bMaskBits := filter.Dst.Mask.Size()
  185. if !(route.Dst.IP.Equal(filter.Dst.IP) && aMaskLen == bMaskLen && aMaskBits == bMaskBits) {
  186. continue
  187. }
  188. }
  189. }
  190. res = append(res, route)
  191. }
  192. return res, nil
  193. }
  194. // deserializeRoute decodes a binary netlink message into a Route struct
  195. func deserializeRoute(m []byte) (Route, error) {
  196. msg := nl.DeserializeRtMsg(m)
  197. attrs, err := nl.ParseRouteAttr(m[msg.Len():])
  198. if err != nil {
  199. return Route{}, err
  200. }
  201. route := Route{
  202. Scope: Scope(msg.Scope),
  203. Protocol: int(msg.Protocol),
  204. Table: int(msg.Table),
  205. Type: int(msg.Type),
  206. Tos: int(msg.Tos),
  207. Flags: int(msg.Flags),
  208. }
  209. native := nl.NativeEndian()
  210. for _, attr := range attrs {
  211. switch attr.Attr.Type {
  212. case syscall.RTA_GATEWAY:
  213. route.Gw = net.IP(attr.Value)
  214. case syscall.RTA_PREFSRC:
  215. route.Src = net.IP(attr.Value)
  216. case syscall.RTA_DST:
  217. route.Dst = &net.IPNet{
  218. IP: attr.Value,
  219. Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attr.Value)),
  220. }
  221. case syscall.RTA_OIF:
  222. route.LinkIndex = int(native.Uint32(attr.Value[0:4]))
  223. case syscall.RTA_IIF:
  224. route.ILinkIndex = int(native.Uint32(attr.Value[0:4]))
  225. case syscall.RTA_PRIORITY:
  226. route.Priority = int(native.Uint32(attr.Value[0:4]))
  227. case syscall.RTA_TABLE:
  228. route.Table = int(native.Uint32(attr.Value[0:4]))
  229. }
  230. }
  231. return route, nil
  232. }
  233. // RouteGet gets a route to a specific destination from the host system.
  234. // Equivalent to: 'ip route get'.
  235. func RouteGet(destination net.IP) ([]Route, error) {
  236. req := nl.NewNetlinkRequest(syscall.RTM_GETROUTE, syscall.NLM_F_REQUEST)
  237. family := nl.GetIPFamily(destination)
  238. var destinationData []byte
  239. var bitlen uint8
  240. if family == FAMILY_V4 {
  241. destinationData = destination.To4()
  242. bitlen = 32
  243. } else {
  244. destinationData = destination.To16()
  245. bitlen = 128
  246. }
  247. msg := &nl.RtMsg{}
  248. msg.Family = uint8(family)
  249. msg.Dst_len = bitlen
  250. req.AddData(msg)
  251. rtaDst := nl.NewRtAttr(syscall.RTA_DST, destinationData)
  252. req.AddData(rtaDst)
  253. msgs, err := req.Execute(syscall.NETLINK_ROUTE, syscall.RTM_NEWROUTE)
  254. if err != nil {
  255. return nil, err
  256. }
  257. var res []Route
  258. for _, m := range msgs {
  259. route, err := deserializeRoute(m)
  260. if err != nil {
  261. return nil, err
  262. }
  263. res = append(res, route)
  264. }
  265. return res, nil
  266. }
  267. // RouteSubscribe takes a chan down which notifications will be sent
  268. // when routes are added or deleted. Close the 'done' chan to stop subscription.
  269. func RouteSubscribe(ch chan<- RouteUpdate, done <-chan struct{}) error {
  270. s, err := nl.Subscribe(syscall.NETLINK_ROUTE, syscall.RTNLGRP_IPV4_ROUTE, syscall.RTNLGRP_IPV6_ROUTE)
  271. if err != nil {
  272. return err
  273. }
  274. if done != nil {
  275. go func() {
  276. <-done
  277. s.Close()
  278. }()
  279. }
  280. go func() {
  281. defer close(ch)
  282. for {
  283. msgs, err := s.Receive()
  284. if err != nil {
  285. return
  286. }
  287. for _, m := range msgs {
  288. route, err := deserializeRoute(m.Data)
  289. if err != nil {
  290. return
  291. }
  292. ch <- RouteUpdate{Type: m.Header.Type, Route: route}
  293. }
  294. }
  295. }()
  296. return nil
  297. }