sender.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. package tftp
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "io"
  6. "net"
  7. "strconv"
  8. "time"
  9. "github.com/pin/tftp/netascii"
  10. )
  11. // OutgoingTransfer provides methods to set the outgoing transfer size and
  12. // retrieve the remote address of the peer.
  13. type OutgoingTransfer interface {
  14. // SetSize is used to set the outgoing transfer size (tsize option: RFC2349)
  15. // manually in a server write transfer handler.
  16. //
  17. // It is not necessary in most cases; when the io.Reader provided to
  18. // ReadFrom also satisfies io.Seeker (e.g. os.File) the transfer size will
  19. // be determined automatically. Seek will not be attempted when the
  20. // transfer size option is set with SetSize.
  21. //
  22. // The value provided will be used only if SetSize is called before ReadFrom
  23. // and only on in a server read handler.
  24. SetSize(n int64)
  25. // RemoteAddr returns the remote peer's IP address and port.
  26. RemoteAddr() net.UDPAddr
  27. }
  28. type sender struct {
  29. conn *net.UDPConn
  30. addr *net.UDPAddr
  31. tid int
  32. send []byte
  33. receive []byte
  34. retry *backoff
  35. timeout time.Duration
  36. retries int
  37. block uint16
  38. mode string
  39. opts options
  40. }
  41. func (s *sender) RemoteAddr() net.UDPAddr { return *s.addr }
  42. func (s *sender) SetSize(n int64) {
  43. if s.opts != nil {
  44. if _, ok := s.opts["tsize"]; ok {
  45. s.opts["tsize"] = strconv.FormatInt(n, 10)
  46. }
  47. }
  48. }
  49. func (s *sender) ReadFrom(r io.Reader) (n int64, err error) {
  50. if s.mode == "netascii" {
  51. r = netascii.ToReader(r)
  52. }
  53. if s.opts != nil {
  54. // check that tsize is set
  55. if ts, ok := s.opts["tsize"]; ok {
  56. // check that tsize is not set with SetSize already
  57. i, err := strconv.ParseInt(ts, 10, 64)
  58. if err == nil && i == 0 {
  59. if rs, ok := r.(io.Seeker); ok {
  60. pos, err := rs.Seek(0, 1)
  61. if err != nil {
  62. return 0, err
  63. }
  64. size, err := rs.Seek(0, 2)
  65. if err != nil {
  66. return 0, err
  67. }
  68. s.opts["tsize"] = strconv.FormatInt(size, 10)
  69. _, err = rs.Seek(pos, 0)
  70. if err != nil {
  71. return 0, err
  72. }
  73. }
  74. }
  75. }
  76. err = s.sendOptions()
  77. if err != nil {
  78. s.abort(err)
  79. return 0, err
  80. }
  81. }
  82. s.block = 1 // start data transmission with block 1
  83. binary.BigEndian.PutUint16(s.send[0:2], opDATA)
  84. for {
  85. l, err := io.ReadFull(r, s.send[4:])
  86. n += int64(l)
  87. if err != nil && err != io.ErrUnexpectedEOF {
  88. if err == io.EOF {
  89. binary.BigEndian.PutUint16(s.send[2:4], s.block)
  90. _, err = s.sendWithRetry(4)
  91. if err != nil {
  92. s.abort(err)
  93. return n, err
  94. }
  95. s.conn.Close()
  96. return n, nil
  97. }
  98. s.abort(err)
  99. return n, err
  100. }
  101. binary.BigEndian.PutUint16(s.send[2:4], s.block)
  102. _, err = s.sendWithRetry(4 + l)
  103. if err != nil {
  104. s.abort(err)
  105. return n, err
  106. }
  107. if l < len(s.send)-4 {
  108. s.conn.Close()
  109. return n, nil
  110. }
  111. s.block++
  112. }
  113. }
  114. func (s *sender) sendOptions() error {
  115. for name, value := range s.opts {
  116. if name == "blksize" {
  117. err := s.setBlockSize(value)
  118. if err != nil {
  119. delete(s.opts, name)
  120. continue
  121. }
  122. } else if name == "tsize" {
  123. if value != "0" {
  124. s.opts["tsize"] = value
  125. } else {
  126. delete(s.opts, name)
  127. continue
  128. }
  129. } else {
  130. delete(s.opts, name)
  131. }
  132. }
  133. if len(s.opts) > 0 {
  134. m := packOACK(s.send, s.opts)
  135. _, err := s.sendWithRetry(m)
  136. if err != nil {
  137. return err
  138. }
  139. }
  140. return nil
  141. }
  142. func (s *sender) setBlockSize(blksize string) error {
  143. n, err := strconv.Atoi(blksize)
  144. if err != nil {
  145. return err
  146. }
  147. if n < 512 {
  148. return fmt.Errorf("blkzise too small: %d", n)
  149. }
  150. if n > 65464 {
  151. return fmt.Errorf("blksize too large: %d", n)
  152. }
  153. s.send = make([]byte, n+4)
  154. return nil
  155. }
  156. func (s *sender) sendWithRetry(l int) (*net.UDPAddr, error) {
  157. s.retry.reset()
  158. for {
  159. addr, err := s.sendDatagram(l)
  160. if _, ok := err.(net.Error); ok && s.retry.count() < s.retries {
  161. s.retry.backoff()
  162. continue
  163. }
  164. return addr, err
  165. }
  166. }
  167. func (s *sender) sendDatagram(l int) (*net.UDPAddr, error) {
  168. err := s.conn.SetReadDeadline(time.Now().Add(s.timeout))
  169. if err != nil {
  170. return nil, err
  171. }
  172. _, err = s.conn.WriteToUDP(s.send[:l], s.addr)
  173. if err != nil {
  174. return nil, err
  175. }
  176. for {
  177. n, addr, err := s.conn.ReadFromUDP(s.receive)
  178. if err != nil {
  179. return nil, err
  180. }
  181. if !addr.IP.Equal(s.addr.IP) || (s.tid != 0 && addr.Port != s.tid) {
  182. continue
  183. }
  184. p, err := parsePacket(s.receive[:n])
  185. if err != nil {
  186. continue
  187. }
  188. s.tid = addr.Port
  189. switch p := p.(type) {
  190. case pACK:
  191. if p.block() == s.block {
  192. return addr, nil
  193. }
  194. case pOACK:
  195. opts, err := unpackOACK(p)
  196. if s.block != 0 {
  197. continue
  198. }
  199. if err != nil {
  200. s.abort(err)
  201. return addr, err
  202. }
  203. for name, value := range opts {
  204. if name == "blksize" {
  205. err := s.setBlockSize(value)
  206. if err != nil {
  207. continue
  208. }
  209. }
  210. }
  211. return addr, nil
  212. case pERROR:
  213. return nil, fmt.Errorf("sending block %d: code=%d, error: %s",
  214. s.block, p.code(), p.message())
  215. }
  216. }
  217. }
  218. func (s *sender) abort(err error) error {
  219. if s.conn == nil {
  220. return nil
  221. }
  222. n := packERROR(s.send, 1, err.Error())
  223. _, err = s.conn.WriteToUDP(s.send[:n], s.addr)
  224. if err != nil {
  225. return err
  226. }
  227. s.conn.Close()
  228. s.conn = nil
  229. return nil
  230. }