|
|
@@ -58,7 +58,9 @@ func ListenTLSTunnel(
|
|
|
return nil, errors.Trace(err)
|
|
|
}
|
|
|
|
|
|
- return tris.NewListener(server.listener, server.tlsConfig), nil
|
|
|
+ listener = tris.NewListener(server.listener, server.tlsConfig)
|
|
|
+
|
|
|
+ return NewTLSTunnelListener(listener, server), nil
|
|
|
}
|
|
|
|
|
|
// NewTLSTunnelServer initializes a new TLSTunnelServer.
|
|
|
@@ -195,3 +197,60 @@ func (server *TLSTunnelServer) makeTLSTunnelConfig() (*tris.Config, error) {
|
|
|
|
|
|
return config, nil
|
|
|
}
|
|
|
+
|
|
|
+// TLSTunnelListener implements the net.Listener interface. Accept returns a
|
|
|
+// net.Conn which implements the common.MetricsSource interface.
|
|
|
+type TLSTunnelListener struct {
|
|
|
+ net.Listener
|
|
|
+ server *TLSTunnelServer
|
|
|
+}
|
|
|
+
|
|
|
+// NewTLSTunnelListener initializes a new TLSTunnelListener.
|
|
|
+func NewTLSTunnelListener(listener net.Listener, server *TLSTunnelServer) *TLSTunnelListener {
|
|
|
+ return &TLSTunnelListener{
|
|
|
+ Listener: listener,
|
|
|
+ server: server,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (l *TLSTunnelListener) Accept() (net.Conn, error) {
|
|
|
+ conn, err := l.Listener.Accept()
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Trace(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ return NewTLSTunnelConn(conn, l.server), nil
|
|
|
+}
|
|
|
+
|
|
|
+// TLSTunnelConn implements the net.Conn and common.MetricsSource interfaces.
|
|
|
+type TLSTunnelConn struct {
|
|
|
+ net.Conn
|
|
|
+ server *TLSTunnelServer
|
|
|
+}
|
|
|
+
|
|
|
+// NewTLSTunnelConn initializes a new TLSTunnelConn.
|
|
|
+func NewTLSTunnelConn(conn net.Conn, server *TLSTunnelServer) *TLSTunnelConn {
|
|
|
+ return &TLSTunnelConn{
|
|
|
+ Conn: conn,
|
|
|
+ server: server,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// GetMetrics implements the common.MetricsSource interface.
|
|
|
+func (conn *TLSTunnelConn) GetMetrics() common.LogFields {
|
|
|
+
|
|
|
+ var logFields common.LogFields
|
|
|
+
|
|
|
+ // Relay any metrics from the underlying conn.
|
|
|
+ if m, ok := conn.Conn.(common.MetricsSource); ok {
|
|
|
+ logFields = m.GetMetrics()
|
|
|
+ } else {
|
|
|
+ logFields = make(common.LogFields)
|
|
|
+ }
|
|
|
+
|
|
|
+ if conn.server.passthroughAddress != "" {
|
|
|
+ logFields["passthrough_address"] = conn.server.passthroughAddress
|
|
|
+ }
|
|
|
+
|
|
|
+ return logFields
|
|
|
+}
|