123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- //+build !windows
- package dbus
- import (
- "bytes"
- "encoding/binary"
- "errors"
- "io"
- "net"
- "syscall"
- )
- type oobReader struct {
- conn *net.UnixConn
- oob []byte
- buf [4096]byte
- }
- func (o *oobReader) Read(b []byte) (n int, err error) {
- n, oobn, flags, _, err := o.conn.ReadMsgUnix(b, o.buf[:])
- if err != nil {
- return n, err
- }
- if flags&syscall.MSG_CTRUNC != 0 {
- return n, errors.New("dbus: control data truncated (too many fds received)")
- }
- o.oob = append(o.oob, o.buf[:oobn]...)
- return n, nil
- }
- type unixTransport struct {
- *net.UnixConn
- hasUnixFDs bool
- }
- func newUnixTransport(keys string) (transport, error) {
- var err error
- t := new(unixTransport)
- abstract := getKey(keys, "abstract")
- path := getKey(keys, "path")
- switch {
- case abstract == "" && path == "":
- return nil, errors.New("dbus: invalid address (neither path nor abstract set)")
- case abstract != "" && path == "":
- t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: "@" + abstract, Net: "unix"})
- if err != nil {
- return nil, err
- }
- return t, nil
- case abstract == "" && path != "":
- t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: path, Net: "unix"})
- if err != nil {
- return nil, err
- }
- return t, nil
- default:
- return nil, errors.New("dbus: invalid address (both path and abstract set)")
- }
- }
- func init() {
- transports["unix"] = newUnixTransport
- }
- func (t *unixTransport) EnableUnixFDs() {
- t.hasUnixFDs = true
- }
- func (t *unixTransport) ReadMessage() (*Message, error) {
- var (
- blen, hlen uint32
- csheader [16]byte
- headers []header
- order binary.ByteOrder
- unixfds uint32
- )
- // To be sure that all bytes of out-of-band data are read, we use a special
- // reader that uses ReadUnix on the underlying connection instead of Read
- // and gathers the out-of-band data in a buffer.
- rd := &oobReader{conn: t.UnixConn}
- // read the first 16 bytes (the part of the header that has a constant size),
- // from which we can figure out the length of the rest of the message
- if _, err := io.ReadFull(rd, csheader[:]); err != nil {
- return nil, err
- }
- switch csheader[0] {
- case 'l':
- order = binary.LittleEndian
- case 'B':
- order = binary.BigEndian
- default:
- return nil, InvalidMessageError("invalid byte order")
- }
- // csheader[4:8] -> length of message body, csheader[12:16] -> length of
- // header fields (without alignment)
- binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen)
- binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen)
- if hlen%8 != 0 {
- hlen += 8 - (hlen % 8)
- }
- // decode headers and look for unix fds
- headerdata := make([]byte, hlen+4)
- copy(headerdata, csheader[12:])
- if _, err := io.ReadFull(t, headerdata[4:]); err != nil {
- return nil, err
- }
- dec := newDecoder(bytes.NewBuffer(headerdata), order)
- dec.pos = 12
- vs, err := dec.Decode(Signature{"a(yv)"})
- if err != nil {
- return nil, err
- }
- Store(vs, &headers)
- for _, v := range headers {
- if v.Field == byte(FieldUnixFDs) {
- unixfds, _ = v.Variant.value.(uint32)
- }
- }
- all := make([]byte, 16+hlen+blen)
- copy(all, csheader[:])
- copy(all[16:], headerdata[4:])
- if _, err := io.ReadFull(rd, all[16+hlen:]); err != nil {
- return nil, err
- }
- if unixfds != 0 {
- if !t.hasUnixFDs {
- return nil, errors.New("dbus: got unix fds on unsupported transport")
- }
- // read the fds from the OOB data
- scms, err := syscall.ParseSocketControlMessage(rd.oob)
- if err != nil {
- return nil, err
- }
- if len(scms) != 1 {
- return nil, errors.New("dbus: received more than one socket control message")
- }
- fds, err := syscall.ParseUnixRights(&scms[0])
- if err != nil {
- return nil, err
- }
- msg, err := DecodeMessage(bytes.NewBuffer(all))
- if err != nil {
- return nil, err
- }
- // substitute the values in the message body (which are indices for the
- // array receiver via OOB) with the actual values
- for i, v := range msg.Body {
- if j, ok := v.(UnixFDIndex); ok {
- if uint32(j) >= unixfds {
- return nil, InvalidMessageError("invalid index for unix fd")
- }
- msg.Body[i] = UnixFD(fds[j])
- }
- }
- return msg, nil
- }
- return DecodeMessage(bytes.NewBuffer(all))
- }
- func (t *unixTransport) SendMessage(msg *Message) error {
- fds := make([]int, 0)
- for i, v := range msg.Body {
- if fd, ok := v.(UnixFD); ok {
- msg.Body[i] = UnixFDIndex(len(fds))
- fds = append(fds, int(fd))
- }
- }
- if len(fds) != 0 {
- if !t.hasUnixFDs {
- return errors.New("dbus: unix fd passing not enabled")
- }
- msg.Headers[FieldUnixFDs] = MakeVariant(uint32(len(fds)))
- oob := syscall.UnixRights(fds...)
- buf := new(bytes.Buffer)
- msg.EncodeTo(buf, binary.LittleEndian)
- n, oobn, err := t.UnixConn.WriteMsgUnix(buf.Bytes(), oob, nil)
- if err != nil {
- return err
- }
- if n != buf.Len() || oobn != len(oob) {
- return io.ErrShortWrite
- }
- } else {
- if err := msg.EncodeTo(t, binary.LittleEndian); err != nil {
- return nil
- }
- }
- return nil
- }
- func (t *unixTransport) SupportsUnixFDs() bool {
- return true
- }
|