123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- package tftp
- import (
- "fmt"
- "io"
- "net"
- "sync"
- "time"
- )
- // NewServer creates TFTP server. It requires two functions to handle
- // read and write requests.
- // In case nil is provided for read or write handler the respective
- // operation is disabled.
- func NewServer(readHandler func(filename string, rf io.ReaderFrom) error,
- writeHandler func(filename string, wt io.WriterTo) error) *Server {
- return &Server{
- readHandler: readHandler,
- writeHandler: writeHandler,
- timeout: defaultTimeout,
- retries: defaultRetries,
- }
- }
- type Server struct {
- readHandler func(filename string, rf io.ReaderFrom) error
- writeHandler func(filename string, wt io.WriterTo) error
- backoff backoffFunc
- conn *net.UDPConn
- quit chan chan struct{}
- wg sync.WaitGroup
- timeout time.Duration
- retries int
- }
- // SetTimeout sets maximum time server waits for single network
- // round-trip to succeed.
- // Default is 5 seconds.
- func (s *Server) SetTimeout(t time.Duration) {
- if t <= 0 {
- s.timeout = defaultTimeout
- } else {
- s.timeout = t
- }
- }
- // SetRetries sets maximum number of attempts server made to transmit a
- // packet.
- // Default is 5 attempts.
- func (s *Server) SetRetries(count int) {
- if count < 1 {
- s.retries = defaultRetries
- } else {
- s.retries = count
- }
- }
- // SetBackoff sets a user provided function that is called to provide a
- // backoff duration prior to retransmitting an unacknowledged packet.
- func (s *Server) SetBackoff(h backoffFunc) {
- s.backoff = h
- }
- // ListenAndServe binds to address provided and start the server.
- // ListenAndServe returns when Shutdown is called.
- func (s *Server) ListenAndServe(addr string) error {
- a, err := net.ResolveUDPAddr("udp", addr)
- if err != nil {
- return err
- }
- conn, err := net.ListenUDP("udp", a)
- if err != nil {
- return err
- }
- s.Serve(conn)
- return nil
- }
- // Serve starts server provided already opened UDP connecton. It is
- // useful for the case when you want to run server in separate goroutine
- // but still want to be able to handle any errors opening connection.
- // Serve returns when Shutdown is called or connection is closed.
- func (s *Server) Serve(conn *net.UDPConn) {
- s.conn = conn
- s.quit = make(chan chan struct{})
- for {
- select {
- case q := <-s.quit:
- q <- struct{}{}
- return
- default:
- err := s.processRequest(s.conn)
- if err != nil {
- // TODO: add logging handler
- }
- }
- }
- }
- // Shutdown make server stop listening for new requests, allows
- // server to finish outstanding transfers and stops server.
- func (s *Server) Shutdown() {
- s.conn.Close()
- q := make(chan struct{})
- s.quit <- q
- <-q
- s.wg.Wait()
- }
- func (s *Server) processRequest(conn *net.UDPConn) error {
- var buffer []byte
- buffer = make([]byte, datagramLength)
- n, remoteAddr, err := conn.ReadFromUDP(buffer)
- if err != nil {
- return fmt.Errorf("reading UDP: %v", err)
- }
- p, err := parsePacket(buffer[:n])
- if err != nil {
- return err
- }
- switch p := p.(type) {
- case pWRQ:
- filename, mode, opts, err := unpackRQ(p)
- if err != nil {
- return fmt.Errorf("unpack WRQ: %v", err)
- }
- //fmt.Printf("got WRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts)
- conn, err := net.ListenUDP("udp", &net.UDPAddr{})
- if err != nil {
- return err
- }
- if err != nil {
- return fmt.Errorf("open transmission: %v", err)
- }
- wt := &receiver{
- send: make([]byte, datagramLength),
- receive: make([]byte, datagramLength),
- conn: conn,
- retry: &backoff{handler: s.backoff},
- timeout: s.timeout,
- retries: s.retries,
- addr: remoteAddr,
- mode: mode,
- opts: opts,
- }
- s.wg.Add(1)
- go func() {
- if s.writeHandler != nil {
- err := s.writeHandler(filename, wt)
- if err != nil {
- wt.abort(err)
- } else {
- wt.terminate()
- wt.conn.Close()
- }
- } else {
- wt.abort(fmt.Errorf("server does not support write requests"))
- }
- s.wg.Done()
- }()
- case pRRQ:
- filename, mode, opts, err := unpackRQ(p)
- if err != nil {
- return fmt.Errorf("unpack RRQ: %v", err)
- }
- //fmt.Printf("got RRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts)
- conn, err := net.ListenUDP("udp", &net.UDPAddr{})
- if err != nil {
- return err
- }
- rf := &sender{
- send: make([]byte, datagramLength),
- receive: make([]byte, datagramLength),
- tid: remoteAddr.Port,
- conn: conn,
- retry: &backoff{handler: s.backoff},
- timeout: s.timeout,
- retries: s.retries,
- addr: remoteAddr,
- mode: mode,
- opts: opts,
- }
- s.wg.Add(1)
- go func() {
- if s.readHandler != nil {
- err := s.readHandler(filename, rf)
- if err != nil {
- rf.abort(err)
- }
- } else {
- rf.abort(fmt.Errorf("server does not support read requests"))
- }
- s.wg.Done()
- }()
- default:
- return fmt.Errorf("unexpected %T", p)
- }
- return nil
- }
|