|
|
@@ -2,19 +2,23 @@ package dtls
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
- "net"
|
|
|
+ "errors"
|
|
|
"sync/atomic"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
-var maxMessageSize = 65535
|
|
|
+var ErrInsufficientBuffer = errors.New("buffer too small to hold the received data")
|
|
|
+
|
|
|
+const recvChBufSize = 64
|
|
|
|
|
|
type hbConn struct {
|
|
|
- conn net.Conn
|
|
|
- recvCh chan errBytes
|
|
|
- waiting uint32
|
|
|
- hb []byte
|
|
|
- timeout time.Duration
|
|
|
+ stream msgStream
|
|
|
+
|
|
|
+ recvCh chan errBytes
|
|
|
+ waiting uint32
|
|
|
+ hb []byte
|
|
|
+ timeout time.Duration
|
|
|
+ maxMessageSize int
|
|
|
}
|
|
|
|
|
|
type errBytes struct {
|
|
|
@@ -23,13 +27,14 @@ type errBytes struct {
|
|
|
}
|
|
|
|
|
|
// heartbeatServer listens for heartbeat over conn with config
|
|
|
-func heartbeatServer(conn net.Conn, config *heartbeatConfig) (net.Conn, error) {
|
|
|
+func heartbeatServer(stream msgStream, config *heartbeatConfig, maxMessageSize int) (*hbConn, error) {
|
|
|
conf := validate(config)
|
|
|
|
|
|
- c := &hbConn{conn: conn,
|
|
|
- recvCh: make(chan errBytes),
|
|
|
- timeout: conf.Interval,
|
|
|
- hb: conf.Heartbeat,
|
|
|
+ c := &hbConn{stream: stream,
|
|
|
+ recvCh: make(chan errBytes, recvChBufSize),
|
|
|
+ timeout: conf.Interval,
|
|
|
+ hb: conf.Heartbeat,
|
|
|
+ maxMessageSize: maxMessageSize,
|
|
|
}
|
|
|
|
|
|
atomic.StoreUint32(&c.waiting, 2)
|
|
|
@@ -43,7 +48,7 @@ func heartbeatServer(conn net.Conn, config *heartbeatConfig) (net.Conn, error) {
|
|
|
func (c *hbConn) hbLoop() {
|
|
|
for {
|
|
|
if atomic.LoadUint32(&c.waiting) == 0 {
|
|
|
- c.conn.Close()
|
|
|
+ c.stream.Close()
|
|
|
return
|
|
|
}
|
|
|
|
|
|
@@ -55,58 +60,65 @@ func (c *hbConn) hbLoop() {
|
|
|
|
|
|
func (c *hbConn) recvLoop() {
|
|
|
for {
|
|
|
- // create a buffer to hold your data
|
|
|
- buffer := make([]byte, maxMessageSize)
|
|
|
+ buffer := make([]byte, c.maxMessageSize)
|
|
|
|
|
|
- n, err := c.conn.Read(buffer)
|
|
|
+ n, err := c.stream.Read(buffer)
|
|
|
|
|
|
if bytes.Equal(c.hb, buffer[:n]) {
|
|
|
atomic.AddUint32(&c.waiting, 1)
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
+ if err != nil {
|
|
|
+ c.recvCh <- errBytes{nil, err}
|
|
|
+ }
|
|
|
+
|
|
|
c.recvCh <- errBytes{buffer[:n], err}
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
func (c *hbConn) Close() error {
|
|
|
- return c.conn.Close()
|
|
|
+ return c.stream.Close()
|
|
|
}
|
|
|
|
|
|
func (c *hbConn) Write(b []byte) (n int, err error) {
|
|
|
- return c.conn.Write(b)
|
|
|
+ return c.stream.Write(b)
|
|
|
}
|
|
|
|
|
|
-func (c *hbConn) Read(b []byte) (n int, err error) {
|
|
|
+func (c *hbConn) Read(b []byte) (int, error) {
|
|
|
readBytes := <-c.recvCh
|
|
|
- copy(b, readBytes.b)
|
|
|
+ if readBytes.err != nil {
|
|
|
+ return 0, readBytes.err
|
|
|
+ }
|
|
|
|
|
|
- return len(readBytes.b), readBytes.err
|
|
|
-}
|
|
|
+ if len(b) < len(readBytes.b) {
|
|
|
+ return 0, ErrInsufficientBuffer
|
|
|
+ }
|
|
|
+
|
|
|
+ n := copy(b, readBytes.b)
|
|
|
|
|
|
-func (c *hbConn) LocalAddr() net.Addr {
|
|
|
- return c.conn.LocalAddr()
|
|
|
+ return n, nil
|
|
|
}
|
|
|
|
|
|
-func (c *hbConn) RemoteAddr() net.Addr {
|
|
|
- return c.conn.RemoteAddr()
|
|
|
+func (c *hbConn) BufferedAmount() uint64 {
|
|
|
+ return c.stream.BufferedAmount()
|
|
|
}
|
|
|
|
|
|
-func (c *hbConn) SetDeadline(t time.Time) error {
|
|
|
- return c.conn.SetDeadline(t)
|
|
|
+func (c *hbConn) SetReadDeadline(deadline time.Time) error {
|
|
|
+ return c.stream.SetReadDeadline(deadline)
|
|
|
}
|
|
|
|
|
|
-func (c *hbConn) SetReadDeadline(t time.Time) error {
|
|
|
- return c.conn.SetReadDeadline(t)
|
|
|
+func (c *hbConn) SetBufferedAmountLowThreshold(th uint64) {
|
|
|
+ c.stream.SetBufferedAmountLowThreshold(th)
|
|
|
}
|
|
|
|
|
|
-func (c *hbConn) SetWriteDeadline(t time.Time) error {
|
|
|
- return c.conn.SetWriteDeadline(t)
|
|
|
+func (c *hbConn) OnBufferedAmountLow(f func()) {
|
|
|
+ c.stream.OnBufferedAmountLow(f)
|
|
|
}
|
|
|
|
|
|
// heartbeatClient sends heartbeats over conn with config
|
|
|
-func heartbeatClient(conn net.Conn, config *heartbeatConfig) error {
|
|
|
+func heartbeatClient(conn msgStream, config *heartbeatConfig) error {
|
|
|
conf := validate(config)
|
|
|
go func() {
|
|
|
for {
|