tftp_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. package tftp
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "math/rand"
  8. "net"
  9. "os"
  10. "strconv"
  11. "sync"
  12. "testing"
  13. "testing/iotest"
  14. "time"
  15. )
  16. var localhost string = determineLocalhost()
  17. func determineLocalhost() string {
  18. l, err := net.ListenTCP("tcp", nil)
  19. if err != nil {
  20. panic(fmt.Sprintf("ListenTCP error: %s", err))
  21. }
  22. _, lport, _ := net.SplitHostPort(l.Addr().String())
  23. defer l.Close()
  24. lo := make(chan string)
  25. go func() {
  26. for {
  27. conn, err := l.Accept()
  28. if err != nil {
  29. break
  30. }
  31. conn.Close()
  32. }
  33. }()
  34. go func() {
  35. port, _ := strconv.Atoi(lport)
  36. for _, af := range []string{"tcp6", "tcp4"} {
  37. conn, err := net.DialTCP(af, &net.TCPAddr{}, &net.TCPAddr{Port: port})
  38. if err == nil {
  39. conn.Close()
  40. host, _, _ := net.SplitHostPort(conn.LocalAddr().String())
  41. lo <- host
  42. return
  43. }
  44. }
  45. panic("could not determine address family")
  46. }()
  47. return <-lo
  48. }
  49. func localSystem(c *net.UDPConn) string {
  50. _, port, _ := net.SplitHostPort(c.LocalAddr().String())
  51. return net.JoinHostPort(localhost, port)
  52. }
  53. func TestPackUnpack(t *testing.T) {
  54. v := []string{"test-filename/with-subdir"}
  55. testOptsList := []options{
  56. nil,
  57. options{
  58. "tsize": "1234",
  59. "blksize": "22",
  60. },
  61. }
  62. for _, filename := range v {
  63. for _, mode := range []string{"octet", "netascii"} {
  64. for _, opts := range testOptsList {
  65. packUnpack(t, filename, mode, opts)
  66. }
  67. }
  68. }
  69. }
  70. func packUnpack(t *testing.T, filename, mode string, opts options) {
  71. b := make([]byte, datagramLength)
  72. for _, op := range []uint16{opRRQ, opWRQ} {
  73. n := packRQ(b, op, filename, mode, opts)
  74. f, m, o, err := unpackRQ(b[:n])
  75. if err != nil {
  76. t.Errorf("%s pack/unpack: %v", filename, err)
  77. }
  78. if f != filename {
  79. t.Errorf("filename mismatch (%s): '%x' vs '%x'",
  80. filename, f, filename)
  81. }
  82. if m != mode {
  83. t.Errorf("mode mismatch (%s): '%x' vs '%x'",
  84. mode, m, mode)
  85. }
  86. if opts != nil {
  87. for name, value := range opts {
  88. v, ok := o[name]
  89. if !ok {
  90. t.Errorf("missing %s option", name)
  91. }
  92. if v != value {
  93. t.Errorf("option %s mismatch: '%x' vs '%x'", name, v, value)
  94. }
  95. }
  96. }
  97. }
  98. }
  99. func TestZeroLength(t *testing.T) {
  100. s, c := makeTestServer()
  101. defer s.Shutdown()
  102. testSendReceive(t, c, 0)
  103. }
  104. func Test900(t *testing.T) {
  105. s, c := makeTestServer()
  106. defer s.Shutdown()
  107. for i := 600; i < 4000; i += 1 {
  108. c.blksize = i
  109. testSendReceive(t, c, 9000+int64(i))
  110. }
  111. }
  112. func Test1000(t *testing.T) {
  113. s, c := makeTestServer()
  114. defer s.Shutdown()
  115. for i := int64(0); i < 5000; i++ {
  116. filename := fmt.Sprintf("length-%d-bytes-%d", i, time.Now().UnixNano())
  117. rf, err := c.Send(filename, "octet")
  118. if err != nil {
  119. t.Fatalf("requesting %s write: %v", filename, err)
  120. }
  121. r := io.LimitReader(newRandReader(rand.NewSource(i)), i)
  122. n, err := rf.ReadFrom(r)
  123. if err != nil {
  124. t.Fatalf("sending %s: %v", filename, err)
  125. }
  126. if n != i {
  127. t.Errorf("%s length mismatch: %d != %d", filename, n, i)
  128. }
  129. }
  130. }
  131. func Test1810(t *testing.T) {
  132. s, c := makeTestServer()
  133. defer s.Shutdown()
  134. c.blksize = 1810
  135. testSendReceive(t, c, 9000+1810)
  136. }
  137. func TestTSize(t *testing.T) {
  138. s, c := makeTestServer()
  139. defer s.Shutdown()
  140. c.tsize = true
  141. testSendReceive(t, c, 640)
  142. }
  143. func TestNearBlockLength(t *testing.T) {
  144. s, c := makeTestServer()
  145. defer s.Shutdown()
  146. for i := 450; i < 520; i++ {
  147. testSendReceive(t, c, int64(i))
  148. }
  149. }
  150. func TestBlockWrapsAround(t *testing.T) {
  151. s, c := makeTestServer()
  152. defer s.Shutdown()
  153. n := 65535 * 512
  154. for i := n - 2; i < n+2; i++ {
  155. testSendReceive(t, c, int64(i))
  156. }
  157. }
  158. func TestRandomLength(t *testing.T) {
  159. s, c := makeTestServer()
  160. defer s.Shutdown()
  161. r := rand.New(rand.NewSource(42))
  162. for i := 0; i < 100; i++ {
  163. testSendReceive(t, c, r.Int63n(100000))
  164. }
  165. }
  166. func TestBigFile(t *testing.T) {
  167. s, c := makeTestServer()
  168. defer s.Shutdown()
  169. testSendReceive(t, c, 3*1000*1000)
  170. }
  171. func TestByOneByte(t *testing.T) {
  172. s, c := makeTestServer()
  173. defer s.Shutdown()
  174. filename := "test-by-one-byte"
  175. mode := "octet"
  176. const length = 80000
  177. sender, err := c.Send(filename, mode)
  178. if err != nil {
  179. t.Fatalf("requesting write: %v", err)
  180. }
  181. r := iotest.OneByteReader(io.LimitReader(
  182. newRandReader(rand.NewSource(42)), length))
  183. n, err := sender.ReadFrom(r)
  184. if err != nil {
  185. t.Fatalf("send error: %v", err)
  186. }
  187. if n != length {
  188. t.Errorf("%s read length mismatch: %d != %d", filename, n, length)
  189. }
  190. readTransfer, err := c.Receive(filename, mode)
  191. if err != nil {
  192. t.Fatalf("requesting read %s: %v", filename, err)
  193. }
  194. buf := &bytes.Buffer{}
  195. n, err = readTransfer.WriteTo(buf)
  196. if err != nil {
  197. t.Fatalf("%s read error: %v", filename, err)
  198. }
  199. if n != length {
  200. t.Errorf("%s read length mismatch: %d != %d", filename, n, length)
  201. }
  202. bs, _ := ioutil.ReadAll(io.LimitReader(
  203. newRandReader(rand.NewSource(42)), length))
  204. if !bytes.Equal(bs, buf.Bytes()) {
  205. t.Errorf("\nsent: %x\nrcvd: %x", bs, buf)
  206. }
  207. }
  208. func TestDuplicate(t *testing.T) {
  209. s, c := makeTestServer()
  210. defer s.Shutdown()
  211. filename := "test-duplicate"
  212. mode := "octet"
  213. bs := []byte("lalala")
  214. sender, err := c.Send(filename, mode)
  215. if err != nil {
  216. t.Fatalf("requesting write: %v", err)
  217. }
  218. buf := bytes.NewBuffer(bs)
  219. _, err = sender.ReadFrom(buf)
  220. if err != nil {
  221. t.Fatalf("send error: %v", err)
  222. }
  223. sender, err = c.Send(filename, mode)
  224. if err == nil {
  225. t.Fatalf("file already exists")
  226. }
  227. t.Logf("sending file that already exists: %v", err)
  228. }
  229. func TestNotFound(t *testing.T) {
  230. s, c := makeTestServer()
  231. defer s.Shutdown()
  232. filename := "test-not-exists"
  233. mode := "octet"
  234. _, err := c.Receive(filename, mode)
  235. if err == nil {
  236. t.Fatalf("file not exists", err)
  237. }
  238. t.Logf("receiving file that does not exist: %v", err)
  239. }
  240. func testSendReceive(t *testing.T, client *Client, length int64) {
  241. filename := fmt.Sprintf("length-%d-bytes", length)
  242. mode := "octet"
  243. writeTransfer, err := client.Send(filename, mode)
  244. if err != nil {
  245. t.Fatalf("requesting write %s: %v", filename, err)
  246. }
  247. r := io.LimitReader(newRandReader(rand.NewSource(42)), length)
  248. n, err := writeTransfer.ReadFrom(r)
  249. if err != nil {
  250. t.Fatalf("%s write error: %v", filename, err)
  251. }
  252. if n != length {
  253. t.Errorf("%s write length mismatch: %d != %d", filename, n, length)
  254. }
  255. readTransfer, err := client.Receive(filename, mode)
  256. if err != nil {
  257. t.Fatalf("requesting read %s: %v", filename, err)
  258. }
  259. if it, ok := readTransfer.(IncomingTransfer); ok {
  260. if n, ok := it.Size(); ok {
  261. fmt.Printf("Transfer size: %d\n", n)
  262. if n != length {
  263. t.Errorf("tsize mismatch: %d vs %d", n, length)
  264. }
  265. }
  266. }
  267. buf := &bytes.Buffer{}
  268. n, err = readTransfer.WriteTo(buf)
  269. if err != nil {
  270. t.Fatalf("%s read error: %v", filename, err)
  271. }
  272. if n != length {
  273. t.Errorf("%s read length mismatch: %d != %d", filename, n, length)
  274. }
  275. bs, _ := ioutil.ReadAll(io.LimitReader(
  276. newRandReader(rand.NewSource(42)), length))
  277. if !bytes.Equal(bs, buf.Bytes()) {
  278. t.Errorf("\nsent: %x\nrcvd: %x", bs, buf)
  279. }
  280. }
  281. func TestSendTsizeFromSeek(t *testing.T) {
  282. // create read-only server
  283. s := NewServer(func(filename string, rf io.ReaderFrom) error {
  284. b := make([]byte, 100)
  285. rr := newRandReader(rand.NewSource(42))
  286. rr.Read(b)
  287. // bytes.Reader implements io.Seek
  288. r := bytes.NewReader(b)
  289. _, err := rf.ReadFrom(r)
  290. if err != nil {
  291. t.Errorf("sending bytes: %v", err)
  292. }
  293. return nil
  294. }, nil)
  295. conn, err := net.ListenUDP("udp", &net.UDPAddr{})
  296. if err != nil {
  297. t.Fatalf("listening: %v", err)
  298. }
  299. go s.Serve(conn)
  300. defer s.Shutdown()
  301. c, _ := NewClient(localSystem(conn))
  302. c.tsize = true
  303. r, _ := c.Receive("f", "octet")
  304. var size int64
  305. if t, ok := r.(IncomingTransfer); ok {
  306. if n, ok := t.Size(); ok {
  307. size = n
  308. fmt.Printf("Transfer size: %d\n", n)
  309. }
  310. }
  311. if size != 100 {
  312. t.Errorf("size expected: 100, got %d", size)
  313. }
  314. r.WriteTo(ioutil.Discard)
  315. }
  316. type testBackend struct {
  317. m map[string][]byte
  318. mu sync.Mutex
  319. }
  320. func makeTestServer() (*Server, *Client) {
  321. b := &testBackend{}
  322. b.m = make(map[string][]byte)
  323. // Create server
  324. s := NewServer(b.handleRead, b.handleWrite)
  325. conn, err := net.ListenUDP("udp", &net.UDPAddr{})
  326. if err != nil {
  327. panic(err)
  328. }
  329. go s.Serve(conn)
  330. // Create client for that server
  331. c, err := NewClient(localSystem(conn))
  332. if err != nil {
  333. panic(err)
  334. }
  335. return s, c
  336. }
  337. func TestNoHandlers(t *testing.T) {
  338. s := NewServer(nil, nil)
  339. conn, err := net.ListenUDP("udp", &net.UDPAddr{})
  340. if err != nil {
  341. panic(err)
  342. }
  343. go s.Serve(conn)
  344. c, err := NewClient(localSystem(conn))
  345. if err != nil {
  346. panic(err)
  347. }
  348. _, err = c.Send("test", "octet")
  349. if err == nil {
  350. t.Errorf("error expected")
  351. }
  352. _, err = c.Receive("test", "octet")
  353. if err == nil {
  354. t.Errorf("error expected")
  355. }
  356. }
  357. func (b *testBackend) handleWrite(filename string, wt io.WriterTo) error {
  358. b.mu.Lock()
  359. defer b.mu.Unlock()
  360. _, ok := b.m[filename]
  361. if ok {
  362. fmt.Fprintf(os.Stderr, "File %s already exists\n", filename)
  363. return fmt.Errorf("file already exists")
  364. }
  365. if t, ok := wt.(IncomingTransfer); ok {
  366. if n, ok := t.Size(); ok {
  367. fmt.Printf("Transfer size: %d\n", n)
  368. }
  369. }
  370. buf := &bytes.Buffer{}
  371. _, err := wt.WriteTo(buf)
  372. if err != nil {
  373. fmt.Fprintf(os.Stderr, "Can't receive %s: %v\n", filename, err)
  374. return err
  375. }
  376. b.m[filename] = buf.Bytes()
  377. return nil
  378. }
  379. func (b *testBackend) handleRead(filename string, rf io.ReaderFrom) error {
  380. b.mu.Lock()
  381. defer b.mu.Unlock()
  382. bs, ok := b.m[filename]
  383. if !ok {
  384. fmt.Fprintf(os.Stderr, "File %s not found\n", filename)
  385. return fmt.Errorf("file not found")
  386. }
  387. if t, ok := rf.(OutgoingTransfer); ok {
  388. t.SetSize(int64(len(bs)))
  389. }
  390. _, err := rf.ReadFrom(bytes.NewBuffer(bs))
  391. if err != nil {
  392. fmt.Fprintf(os.Stderr, "Can't send %s: %v\n", filename, err)
  393. return err
  394. }
  395. return nil
  396. }
  397. type randReader struct {
  398. src rand.Source
  399. next int64
  400. i int8
  401. }
  402. func newRandReader(src rand.Source) io.Reader {
  403. r := &randReader{
  404. src: src,
  405. next: src.Int63(),
  406. }
  407. return r
  408. }
  409. func (r *randReader) Read(p []byte) (n int, err error) {
  410. next, i := r.next, r.i
  411. for n = 0; n < len(p); n++ {
  412. if i == 7 {
  413. next, i = r.src.Int63(), 0
  414. }
  415. p[n] = byte(next)
  416. next >>= 8
  417. i++
  418. }
  419. r.next, r.i = next, i
  420. return
  421. }
  422. func TestServerSendTimeout(t *testing.T) {
  423. s, c := makeTestServer()
  424. s.SetTimeout(time.Second)
  425. s.SetRetries(2)
  426. var serverErr error
  427. s.readHandler = func(filename string, rf io.ReaderFrom) error {
  428. r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000)
  429. _, serverErr = rf.ReadFrom(r)
  430. return serverErr
  431. }
  432. defer s.Shutdown()
  433. filename := "test-server-send-timeout"
  434. mode := "octet"
  435. readTransfer, err := c.Receive(filename, mode)
  436. if err != nil {
  437. t.Fatalf("requesting read %s: %v", filename, err)
  438. }
  439. w := &slowWriter{
  440. n: 3,
  441. delay: 8 * time.Second,
  442. }
  443. _, _ = readTransfer.WriteTo(w)
  444. netErr, ok := serverErr.(net.Error)
  445. if !ok {
  446. t.Fatalf("network error expected: %T", serverErr)
  447. }
  448. if !netErr.Timeout() {
  449. t.Fatalf("timout is expected: %v", serverErr)
  450. }
  451. }
  452. func TestServerReceiveTimeout(t *testing.T) {
  453. s, c := makeTestServer()
  454. s.SetTimeout(time.Second)
  455. s.SetRetries(2)
  456. var serverErr error
  457. s.writeHandler = func(filename string, wt io.WriterTo) error {
  458. buf := &bytes.Buffer{}
  459. _, serverErr = wt.WriteTo(buf)
  460. return serverErr
  461. }
  462. defer s.Shutdown()
  463. filename := "test-server-receive-timeout"
  464. mode := "octet"
  465. writeTransfer, err := c.Send(filename, mode)
  466. if err != nil {
  467. t.Fatalf("requesting write %s: %v", filename, err)
  468. }
  469. r := &slowReader{
  470. r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000),
  471. n: 3,
  472. delay: 8 * time.Second,
  473. }
  474. _, _ = writeTransfer.ReadFrom(r)
  475. netErr, ok := serverErr.(net.Error)
  476. if !ok {
  477. t.Fatalf("network error expected: %T", serverErr)
  478. }
  479. if !netErr.Timeout() {
  480. t.Fatalf("timout is expected: %v", serverErr)
  481. }
  482. }
  483. func TestClientReceiveTimeout(t *testing.T) {
  484. s, c := makeTestServer()
  485. c.SetTimeout(time.Second)
  486. c.SetRetries(2)
  487. s.readHandler = func(filename string, rf io.ReaderFrom) error {
  488. r := &slowReader{
  489. r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000),
  490. n: 3,
  491. delay: 8 * time.Second,
  492. }
  493. _, err := rf.ReadFrom(r)
  494. return err
  495. }
  496. defer s.Shutdown()
  497. filename := "test-client-receive-timeout"
  498. mode := "octet"
  499. readTransfer, err := c.Receive(filename, mode)
  500. if err != nil {
  501. t.Fatalf("requesting read %s: %v", filename, err)
  502. }
  503. buf := &bytes.Buffer{}
  504. _, err = readTransfer.WriteTo(buf)
  505. netErr, ok := err.(net.Error)
  506. if !ok {
  507. t.Fatalf("network error expected: %T", err)
  508. }
  509. if !netErr.Timeout() {
  510. t.Fatalf("timout is expected: %v", err)
  511. }
  512. }
  513. func TestClientSendTimeout(t *testing.T) {
  514. s, c := makeTestServer()
  515. c.SetTimeout(time.Second)
  516. c.SetRetries(2)
  517. s.writeHandler = func(filename string, wt io.WriterTo) error {
  518. w := &slowWriter{
  519. n: 3,
  520. delay: 8 * time.Second,
  521. }
  522. _, err := wt.WriteTo(w)
  523. return err
  524. }
  525. defer s.Shutdown()
  526. filename := "test-client-send-timeout"
  527. mode := "octet"
  528. writeTransfer, err := c.Send(filename, mode)
  529. if err != nil {
  530. t.Fatalf("requesting write %s: %v", filename, err)
  531. }
  532. r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000)
  533. _, err = writeTransfer.ReadFrom(r)
  534. netErr, ok := err.(net.Error)
  535. if !ok {
  536. t.Fatalf("network error expected: %T", err)
  537. }
  538. if !netErr.Timeout() {
  539. t.Fatalf("timout is expected: %v", err)
  540. }
  541. }
  542. type slowReader struct {
  543. r io.Reader
  544. n int64
  545. delay time.Duration
  546. }
  547. func (r *slowReader) Read(p []byte) (n int, err error) {
  548. if r.n > 0 {
  549. r.n--
  550. return r.r.Read(p)
  551. }
  552. time.Sleep(r.delay)
  553. return r.r.Read(p)
  554. }
  555. type slowWriter struct {
  556. r io.Reader
  557. n int64
  558. delay time.Duration
  559. }
  560. func (r *slowWriter) Write(p []byte) (n int, err error) {
  561. if r.n > 0 {
  562. r.n--
  563. return len(p), nil
  564. }
  565. time.Sleep(r.delay)
  566. return len(p), nil
  567. }