transport_unix.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. //+build !windows
  2. package dbus
  3. import (
  4. "bytes"
  5. "encoding/binary"
  6. "errors"
  7. "io"
  8. "net"
  9. "syscall"
  10. )
  11. type oobReader struct {
  12. conn *net.UnixConn
  13. oob []byte
  14. buf [4096]byte
  15. }
  16. func (o *oobReader) Read(b []byte) (n int, err error) {
  17. n, oobn, flags, _, err := o.conn.ReadMsgUnix(b, o.buf[:])
  18. if err != nil {
  19. return n, err
  20. }
  21. if flags&syscall.MSG_CTRUNC != 0 {
  22. return n, errors.New("dbus: control data truncated (too many fds received)")
  23. }
  24. o.oob = append(o.oob, o.buf[:oobn]...)
  25. return n, nil
  26. }
  27. type unixTransport struct {
  28. *net.UnixConn
  29. hasUnixFDs bool
  30. }
  31. func newUnixTransport(keys string) (transport, error) {
  32. var err error
  33. t := new(unixTransport)
  34. abstract := getKey(keys, "abstract")
  35. path := getKey(keys, "path")
  36. switch {
  37. case abstract == "" && path == "":
  38. return nil, errors.New("dbus: invalid address (neither path nor abstract set)")
  39. case abstract != "" && path == "":
  40. t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: "@" + abstract, Net: "unix"})
  41. if err != nil {
  42. return nil, err
  43. }
  44. return t, nil
  45. case abstract == "" && path != "":
  46. t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: path, Net: "unix"})
  47. if err != nil {
  48. return nil, err
  49. }
  50. return t, nil
  51. default:
  52. return nil, errors.New("dbus: invalid address (both path and abstract set)")
  53. }
  54. }
  55. func init() {
  56. transports["unix"] = newUnixTransport
  57. }
  58. func (t *unixTransport) EnableUnixFDs() {
  59. t.hasUnixFDs = true
  60. }
  61. func (t *unixTransport) ReadMessage() (*Message, error) {
  62. var (
  63. blen, hlen uint32
  64. csheader [16]byte
  65. headers []header
  66. order binary.ByteOrder
  67. unixfds uint32
  68. )
  69. // To be sure that all bytes of out-of-band data are read, we use a special
  70. // reader that uses ReadUnix on the underlying connection instead of Read
  71. // and gathers the out-of-band data in a buffer.
  72. rd := &oobReader{conn: t.UnixConn}
  73. // read the first 16 bytes (the part of the header that has a constant size),
  74. // from which we can figure out the length of the rest of the message
  75. if _, err := io.ReadFull(rd, csheader[:]); err != nil {
  76. return nil, err
  77. }
  78. switch csheader[0] {
  79. case 'l':
  80. order = binary.LittleEndian
  81. case 'B':
  82. order = binary.BigEndian
  83. default:
  84. return nil, InvalidMessageError("invalid byte order")
  85. }
  86. // csheader[4:8] -> length of message body, csheader[12:16] -> length of
  87. // header fields (without alignment)
  88. binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen)
  89. binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen)
  90. if hlen%8 != 0 {
  91. hlen += 8 - (hlen % 8)
  92. }
  93. // decode headers and look for unix fds
  94. headerdata := make([]byte, hlen+4)
  95. copy(headerdata, csheader[12:])
  96. if _, err := io.ReadFull(t, headerdata[4:]); err != nil {
  97. return nil, err
  98. }
  99. dec := newDecoder(bytes.NewBuffer(headerdata), order)
  100. dec.pos = 12
  101. vs, err := dec.Decode(Signature{"a(yv)"})
  102. if err != nil {
  103. return nil, err
  104. }
  105. Store(vs, &headers)
  106. for _, v := range headers {
  107. if v.Field == byte(FieldUnixFDs) {
  108. unixfds, _ = v.Variant.value.(uint32)
  109. }
  110. }
  111. all := make([]byte, 16+hlen+blen)
  112. copy(all, csheader[:])
  113. copy(all[16:], headerdata[4:])
  114. if _, err := io.ReadFull(rd, all[16+hlen:]); err != nil {
  115. return nil, err
  116. }
  117. if unixfds != 0 {
  118. if !t.hasUnixFDs {
  119. return nil, errors.New("dbus: got unix fds on unsupported transport")
  120. }
  121. // read the fds from the OOB data
  122. scms, err := syscall.ParseSocketControlMessage(rd.oob)
  123. if err != nil {
  124. return nil, err
  125. }
  126. if len(scms) != 1 {
  127. return nil, errors.New("dbus: received more than one socket control message")
  128. }
  129. fds, err := syscall.ParseUnixRights(&scms[0])
  130. if err != nil {
  131. return nil, err
  132. }
  133. msg, err := DecodeMessage(bytes.NewBuffer(all))
  134. if err != nil {
  135. return nil, err
  136. }
  137. // substitute the values in the message body (which are indices for the
  138. // array receiver via OOB) with the actual values
  139. for i, v := range msg.Body {
  140. if j, ok := v.(UnixFDIndex); ok {
  141. if uint32(j) >= unixfds {
  142. return nil, InvalidMessageError("invalid index for unix fd")
  143. }
  144. msg.Body[i] = UnixFD(fds[j])
  145. }
  146. }
  147. return msg, nil
  148. }
  149. return DecodeMessage(bytes.NewBuffer(all))
  150. }
  151. func (t *unixTransport) SendMessage(msg *Message) error {
  152. fds := make([]int, 0)
  153. for i, v := range msg.Body {
  154. if fd, ok := v.(UnixFD); ok {
  155. msg.Body[i] = UnixFDIndex(len(fds))
  156. fds = append(fds, int(fd))
  157. }
  158. }
  159. if len(fds) != 0 {
  160. if !t.hasUnixFDs {
  161. return errors.New("dbus: unix fd passing not enabled")
  162. }
  163. msg.Headers[FieldUnixFDs] = MakeVariant(uint32(len(fds)))
  164. oob := syscall.UnixRights(fds...)
  165. buf := new(bytes.Buffer)
  166. msg.EncodeTo(buf, binary.LittleEndian)
  167. n, oobn, err := t.UnixConn.WriteMsgUnix(buf.Bytes(), oob, nil)
  168. if err != nil {
  169. return err
  170. }
  171. if n != buf.Len() || oobn != len(oob) {
  172. return io.ErrShortWrite
  173. }
  174. } else {
  175. if err := msg.EncodeTo(t, binary.LittleEndian); err != nil {
  176. return nil
  177. }
  178. }
  179. return nil
  180. }
  181. func (t *unixTransport) SupportsUnixFDs() bool {
  182. return true
  183. }