server.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. package tftp
  2. import (
  3. "fmt"
  4. "io"
  5. "net"
  6. "sync"
  7. "time"
  8. )
  9. // NewServer creates TFTP server. It requires two functions to handle
  10. // read and write requests.
  11. // In case nil is provided for read or write handler the respective
  12. // operation is disabled.
  13. func NewServer(readHandler func(filename string, rf io.ReaderFrom) error,
  14. writeHandler func(filename string, wt io.WriterTo) error) *Server {
  15. return &Server{
  16. readHandler: readHandler,
  17. writeHandler: writeHandler,
  18. timeout: defaultTimeout,
  19. retries: defaultRetries,
  20. }
  21. }
  22. type Server struct {
  23. readHandler func(filename string, rf io.ReaderFrom) error
  24. writeHandler func(filename string, wt io.WriterTo) error
  25. backoff backoffFunc
  26. conn *net.UDPConn
  27. quit chan chan struct{}
  28. wg sync.WaitGroup
  29. timeout time.Duration
  30. retries int
  31. }
  32. // SetTimeout sets maximum time server waits for single network
  33. // round-trip to succeed.
  34. // Default is 5 seconds.
  35. func (s *Server) SetTimeout(t time.Duration) {
  36. if t <= 0 {
  37. s.timeout = defaultTimeout
  38. } else {
  39. s.timeout = t
  40. }
  41. }
  42. // SetRetries sets maximum number of attempts server made to transmit a
  43. // packet.
  44. // Default is 5 attempts.
  45. func (s *Server) SetRetries(count int) {
  46. if count < 1 {
  47. s.retries = defaultRetries
  48. } else {
  49. s.retries = count
  50. }
  51. }
  52. // SetBackoff sets a user provided function that is called to provide a
  53. // backoff duration prior to retransmitting an unacknowledged packet.
  54. func (s *Server) SetBackoff(h backoffFunc) {
  55. s.backoff = h
  56. }
  57. // ListenAndServe binds to address provided and start the server.
  58. // ListenAndServe returns when Shutdown is called.
  59. func (s *Server) ListenAndServe(addr string) error {
  60. a, err := net.ResolveUDPAddr("udp", addr)
  61. if err != nil {
  62. return err
  63. }
  64. conn, err := net.ListenUDP("udp", a)
  65. if err != nil {
  66. return err
  67. }
  68. s.Serve(conn)
  69. return nil
  70. }
  71. // Serve starts server provided already opened UDP connecton. It is
  72. // useful for the case when you want to run server in separate goroutine
  73. // but still want to be able to handle any errors opening connection.
  74. // Serve returns when Shutdown is called or connection is closed.
  75. func (s *Server) Serve(conn *net.UDPConn) {
  76. s.conn = conn
  77. s.quit = make(chan chan struct{})
  78. for {
  79. select {
  80. case q := <-s.quit:
  81. q <- struct{}{}
  82. return
  83. default:
  84. err := s.processRequest(s.conn)
  85. if err != nil {
  86. // TODO: add logging handler
  87. }
  88. }
  89. }
  90. }
  91. // Shutdown make server stop listening for new requests, allows
  92. // server to finish outstanding transfers and stops server.
  93. func (s *Server) Shutdown() {
  94. s.conn.Close()
  95. q := make(chan struct{})
  96. s.quit <- q
  97. <-q
  98. s.wg.Wait()
  99. }
  100. func (s *Server) processRequest(conn *net.UDPConn) error {
  101. var buffer []byte
  102. buffer = make([]byte, datagramLength)
  103. n, remoteAddr, err := conn.ReadFromUDP(buffer)
  104. if err != nil {
  105. return fmt.Errorf("reading UDP: %v", err)
  106. }
  107. p, err := parsePacket(buffer[:n])
  108. if err != nil {
  109. return err
  110. }
  111. switch p := p.(type) {
  112. case pWRQ:
  113. filename, mode, opts, err := unpackRQ(p)
  114. if err != nil {
  115. return fmt.Errorf("unpack WRQ: %v", err)
  116. }
  117. //fmt.Printf("got WRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts)
  118. conn, err := net.ListenUDP("udp", &net.UDPAddr{})
  119. if err != nil {
  120. return err
  121. }
  122. if err != nil {
  123. return fmt.Errorf("open transmission: %v", err)
  124. }
  125. wt := &receiver{
  126. send: make([]byte, datagramLength),
  127. receive: make([]byte, datagramLength),
  128. conn: conn,
  129. retry: &backoff{handler: s.backoff},
  130. timeout: s.timeout,
  131. retries: s.retries,
  132. addr: remoteAddr,
  133. mode: mode,
  134. opts: opts,
  135. }
  136. s.wg.Add(1)
  137. go func() {
  138. if s.writeHandler != nil {
  139. err := s.writeHandler(filename, wt)
  140. if err != nil {
  141. wt.abort(err)
  142. } else {
  143. wt.terminate()
  144. wt.conn.Close()
  145. }
  146. } else {
  147. wt.abort(fmt.Errorf("server does not support write requests"))
  148. }
  149. s.wg.Done()
  150. }()
  151. case pRRQ:
  152. filename, mode, opts, err := unpackRQ(p)
  153. if err != nil {
  154. return fmt.Errorf("unpack RRQ: %v", err)
  155. }
  156. //fmt.Printf("got RRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts)
  157. conn, err := net.ListenUDP("udp", &net.UDPAddr{})
  158. if err != nil {
  159. return err
  160. }
  161. rf := &sender{
  162. send: make([]byte, datagramLength),
  163. receive: make([]byte, datagramLength),
  164. tid: remoteAddr.Port,
  165. conn: conn,
  166. retry: &backoff{handler: s.backoff},
  167. timeout: s.timeout,
  168. retries: s.retries,
  169. addr: remoteAddr,
  170. mode: mode,
  171. opts: opts,
  172. }
  173. s.wg.Add(1)
  174. go func() {
  175. if s.readHandler != nil {
  176. err := s.readHandler(filename, rf)
  177. if err != nil {
  178. rf.abort(err)
  179. }
  180. } else {
  181. rf.abort(fmt.Errorf("server does not support read requests"))
  182. }
  183. s.wg.Done()
  184. }()
  185. default:
  186. return fmt.Errorf("unexpected %T", p)
  187. }
  188. return nil
  189. }