123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- package netlink
- import (
- "fmt"
- "net"
- "syscall"
- "github.com/vishvananda/netlink/nl"
- )
- // RtAttr is shared so it is in netlink_linux.go
- const (
- RT_FILTER_PROTOCOL uint64 = 1 << (1 + iota)
- RT_FILTER_SCOPE
- RT_FILTER_TYPE
- RT_FILTER_TOS
- RT_FILTER_IIF
- RT_FILTER_OIF
- RT_FILTER_DST
- RT_FILTER_SRC
- RT_FILTER_GW
- RT_FILTER_TABLE
- )
- // RouteAdd will add a route to the system.
- // Equivalent to: `ip route add $route`
- func RouteAdd(route *Route) error {
- req := nl.NewNetlinkRequest(syscall.RTM_NEWROUTE, syscall.NLM_F_CREATE|syscall.NLM_F_EXCL|syscall.NLM_F_ACK)
- return routeHandle(route, req, nl.NewRtMsg())
- }
- // RouteDel will delete a route from the system.
- // Equivalent to: `ip route del $route`
- func RouteDel(route *Route) error {
- req := nl.NewNetlinkRequest(syscall.RTM_DELROUTE, syscall.NLM_F_ACK)
- return routeHandle(route, req, nl.NewRtDelMsg())
- }
- func routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg) error {
- if (route.Dst == nil || route.Dst.IP == nil) && route.Src == nil && route.Gw == nil {
- return fmt.Errorf("one of Dst.IP, Src, or Gw must not be nil")
- }
- family := -1
- var rtAttrs []*nl.RtAttr
- if route.Dst != nil && route.Dst.IP != nil {
- dstLen, _ := route.Dst.Mask.Size()
- msg.Dst_len = uint8(dstLen)
- dstFamily := nl.GetIPFamily(route.Dst.IP)
- family = dstFamily
- var dstData []byte
- if dstFamily == FAMILY_V4 {
- dstData = route.Dst.IP.To4()
- } else {
- dstData = route.Dst.IP.To16()
- }
- rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_DST, dstData))
- }
- if route.Src != nil {
- srcFamily := nl.GetIPFamily(route.Src)
- if family != -1 && family != srcFamily {
- return fmt.Errorf("source and destination ip are not the same IP family")
- }
- family = srcFamily
- var srcData []byte
- if srcFamily == FAMILY_V4 {
- srcData = route.Src.To4()
- } else {
- srcData = route.Src.To16()
- }
- // The commonly used src ip for routes is actually PREFSRC
- rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_PREFSRC, srcData))
- }
- if route.Gw != nil {
- gwFamily := nl.GetIPFamily(route.Gw)
- if family != -1 && family != gwFamily {
- return fmt.Errorf("gateway, source, and destination ip are not the same IP family")
- }
- family = gwFamily
- var gwData []byte
- if gwFamily == FAMILY_V4 {
- gwData = route.Gw.To4()
- } else {
- gwData = route.Gw.To16()
- }
- rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_GATEWAY, gwData))
- }
- if route.Table > 0 {
- if route.Table >= 256 {
- msg.Table = syscall.RT_TABLE_UNSPEC
- b := make([]byte, 4)
- native.PutUint32(b, uint32(route.Table))
- rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_TABLE, b))
- } else {
- msg.Table = uint8(route.Table)
- }
- }
- if route.Priority > 0 {
- b := make([]byte, 4)
- native.PutUint32(b, uint32(route.Priority))
- rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_PRIORITY, b))
- }
- if route.Tos > 0 {
- msg.Tos = uint8(route.Tos)
- }
- if route.Protocol > 0 {
- msg.Protocol = uint8(route.Protocol)
- }
- if route.Type > 0 {
- msg.Type = uint8(route.Type)
- }
- msg.Scope = uint8(route.Scope)
- msg.Family = uint8(family)
- req.AddData(msg)
- for _, attr := range rtAttrs {
- req.AddData(attr)
- }
- var (
- b = make([]byte, 4)
- native = nl.NativeEndian()
- )
- native.PutUint32(b, uint32(route.LinkIndex))
- req.AddData(nl.NewRtAttr(syscall.RTA_OIF, b))
- _, err := req.Execute(syscall.NETLINK_ROUTE, 0)
- return err
- }
- // RouteList gets a list of routes in the system.
- // Equivalent to: `ip route show`.
- // The list can be filtered by link and ip family.
- func RouteList(link Link, family int) ([]Route, error) {
- var routeFilter *Route
- if link != nil {
- routeFilter = &Route{
- LinkIndex: link.Attrs().Index,
- }
- }
- return RouteListFiltered(family, routeFilter, RT_FILTER_OIF)
- }
- // RouteListFiltered gets a list of routes in the system filtered with specified rules.
- // All rules must be defined in RouteFilter struct
- func RouteListFiltered(family int, filter *Route, filterMask uint64) ([]Route, error) {
- req := nl.NewNetlinkRequest(syscall.RTM_GETROUTE, syscall.NLM_F_DUMP)
- infmsg := nl.NewIfInfomsg(family)
- req.AddData(infmsg)
- msgs, err := req.Execute(syscall.NETLINK_ROUTE, syscall.RTM_NEWROUTE)
- if err != nil {
- return nil, err
- }
- var res []Route
- for _, m := range msgs {
- msg := nl.DeserializeRtMsg(m)
- if msg.Flags&syscall.RTM_F_CLONED != 0 {
- // Ignore cloned routes
- continue
- }
- if msg.Table != syscall.RT_TABLE_MAIN {
- if filter == nil || filter != nil && filterMask&RT_FILTER_TABLE == 0 {
- // Ignore non-main tables
- continue
- }
- }
- route, err := deserializeRoute(m)
- if err != nil {
- return nil, err
- }
- if filter != nil {
- switch {
- case filterMask&RT_FILTER_TABLE != 0 && route.Table != filter.Table:
- continue
- case filterMask&RT_FILTER_PROTOCOL != 0 && route.Protocol != filter.Protocol:
- continue
- case filterMask&RT_FILTER_SCOPE != 0 && route.Scope != filter.Scope:
- continue
- case filterMask&RT_FILTER_TYPE != 0 && route.Type != filter.Type:
- continue
- case filterMask&RT_FILTER_TOS != 0 && route.Tos != filter.Tos:
- continue
- case filterMask&RT_FILTER_OIF != 0 && route.LinkIndex != filter.LinkIndex:
- continue
- case filterMask&RT_FILTER_IIF != 0 && route.ILinkIndex != filter.ILinkIndex:
- continue
- case filterMask&RT_FILTER_GW != 0 && !route.Gw.Equal(filter.Gw):
- continue
- case filterMask&RT_FILTER_SRC != 0 && !route.Src.Equal(filter.Src):
- continue
- case filterMask&RT_FILTER_DST != 0 && filter.Dst != nil:
- if route.Dst == nil {
- continue
- }
- aMaskLen, aMaskBits := route.Dst.Mask.Size()
- bMaskLen, bMaskBits := filter.Dst.Mask.Size()
- if !(route.Dst.IP.Equal(filter.Dst.IP) && aMaskLen == bMaskLen && aMaskBits == bMaskBits) {
- continue
- }
- }
- }
- res = append(res, route)
- }
- return res, nil
- }
- // deserializeRoute decodes a binary netlink message into a Route struct
- func deserializeRoute(m []byte) (Route, error) {
- msg := nl.DeserializeRtMsg(m)
- attrs, err := nl.ParseRouteAttr(m[msg.Len():])
- if err != nil {
- return Route{}, err
- }
- route := Route{
- Scope: Scope(msg.Scope),
- Protocol: int(msg.Protocol),
- Table: int(msg.Table),
- Type: int(msg.Type),
- Tos: int(msg.Tos),
- Flags: int(msg.Flags),
- }
- native := nl.NativeEndian()
- for _, attr := range attrs {
- switch attr.Attr.Type {
- case syscall.RTA_GATEWAY:
- route.Gw = net.IP(attr.Value)
- case syscall.RTA_PREFSRC:
- route.Src = net.IP(attr.Value)
- case syscall.RTA_DST:
- route.Dst = &net.IPNet{
- IP: attr.Value,
- Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attr.Value)),
- }
- case syscall.RTA_OIF:
- route.LinkIndex = int(native.Uint32(attr.Value[0:4]))
- case syscall.RTA_IIF:
- route.ILinkIndex = int(native.Uint32(attr.Value[0:4]))
- case syscall.RTA_PRIORITY:
- route.Priority = int(native.Uint32(attr.Value[0:4]))
- case syscall.RTA_TABLE:
- route.Table = int(native.Uint32(attr.Value[0:4]))
- }
- }
- return route, nil
- }
- // RouteGet gets a route to a specific destination from the host system.
- // Equivalent to: 'ip route get'.
- func RouteGet(destination net.IP) ([]Route, error) {
- req := nl.NewNetlinkRequest(syscall.RTM_GETROUTE, syscall.NLM_F_REQUEST)
- family := nl.GetIPFamily(destination)
- var destinationData []byte
- var bitlen uint8
- if family == FAMILY_V4 {
- destinationData = destination.To4()
- bitlen = 32
- } else {
- destinationData = destination.To16()
- bitlen = 128
- }
- msg := &nl.RtMsg{}
- msg.Family = uint8(family)
- msg.Dst_len = bitlen
- req.AddData(msg)
- rtaDst := nl.NewRtAttr(syscall.RTA_DST, destinationData)
- req.AddData(rtaDst)
- msgs, err := req.Execute(syscall.NETLINK_ROUTE, syscall.RTM_NEWROUTE)
- if err != nil {
- return nil, err
- }
- var res []Route
- for _, m := range msgs {
- route, err := deserializeRoute(m)
- if err != nil {
- return nil, err
- }
- res = append(res, route)
- }
- return res, nil
- }
- // RouteSubscribe takes a chan down which notifications will be sent
- // when routes are added or deleted. Close the 'done' chan to stop subscription.
- func RouteSubscribe(ch chan<- RouteUpdate, done <-chan struct{}) error {
- s, err := nl.Subscribe(syscall.NETLINK_ROUTE, syscall.RTNLGRP_IPV4_ROUTE, syscall.RTNLGRP_IPV6_ROUTE)
- if err != nil {
- return err
- }
- if done != nil {
- go func() {
- <-done
- s.Close()
- }()
- }
- go func() {
- defer close(ch)
- for {
- msgs, err := s.Receive()
- if err != nil {
- return
- }
- for _, m := range msgs {
- route, err := deserializeRoute(m.Data)
- if err != nil {
- return
- }
- ch <- RouteUpdate{Type: m.Header.Type, Route: route}
- }
- }
- }()
- return nil
- }
|