skipReader.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. package obfuscator
  2. import (
  3. "bytes"
  4. "io"
  5. "net"
  6. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  7. )
  8. type SkipReader struct {
  9. net.Conn
  10. offset int // buf offset for next Read
  11. end int // buf end index for next Read
  12. buf []byte
  13. }
  14. func WrapConnWithSkipReader(conn net.Conn) net.Conn {
  15. return &SkipReader{
  16. Conn: conn,
  17. offset: 0,
  18. end: 0,
  19. buf: nil,
  20. }
  21. }
  22. func (sr *SkipReader) Read(b []byte) (int, error) {
  23. // read buffered bytes first
  24. if sr.offset < sr.end {
  25. n := copy(b, sr.buf[sr.offset:sr.end])
  26. if n == 0 {
  27. // should never happen if len(b) > 0
  28. return 0, errors.TraceNew("read failed")
  29. }
  30. sr.offset += n
  31. // clear resources if all buffered bytes are read
  32. if sr.offset == sr.end {
  33. sr.offset = 0
  34. sr.end = 0
  35. sr.buf = nil
  36. }
  37. return n, nil
  38. }
  39. return sr.Conn.Read(b)
  40. }
  41. // SkipUpToToken reads from the underlying conn initially len(token) bytes,
  42. // and then readSize bytes at a time up to maxSearchSize until token is found,
  43. // or error. If the token is found, stream is rewound to end of the token.
  44. //
  45. // Note that maxSearchSize is not a strict limit on the total number of bytes read.
  46. func (sr *SkipReader) SkipUpToToken(
  47. token []byte, readSize, maxSearchSize int) error {
  48. if len(token) == 0 {
  49. return nil
  50. }
  51. if readSize < 1 {
  52. return errors.TraceNew("readSize too small")
  53. }
  54. if maxSearchSize < readSize {
  55. return errors.TraceNew("maxSearchSize too small")
  56. }
  57. sr.offset = 0
  58. sr.end = 0
  59. sr.buf = make([]byte, readSize+len(token))
  60. // Reads at least len(token) bytes.
  61. nTotal, err := io.ReadFull(sr.Conn, sr.buf[:len(token)])
  62. if err == io.ErrUnexpectedEOF {
  63. return errors.TraceNew("token not found")
  64. }
  65. if err != nil {
  66. return err
  67. }
  68. if bytes.Equal(sr.buf[:len(token)], token) {
  69. return nil
  70. }
  71. for nTotal < maxSearchSize {
  72. // The underlying conn is read into buf[len(token):].
  73. // buf[:len(token)] stores bytes from the previous read.
  74. n, err := sr.Conn.Read(sr.buf[len(token):])
  75. if err != nil && err != io.EOF {
  76. return err
  77. }
  78. if idx := bytes.Index(sr.buf[:n+len(token)], token); idx != -1 {
  79. // Found match, sets offset and end for next Read to start after the token.
  80. sr.offset = idx + len(token)
  81. sr.end = n + len(token)
  82. return err
  83. }
  84. if err == io.EOF {
  85. // Reached the end of stream, token not found.
  86. return errors.TraceNew("token not found")
  87. }
  88. // Copies last len(token) bytes to the beginning of the buffer.
  89. copy(sr.buf, sr.buf[n:n+len(token)])
  90. nTotal += n
  91. }
  92. return errors.TraceNew("exceeded max search size")
  93. }