123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- package tftp
- import (
- "encoding/binary"
- "fmt"
- "io"
- "net"
- "strconv"
- "time"
- "github.com/pin/tftp/netascii"
- )
- // OutgoingTransfer provides methods to set the outgoing transfer size and
- // retrieve the remote address of the peer.
- type OutgoingTransfer interface {
- // SetSize is used to set the outgoing transfer size (tsize option: RFC2349)
- // manually in a server write transfer handler.
- //
- // It is not necessary in most cases; when the io.Reader provided to
- // ReadFrom also satisfies io.Seeker (e.g. os.File) the transfer size will
- // be determined automatically. Seek will not be attempted when the
- // transfer size option is set with SetSize.
- //
- // The value provided will be used only if SetSize is called before ReadFrom
- // and only on in a server read handler.
- SetSize(n int64)
- // RemoteAddr returns the remote peer's IP address and port.
- RemoteAddr() net.UDPAddr
- }
- type sender struct {
- conn *net.UDPConn
- addr *net.UDPAddr
- tid int
- send []byte
- receive []byte
- retry *backoff
- timeout time.Duration
- retries int
- block uint16
- mode string
- opts options
- }
- func (s *sender) RemoteAddr() net.UDPAddr { return *s.addr }
- func (s *sender) SetSize(n int64) {
- if s.opts != nil {
- if _, ok := s.opts["tsize"]; ok {
- s.opts["tsize"] = strconv.FormatInt(n, 10)
- }
- }
- }
- func (s *sender) ReadFrom(r io.Reader) (n int64, err error) {
- if s.mode == "netascii" {
- r = netascii.ToReader(r)
- }
- if s.opts != nil {
- // check that tsize is set
- if ts, ok := s.opts["tsize"]; ok {
- // check that tsize is not set with SetSize already
- i, err := strconv.ParseInt(ts, 10, 64)
- if err == nil && i == 0 {
- if rs, ok := r.(io.Seeker); ok {
- pos, err := rs.Seek(0, 1)
- if err != nil {
- return 0, err
- }
- size, err := rs.Seek(0, 2)
- if err != nil {
- return 0, err
- }
- s.opts["tsize"] = strconv.FormatInt(size, 10)
- _, err = rs.Seek(pos, 0)
- if err != nil {
- return 0, err
- }
- }
- }
- }
- err = s.sendOptions()
- if err != nil {
- s.abort(err)
- return 0, err
- }
- }
- s.block = 1 // start data transmission with block 1
- binary.BigEndian.PutUint16(s.send[0:2], opDATA)
- for {
- l, err := io.ReadFull(r, s.send[4:])
- n += int64(l)
- if err != nil && err != io.ErrUnexpectedEOF {
- if err == io.EOF {
- binary.BigEndian.PutUint16(s.send[2:4], s.block)
- _, err = s.sendWithRetry(4)
- if err != nil {
- s.abort(err)
- return n, err
- }
- s.conn.Close()
- return n, nil
- }
- s.abort(err)
- return n, err
- }
- binary.BigEndian.PutUint16(s.send[2:4], s.block)
- _, err = s.sendWithRetry(4 + l)
- if err != nil {
- s.abort(err)
- return n, err
- }
- if l < len(s.send)-4 {
- s.conn.Close()
- return n, nil
- }
- s.block++
- }
- }
- func (s *sender) sendOptions() error {
- for name, value := range s.opts {
- if name == "blksize" {
- err := s.setBlockSize(value)
- if err != nil {
- delete(s.opts, name)
- continue
- }
- } else if name == "tsize" {
- if value != "0" {
- s.opts["tsize"] = value
- } else {
- delete(s.opts, name)
- continue
- }
- } else {
- delete(s.opts, name)
- }
- }
- if len(s.opts) > 0 {
- m := packOACK(s.send, s.opts)
- _, err := s.sendWithRetry(m)
- if err != nil {
- return err
- }
- }
- return nil
- }
- func (s *sender) setBlockSize(blksize string) error {
- n, err := strconv.Atoi(blksize)
- if err != nil {
- return err
- }
- if n < 512 {
- return fmt.Errorf("blkzise too small: %d", n)
- }
- if n > 65464 {
- return fmt.Errorf("blksize too large: %d", n)
- }
- s.send = make([]byte, n+4)
- return nil
- }
- func (s *sender) sendWithRetry(l int) (*net.UDPAddr, error) {
- s.retry.reset()
- for {
- addr, err := s.sendDatagram(l)
- if _, ok := err.(net.Error); ok && s.retry.count() < s.retries {
- s.retry.backoff()
- continue
- }
- return addr, err
- }
- }
- func (s *sender) sendDatagram(l int) (*net.UDPAddr, error) {
- err := s.conn.SetReadDeadline(time.Now().Add(s.timeout))
- if err != nil {
- return nil, err
- }
- _, err = s.conn.WriteToUDP(s.send[:l], s.addr)
- if err != nil {
- return nil, err
- }
- for {
- n, addr, err := s.conn.ReadFromUDP(s.receive)
- if err != nil {
- return nil, err
- }
- if !addr.IP.Equal(s.addr.IP) || (s.tid != 0 && addr.Port != s.tid) {
- continue
- }
- p, err := parsePacket(s.receive[:n])
- if err != nil {
- continue
- }
- s.tid = addr.Port
- switch p := p.(type) {
- case pACK:
- if p.block() == s.block {
- return addr, nil
- }
- case pOACK:
- opts, err := unpackOACK(p)
- if s.block != 0 {
- continue
- }
- if err != nil {
- s.abort(err)
- return addr, err
- }
- for name, value := range opts {
- if name == "blksize" {
- err := s.setBlockSize(value)
- if err != nil {
- continue
- }
- }
- }
- return addr, nil
- case pERROR:
- return nil, fmt.Errorf("sending block %d: code=%d, error: %s",
- s.block, p.code(), p.message())
- }
- }
- }
- func (s *sender) abort(err error) error {
- if s.conn == nil {
- return nil
- }
- n := packERROR(s.send, 1, err.Error())
- _, err = s.conn.WriteToUDP(s.send[:n], s.addr)
- if err != nil {
- return err
- }
- s.conn.Close()
- s.conn = nil
- return nil
- }
|