receiver.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. package tftp
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "io"
  6. "net"
  7. "strconv"
  8. "time"
  9. "github.com/pin/tftp/netascii"
  10. )
  11. // IncomingTransfer provides methods that expose information associated with
  12. // an incoming transfer.
  13. type IncomingTransfer interface {
  14. // Size returns the size of an incoming file if the request included the
  15. // tsize option (see RFC2349). To differentiate a zero-sized file transfer
  16. // from a request without tsize use the second boolean "ok" return value.
  17. Size() (n int64, ok bool)
  18. // RemoteAddr returns the remote peer's IP address and port.
  19. RemoteAddr() net.UDPAddr
  20. }
  21. func (r *receiver) RemoteAddr() net.UDPAddr { return *r.addr }
  22. func (r *receiver) Size() (n int64, ok bool) {
  23. if r.opts != nil {
  24. if s, ok := r.opts["tsize"]; ok {
  25. n, err := strconv.ParseInt(s, 10, 64)
  26. if err != nil {
  27. return 0, false
  28. }
  29. return n, true
  30. }
  31. }
  32. return 0, false
  33. }
  34. type receiver struct {
  35. send []byte
  36. receive []byte
  37. addr *net.UDPAddr
  38. tid int
  39. conn *net.UDPConn
  40. block uint16
  41. retry *backoff
  42. timeout time.Duration
  43. retries int
  44. l int
  45. autoTerm bool
  46. dally bool
  47. mode string
  48. opts options
  49. }
  50. func (r *receiver) WriteTo(w io.Writer) (n int64, err error) {
  51. if r.mode == "netascii" {
  52. w = netascii.FromWriter(w)
  53. }
  54. if r.opts != nil {
  55. err := r.sendOptions()
  56. if err != nil {
  57. r.abort(err)
  58. return 0, err
  59. }
  60. }
  61. binary.BigEndian.PutUint16(r.send[0:2], opACK)
  62. for {
  63. if r.l > 0 {
  64. l, err := w.Write(r.receive[4:r.l])
  65. n += int64(l)
  66. if err != nil {
  67. r.abort(err)
  68. return n, err
  69. }
  70. if r.l < len(r.receive) {
  71. if r.autoTerm {
  72. r.terminate()
  73. r.conn.Close()
  74. }
  75. return n, nil
  76. }
  77. }
  78. binary.BigEndian.PutUint16(r.send[2:4], r.block)
  79. r.block++ // send ACK for current block and expect next one
  80. ll, _, err := r.receiveWithRetry(4)
  81. if err != nil {
  82. r.abort(err)
  83. return n, err
  84. }
  85. r.l = ll
  86. }
  87. }
  88. func (r *receiver) sendOptions() error {
  89. for name, value := range r.opts {
  90. if name == "blksize" {
  91. err := r.setBlockSize(value)
  92. if err != nil {
  93. delete(r.opts, name)
  94. continue
  95. }
  96. } else {
  97. delete(r.opts, name)
  98. }
  99. }
  100. if len(r.opts) > 0 {
  101. m := packOACK(r.send, r.opts)
  102. r.block = 1 // expect data block number 1
  103. ll, _, err := r.receiveWithRetry(m)
  104. if err != nil {
  105. r.abort(err)
  106. return err
  107. }
  108. r.l = ll
  109. }
  110. return nil
  111. }
  112. func (r *receiver) setBlockSize(blksize string) error {
  113. n, err := strconv.Atoi(blksize)
  114. if err != nil {
  115. return err
  116. }
  117. if n < 512 {
  118. return fmt.Errorf("blkzise too small: %d", n)
  119. }
  120. if n > 65464 {
  121. return fmt.Errorf("blksize too large: %d", n)
  122. }
  123. r.receive = make([]byte, n+4)
  124. return nil
  125. }
  126. func (r *receiver) receiveWithRetry(l int) (int, *net.UDPAddr, error) {
  127. r.retry.reset()
  128. for {
  129. n, addr, err := r.receiveDatagram(l)
  130. if _, ok := err.(net.Error); ok && r.retry.count() < r.retries {
  131. r.retry.backoff()
  132. continue
  133. }
  134. return n, addr, err
  135. }
  136. }
  137. func (r *receiver) receiveDatagram(l int) (int, *net.UDPAddr, error) {
  138. err := r.conn.SetReadDeadline(time.Now().Add(r.timeout))
  139. if err != nil {
  140. return 0, nil, err
  141. }
  142. _, err = r.conn.WriteToUDP(r.send[:l], r.addr)
  143. if err != nil {
  144. return 0, nil, err
  145. }
  146. for {
  147. c, addr, err := r.conn.ReadFromUDP(r.receive)
  148. if err != nil {
  149. return 0, nil, err
  150. }
  151. if !addr.IP.Equal(r.addr.IP) || (r.tid != 0 && addr.Port != r.tid) {
  152. continue
  153. }
  154. p, err := parsePacket(r.receive[:c])
  155. if err != nil {
  156. return 0, addr, err
  157. }
  158. r.tid = addr.Port
  159. switch p := p.(type) {
  160. case pDATA:
  161. if p.block() == r.block {
  162. return c, addr, nil
  163. }
  164. case pOACK:
  165. opts, err := unpackOACK(p)
  166. if r.block != 1 {
  167. continue
  168. }
  169. if err != nil {
  170. r.abort(err)
  171. return 0, addr, err
  172. }
  173. for name, value := range opts {
  174. if name == "blksize" {
  175. err := r.setBlockSize(value)
  176. if err != nil {
  177. continue
  178. }
  179. }
  180. }
  181. r.block = 0 // ACK with block number 0
  182. r.opts = opts
  183. return 0, addr, nil
  184. case pERROR:
  185. return 0, addr, fmt.Errorf("code: %d, message: %s",
  186. p.code(), p.message())
  187. }
  188. }
  189. }
  190. func (r *receiver) terminate() error {
  191. binary.BigEndian.PutUint16(r.send[2:4], r.block)
  192. if r.dally {
  193. for i := 0; i < 3; i++ {
  194. _, _, err := r.receiveDatagram(4)
  195. if err != nil {
  196. return nil
  197. }
  198. }
  199. return fmt.Errorf("dallying termination failed")
  200. } else {
  201. _, err := r.conn.WriteToUDP(r.send[:4], r.addr)
  202. if err != nil {
  203. return err
  204. }
  205. }
  206. return nil
  207. }
  208. func (r *receiver) abort(err error) error {
  209. if r.conn == nil {
  210. return nil
  211. }
  212. n := packERROR(r.send, 1, err.Error())
  213. _, err = r.conn.WriteToUDP(r.send[:n], r.addr)
  214. if err != nil {
  215. return err
  216. }
  217. r.conn.Close()
  218. r.conn = nil
  219. return nil
  220. }