Integrate BACKBEAT SDK and resolve KACHING license validation

Major integrations and fixes:
- Added BACKBEAT SDK integration for P2P operation timing
- Implemented beat-aware status tracking for distributed operations
- Added Docker secrets support for secure license management
- Resolved KACHING license validation via HTTPS/TLS
- Updated docker-compose configuration for clean stack deployment
- Disabled rollback policies to prevent deployment failures
- Added license credential storage (CHORUS-DEV-MULTI-001)

Technical improvements:
- BACKBEAT P2P operation tracking with phase management
- Enhanced configuration system with file-based secrets
- Improved error handling for license validation
- Clean separation of KACHING and CHORUS deployment stacks

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
anthonyrawlins
2025-09-06 07:56:26 +10:00
parent 543ab216f9
commit 9bdcbe0447
4730 changed files with 1480093 additions and 1916 deletions

View File

@@ -0,0 +1,91 @@
package libp2pquic
import (
"context"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
tpt "github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
"github.com/quic-go/quic-go"
)
type conn struct {
quicConn quic.Connection
transport *transport
scope network.ConnManagementScope
localPeer peer.ID
localMultiaddr ma.Multiaddr
remotePeerID peer.ID
remotePubKey ic.PubKey
remoteMultiaddr ma.Multiaddr
}
var _ tpt.CapableConn = &conn{}
// Close closes the connection.
// It must be called even if the peer closed the connection in order for
// garbage collection to properly work in this package.
func (c *conn) Close() error {
return c.closeWithError(0, "")
}
func (c *conn) closeWithError(errCode quic.ApplicationErrorCode, errString string) error {
c.transport.removeConn(c.quicConn)
err := c.quicConn.CloseWithError(errCode, errString)
c.scope.Done()
return err
}
// IsClosed returns whether a connection is fully closed.
func (c *conn) IsClosed() bool {
return c.quicConn.Context().Err() != nil
}
func (c *conn) allowWindowIncrease(size uint64) bool {
return c.scope.ReserveMemory(int(size), network.ReservationPriorityMedium) == nil
}
// OpenStream creates a new stream.
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
qstr, err := c.quicConn.OpenStreamSync(ctx)
return &stream{Stream: qstr}, err
}
// AcceptStream accepts a stream opened by the other side.
func (c *conn) AcceptStream() (network.MuxedStream, error) {
qstr, err := c.quicConn.AcceptStream(context.Background())
return &stream{Stream: qstr}, err
}
// LocalPeer returns our peer ID
func (c *conn) LocalPeer() peer.ID { return c.localPeer }
// RemotePeer returns the peer ID of the remote peer.
func (c *conn) RemotePeer() peer.ID { return c.remotePeerID }
// RemotePublicKey returns the public key of the remote peer.
func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey }
// LocalMultiaddr returns the local Multiaddr associated
func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.localMultiaddr }
// RemoteMultiaddr returns the remote Multiaddr associated
func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remoteMultiaddr }
func (c *conn) Transport() tpt.Transport { return c.transport }
func (c *conn) Scope() network.ConnScope { return c.scope }
// ConnState is the state of security connection.
func (c *conn) ConnState() network.ConnectionState {
t := "quic-v1"
if _, err := c.LocalMultiaddr().ValueForProtocol(ma.P_QUIC); err == nil {
t = "quic"
}
return network.ConnectionState{Transport: t}
}

View File

@@ -0,0 +1,147 @@
package libp2pquic
import (
"context"
"errors"
"net"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
tpt "github.com/libp2p/go-libp2p/core/transport"
p2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
ma "github.com/multiformats/go-multiaddr"
"github.com/quic-go/quic-go"
)
// A listener listens for QUIC connections.
type listener struct {
reuseListener quicreuse.Listener
transport *transport
rcmgr network.ResourceManager
privKey ic.PrivKey
localPeer peer.ID
localMultiaddrs map[quic.VersionNumber]ma.Multiaddr
}
func newListener(ln quicreuse.Listener, t *transport, localPeer peer.ID, key ic.PrivKey, rcmgr network.ResourceManager) (listener, error) {
localMultiaddrs := make(map[quic.VersionNumber]ma.Multiaddr)
for _, addr := range ln.Multiaddrs() {
if _, err := addr.ValueForProtocol(ma.P_QUIC_V1); err == nil {
localMultiaddrs[quic.Version1] = addr
}
}
return listener{
reuseListener: ln,
transport: t,
rcmgr: rcmgr,
privKey: key,
localPeer: localPeer,
localMultiaddrs: localMultiaddrs,
}, nil
}
// Accept accepts new connections.
func (l *listener) Accept() (tpt.CapableConn, error) {
for {
qconn, err := l.reuseListener.Accept(context.Background())
if err != nil {
return nil, err
}
c, err := l.setupConn(qconn)
if err != nil {
continue
}
l.transport.addConn(qconn, c)
if l.transport.gater != nil && !(l.transport.gater.InterceptAccept(c) && l.transport.gater.InterceptSecured(network.DirInbound, c.remotePeerID, c)) {
c.closeWithError(errorCodeConnectionGating, "connection gated")
continue
}
// return through active hole punching if any
key := holePunchKey{addr: qconn.RemoteAddr().String(), peer: c.remotePeerID}
var wasHolePunch bool
l.transport.holePunchingMx.Lock()
holePunch, ok := l.transport.holePunching[key]
if ok && !holePunch.fulfilled {
holePunch.connCh <- c
wasHolePunch = true
holePunch.fulfilled = true
}
l.transport.holePunchingMx.Unlock()
if wasHolePunch {
continue
}
return c, nil
}
}
func (l *listener) setupConn(qconn quic.Connection) (*conn, error) {
remoteMultiaddr, err := quicreuse.ToQuicMultiaddr(qconn.RemoteAddr(), qconn.ConnectionState().Version)
if err != nil {
return nil, err
}
connScope, err := l.rcmgr.OpenConnection(network.DirInbound, false, remoteMultiaddr)
if err != nil {
log.Debugw("resource manager blocked incoming connection", "addr", qconn.RemoteAddr(), "error", err)
return nil, err
}
c, err := l.setupConnWithScope(qconn, connScope, remoteMultiaddr)
if err != nil {
connScope.Done()
qconn.CloseWithError(1, "")
return nil, err
}
return c, nil
}
func (l *listener) setupConnWithScope(qconn quic.Connection, connScope network.ConnManagementScope, remoteMultiaddr ma.Multiaddr) (*conn, error) {
// The tls.Config used to establish this connection already verified the certificate chain.
// Since we don't have any way of knowing which tls.Config was used though,
// we have to re-determine the peer's identity here.
// Therefore, this is expected to never fail.
remotePubKey, err := p2ptls.PubKeyFromCertChain(qconn.ConnectionState().TLS.PeerCertificates)
if err != nil {
return nil, err
}
remotePeerID, err := peer.IDFromPublicKey(remotePubKey)
if err != nil {
return nil, err
}
if err := connScope.SetPeer(remotePeerID); err != nil {
log.Debugw("resource manager blocked incoming connection for peer", "peer", remotePeerID, "addr", qconn.RemoteAddr(), "error", err)
return nil, err
}
localMultiaddr, found := l.localMultiaddrs[qconn.ConnectionState().Version]
if !found {
return nil, errors.New("unknown QUIC version:" + qconn.ConnectionState().Version.String())
}
return &conn{
quicConn: qconn,
transport: l.transport,
scope: connScope,
localPeer: l.localPeer,
localMultiaddr: localMultiaddr,
remoteMultiaddr: remoteMultiaddr,
remotePeerID: remotePeerID,
remotePubKey: remotePubKey,
}, nil
}
// Close closes the listener.
func (l *listener) Close() error {
return l.reuseListener.Close()
}
// Addr returns the address of this listener.
func (l *listener) Addr() net.Addr {
return l.reuseListener.Addr()
}

View File

@@ -0,0 +1,55 @@
package libp2pquic
import (
"errors"
"github.com/libp2p/go-libp2p/core/network"
"github.com/quic-go/quic-go"
)
const (
reset quic.StreamErrorCode = 0
)
type stream struct {
quic.Stream
}
var _ network.MuxedStream = &stream{}
func (s *stream) Read(b []byte) (n int, err error) {
n, err = s.Stream.Read(b)
if err != nil && errors.Is(err, &quic.StreamError{}) {
err = network.ErrReset
}
return n, err
}
func (s *stream) Write(b []byte) (n int, err error) {
n, err = s.Stream.Write(b)
if err != nil && errors.Is(err, &quic.StreamError{}) {
err = network.ErrReset
}
return n, err
}
func (s *stream) Reset() error {
s.Stream.CancelRead(reset)
s.Stream.CancelWrite(reset)
return nil
}
func (s *stream) Close() error {
s.Stream.CancelRead(reset)
return s.Stream.Close()
}
func (s *stream) CloseRead() error {
s.Stream.CancelRead(reset)
return nil
}
func (s *stream) CloseWrite() error {
return s.Stream.Close()
}

View File

@@ -0,0 +1,399 @@
package libp2pquic
import (
"context"
"crypto/tls"
"errors"
"fmt"
"math/rand"
"net"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/connmgr"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/pnet"
tpt "github.com/libp2p/go-libp2p/core/transport"
p2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
mafmt "github.com/multiformats/go-multiaddr-fmt"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/quic-go/quic-go"
)
var log = logging.Logger("quic-transport")
var ErrHolePunching = errors.New("hole punching attempted; no active dial")
var HolePunchTimeout = 5 * time.Second
const errorCodeConnectionGating = 0x47415445 // GATE in ASCII
// The Transport implements the tpt.Transport interface for QUIC connections.
type transport struct {
privKey ic.PrivKey
localPeer peer.ID
identity *p2ptls.Identity
connManager *quicreuse.ConnManager
gater connmgr.ConnectionGater
rcmgr network.ResourceManager
holePunchingMx sync.Mutex
holePunching map[holePunchKey]*activeHolePunch
rndMx sync.Mutex
rnd rand.Rand
connMx sync.Mutex
conns map[quic.Connection]*conn
listenersMu sync.Mutex
// map of UDPAddr as string to a virtualListeners
listeners map[string][]*virtualListener
}
var _ tpt.Transport = &transport{}
type holePunchKey struct {
addr string
peer peer.ID
}
type activeHolePunch struct {
connCh chan tpt.CapableConn
fulfilled bool
}
// NewTransport creates a new QUIC transport
func NewTransport(key ic.PrivKey, connManager *quicreuse.ConnManager, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Transport, error) {
if len(psk) > 0 {
log.Error("QUIC doesn't support private networks yet.")
return nil, errors.New("QUIC doesn't support private networks yet")
}
localPeer, err := peer.IDFromPrivateKey(key)
if err != nil {
return nil, err
}
identity, err := p2ptls.NewIdentity(key)
if err != nil {
return nil, err
}
if rcmgr == nil {
rcmgr = &network.NullResourceManager{}
}
return &transport{
privKey: key,
localPeer: localPeer,
identity: identity,
connManager: connManager,
gater: gater,
rcmgr: rcmgr,
conns: make(map[quic.Connection]*conn),
holePunching: make(map[holePunchKey]*activeHolePunch),
rnd: *rand.New(rand.NewSource(time.Now().UnixNano())),
listeners: make(map[string][]*virtualListener),
}, nil
}
// Dial dials a new QUIC connection
func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (_c tpt.CapableConn, _err error) {
if ok, isClient, _ := network.GetSimultaneousConnect(ctx); ok && !isClient {
return t.holePunch(ctx, raddr, p)
}
scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr)
if err != nil {
log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err)
return nil, err
}
c, err := t.dialWithScope(ctx, raddr, p, scope)
if err != nil {
scope.Done()
return nil, err
}
return c, nil
}
func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) {
if err := scope.SetPeer(p); err != nil {
log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err)
return nil, err
}
tlsConf, keyCh := t.identity.ConfigForPeer(p)
pconn, err := t.connManager.DialQUIC(ctx, raddr, tlsConf, t.allowWindowIncrease)
if err != nil {
return nil, err
}
// Should be ready by this point, don't block.
var remotePubKey ic.PubKey
select {
case remotePubKey = <-keyCh:
default:
}
if remotePubKey == nil {
pconn.CloseWithError(1, "")
return nil, errors.New("p2p/transport/quic BUG: expected remote pub key to be set")
}
localMultiaddr, err := quicreuse.ToQuicMultiaddr(pconn.LocalAddr(), pconn.ConnectionState().Version)
if err != nil {
pconn.CloseWithError(1, "")
return nil, err
}
c := &conn{
quicConn: pconn,
transport: t,
scope: scope,
localPeer: t.localPeer,
localMultiaddr: localMultiaddr,
remotePubKey: remotePubKey,
remotePeerID: p,
remoteMultiaddr: raddr,
}
if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, c) {
pconn.CloseWithError(errorCodeConnectionGating, "connection gated")
return nil, fmt.Errorf("secured connection gated")
}
t.addConn(pconn, c)
return c, nil
}
func (t *transport) addConn(conn quic.Connection, c *conn) {
t.connMx.Lock()
t.conns[conn] = c
t.connMx.Unlock()
}
func (t *transport) removeConn(conn quic.Connection) {
t.connMx.Lock()
delete(t.conns, conn)
t.connMx.Unlock()
}
func (t *transport) holePunch(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) {
network, saddr, err := manet.DialArgs(raddr)
if err != nil {
return nil, err
}
addr, err := net.ResolveUDPAddr(network, saddr)
if err != nil {
return nil, err
}
tr, err := t.connManager.TransportForDial(network, addr)
if err != nil {
return nil, err
}
defer tr.DecreaseCount()
ctx, cancel := context.WithTimeout(ctx, HolePunchTimeout)
defer cancel()
key := holePunchKey{addr: addr.String(), peer: p}
t.holePunchingMx.Lock()
if _, ok := t.holePunching[key]; ok {
t.holePunchingMx.Unlock()
return nil, fmt.Errorf("already punching hole for %s", addr)
}
connCh := make(chan tpt.CapableConn, 1)
t.holePunching[key] = &activeHolePunch{connCh: connCh}
t.holePunchingMx.Unlock()
var timer *time.Timer
defer func() {
if timer != nil {
timer.Stop()
}
}()
payload := make([]byte, 64)
var punchErr error
loop:
for i := 0; ; i++ {
t.rndMx.Lock()
_, err := t.rnd.Read(payload)
t.rndMx.Unlock()
if err != nil {
punchErr = err
break
}
if _, err := tr.WriteTo(payload, addr); err != nil {
punchErr = err
break
}
maxSleep := 10 * (i + 1) * (i + 1) // in ms
if maxSleep > 200 {
maxSleep = 200
}
d := 10*time.Millisecond + time.Duration(rand.Intn(maxSleep))*time.Millisecond
if timer == nil {
timer = time.NewTimer(d)
} else {
timer.Reset(d)
}
select {
case c := <-connCh:
t.holePunchingMx.Lock()
delete(t.holePunching, key)
t.holePunchingMx.Unlock()
return c, nil
case <-timer.C:
case <-ctx.Done():
punchErr = ErrHolePunching
break loop
}
}
// we only arrive here if punchErr != nil
t.holePunchingMx.Lock()
defer func() {
delete(t.holePunching, key)
t.holePunchingMx.Unlock()
}()
select {
case c := <-t.holePunching[key].connCh:
return c, nil
default:
return nil, punchErr
}
}
// Don't use mafmt.QUIC as we don't want to dial DNS addresses. Just /ip{4,6}/udp/quic-v1
var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC_V1))
// CanDial determines if we can dial to an address
func (t *transport) CanDial(addr ma.Multiaddr) bool {
return dialMatcher.Matches(addr)
}
// Listen listens for new QUIC connections on the passed multiaddr.
func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) {
var tlsConf tls.Config
tlsConf.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
// return a tls.Config that verifies the peer's certificate chain.
// Note that since we have no way of associating an incoming QUIC connection with
// the peer ID calculated here, we don't actually receive the peer's public key
// from the key chan.
conf, _ := t.identity.ConfigForPeer("")
return conf, nil
}
tlsConf.NextProtos = []string{"libp2p"}
udpAddr, version, err := quicreuse.FromQuicMultiaddr(addr)
if err != nil {
return nil, err
}
t.listenersMu.Lock()
defer t.listenersMu.Unlock()
listeners := t.listeners[udpAddr.String()]
var underlyingListener *listener
var acceptRunner *acceptLoopRunner
if len(listeners) != 0 {
// We already have an underlying listener, let's use it
underlyingListener = listeners[0].listener
acceptRunner = listeners[0].acceptRunnner
// Make sure our underlying listener is listening on the specified QUIC version
if _, ok := underlyingListener.localMultiaddrs[version]; !ok {
return nil, fmt.Errorf("can't listen on quic version %v, underlying listener doesn't support it", version)
}
} else {
ln, err := t.connManager.ListenQUIC(addr, &tlsConf, t.allowWindowIncrease)
if err != nil {
return nil, err
}
l, err := newListener(ln, t, t.localPeer, t.privKey, t.rcmgr)
if err != nil {
_ = ln.Close()
return nil, err
}
underlyingListener = &l
acceptRunner = &acceptLoopRunner{
acceptSem: make(chan struct{}, 1),
muxer: make(map[quic.VersionNumber]chan acceptVal),
}
}
l := &virtualListener{
listener: underlyingListener,
version: version,
udpAddr: udpAddr.String(),
t: t,
acceptRunnner: acceptRunner,
acceptChan: acceptRunner.AcceptForVersion(version),
}
listeners = append(listeners, l)
t.listeners[udpAddr.String()] = listeners
return l, nil
}
func (t *transport) allowWindowIncrease(conn quic.Connection, size uint64) bool {
// If the QUIC connection tries to increase the window before we've inserted it
// into our connections map (which we do right after dialing / accepting it),
// we have no way to account for that memory. This should be very rare.
// Block this attempt. The connection can request more memory later.
t.connMx.Lock()
c, ok := t.conns[conn]
t.connMx.Unlock()
if !ok {
return false
}
return c.allowWindowIncrease(size)
}
// Proxy returns true if this transport proxies.
func (t *transport) Proxy() bool {
return false
}
// Protocols returns the set of protocols handled by this transport.
func (t *transport) Protocols() []int {
return t.connManager.Protocols()
}
func (t *transport) String() string {
return "QUIC"
}
func (t *transport) Close() error {
return nil
}
func (t *transport) CloseVirtualListener(l *virtualListener) error {
t.listenersMu.Lock()
defer t.listenersMu.Unlock()
var err error
listeners := t.listeners[l.udpAddr]
if len(listeners) == 1 {
// This is the last virtual listener here, so we can close the underlying listener
err = l.listener.Close()
delete(t.listeners, l.udpAddr)
return err
}
for i := 0; i < len(listeners); i++ {
// Swap remove
if l == listeners[i] {
listeners[i] = listeners[len(listeners)-1]
listeners = listeners[:len(listeners)-1]
t.listeners[l.udpAddr] = listeners
break
}
}
return nil
}

View File

@@ -0,0 +1,175 @@
package libp2pquic
import (
"sync"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
ma "github.com/multiformats/go-multiaddr"
"github.com/quic-go/quic-go"
)
const acceptBufferPerVersion = 4
// virtualListener is a listener that exposes a single multiaddr but uses another listener under the hood
type virtualListener struct {
*listener
udpAddr string
version quic.VersionNumber
t *transport
acceptRunnner *acceptLoopRunner
acceptChan chan acceptVal
}
var _ tpt.Listener = &virtualListener{}
func (l *virtualListener) Multiaddr() ma.Multiaddr {
return l.listener.localMultiaddrs[l.version]
}
func (l *virtualListener) Close() error {
l.acceptRunnner.RmAcceptForVersion(l.version, tpt.ErrListenerClosed)
return l.t.CloseVirtualListener(l)
}
func (l *virtualListener) Accept() (tpt.CapableConn, error) {
return l.acceptRunnner.Accept(l.listener, l.version, l.acceptChan)
}
type acceptVal struct {
conn tpt.CapableConn
err error
}
type acceptLoopRunner struct {
acceptSem chan struct{}
muxerMu sync.Mutex
muxer map[quic.VersionNumber]chan acceptVal
muxerClosed bool
}
func (r *acceptLoopRunner) AcceptForVersion(v quic.VersionNumber) chan acceptVal {
r.muxerMu.Lock()
defer r.muxerMu.Unlock()
ch := make(chan acceptVal, acceptBufferPerVersion)
if _, ok := r.muxer[v]; ok {
panic("unexpected chan already found in accept muxer")
}
r.muxer[v] = ch
return ch
}
func (r *acceptLoopRunner) RmAcceptForVersion(v quic.VersionNumber, err error) {
r.muxerMu.Lock()
defer r.muxerMu.Unlock()
if r.muxerClosed {
// Already closed, all versions are removed
return
}
ch, ok := r.muxer[v]
if !ok {
panic("expected chan in accept muxer")
}
ch <- acceptVal{err: err}
delete(r.muxer, v)
}
func (r *acceptLoopRunner) sendErrAndClose(err error) {
r.muxerMu.Lock()
defer r.muxerMu.Unlock()
r.muxerClosed = true
for k, ch := range r.muxer {
select {
case ch <- acceptVal{err: err}:
default:
}
delete(r.muxer, k)
close(ch)
}
}
// innerAccept is the inner logic of the Accept loop. Assume caller holds the
// acceptSemaphore. May return both a nil conn and nil error if it didn't find a
// conn with the expected version
func (r *acceptLoopRunner) innerAccept(l *listener, expectedVersion quic.VersionNumber, bufferedConnChan chan acceptVal) (tpt.CapableConn, error) {
select {
// Check if we have a buffered connection first from an earlier Accept call
case v, ok := <-bufferedConnChan:
if !ok {
return nil, tpt.ErrListenerClosed
}
return v.conn, v.err
default:
}
conn, err := l.Accept()
if err != nil {
r.sendErrAndClose(err)
return nil, err
}
_, version, err := quicreuse.FromQuicMultiaddr(conn.RemoteMultiaddr())
if err != nil {
r.sendErrAndClose(err)
return nil, err
}
if version == expectedVersion {
return conn, nil
}
// This wasn't the version we were expecting, lets queue it up for a
// future Accept call with a different version
r.muxerMu.Lock()
ch, ok := r.muxer[version]
r.muxerMu.Unlock()
if !ok {
// Nothing to handle this connection version. Close it
conn.Close()
return nil, nil
}
// Non blocking
select {
case ch <- acceptVal{conn: conn}:
default:
// accept queue filled up, drop the connection
conn.Close()
log.Warn("Accept queue filled. Dropping connection.")
}
return nil, nil
}
func (r *acceptLoopRunner) Accept(l *listener, expectedVersion quic.VersionNumber, bufferedConnChan chan acceptVal) (tpt.CapableConn, error) {
for {
var conn tpt.CapableConn
var err error
select {
case r.acceptSem <- struct{}{}:
conn, err = r.innerAccept(l, expectedVersion, bufferedConnChan)
<-r.acceptSem
if conn == nil && err == nil {
// Didn't find a conn for the expected version and there was no error, lets try again
continue
}
case v, ok := <-bufferedConnChan:
if !ok {
return nil, tpt.ErrListenerClosed
}
conn = v.conn
err = v.err
}
return conn, err
}
}

View File

@@ -0,0 +1,23 @@
package quicreuse
import (
"net"
"time"
"github.com/quic-go/quic-go"
)
var quicConfig = &quic.Config{
MaxIncomingStreams: 256,
MaxIncomingUniStreams: 5, // allow some unidirectional streams, in case we speak WebTransport
MaxStreamReceiveWindow: 10 * (1 << 20), // 10 MB
MaxConnectionReceiveWindow: 15 * (1 << 20), // 15 MB
RequireAddressValidation: func(net.Addr) bool {
// TODO(#1535): require source address validation when under load
return false
},
KeepAlivePeriod: 15 * time.Second,
Versions: []quic.VersionNumber{quic.Version1},
// We don't use datagrams (yet), but this is necessary for WebTransport
EnableDatagrams: true,
}

View File

@@ -0,0 +1,224 @@
package quicreuse
import (
"context"
"crypto/tls"
"errors"
"net"
"sync"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/quic-go/quic-go"
quiclogging "github.com/quic-go/quic-go/logging"
)
type ConnManager struct {
reuseUDP4 *reuse
reuseUDP6 *reuse
enableReuseport bool
enableMetrics bool
serverConfig *quic.Config
clientConfig *quic.Config
quicListenersMu sync.Mutex
quicListeners map[string]quicListenerEntry
srk quic.StatelessResetKey
tokenKey quic.TokenGeneratorKey
}
type quicListenerEntry struct {
refCount int
ln *quicListener
}
func NewConnManager(statelessResetKey quic.StatelessResetKey, tokenKey quic.TokenGeneratorKey, opts ...Option) (*ConnManager, error) {
cm := &ConnManager{
enableReuseport: true,
quicListeners: make(map[string]quicListenerEntry),
srk: statelessResetKey,
tokenKey: tokenKey,
}
for _, o := range opts {
if err := o(cm); err != nil {
return nil, err
}
}
quicConf := quicConfig.Clone()
quicConf.Tracer = func(ctx context.Context, p quiclogging.Perspective, ci quic.ConnectionID) *quiclogging.ConnectionTracer {
var tracer *quiclogging.ConnectionTracer
if qlogTracerDir != "" {
tracer = qloggerForDir(qlogTracerDir, p, ci)
}
return tracer
}
serverConfig := quicConf.Clone()
cm.clientConfig = quicConf
cm.serverConfig = serverConfig
if cm.enableReuseport {
cm.reuseUDP4 = newReuse(&statelessResetKey, &tokenKey)
cm.reuseUDP6 = newReuse(&statelessResetKey, &tokenKey)
}
return cm, nil
}
func (c *ConnManager) getReuse(network string) (*reuse, error) {
switch network {
case "udp4":
return c.reuseUDP4, nil
case "udp6":
return c.reuseUDP6, nil
default:
return nil, errors.New("invalid network: must be either udp4 or udp6")
}
}
func (c *ConnManager) ListenQUIC(addr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (Listener, error) {
netw, host, err := manet.DialArgs(addr)
if err != nil {
return nil, err
}
laddr, err := net.ResolveUDPAddr(netw, host)
if err != nil {
return nil, err
}
c.quicListenersMu.Lock()
defer c.quicListenersMu.Unlock()
key := laddr.String()
entry, ok := c.quicListeners[key]
if !ok {
tr, err := c.transportForListen(netw, laddr)
if err != nil {
return nil, err
}
ln, err := newQuicListener(tr, c.serverConfig)
if err != nil {
return nil, err
}
key = tr.LocalAddr().String()
entry = quicListenerEntry{ln: ln}
}
l, err := entry.ln.Add(tlsConf, allowWindowIncrease, func() { c.onListenerClosed(key) })
if err != nil {
if entry.refCount <= 0 {
entry.ln.Close()
}
return nil, err
}
entry.refCount++
c.quicListeners[key] = entry
return l, nil
}
func (c *ConnManager) onListenerClosed(key string) {
c.quicListenersMu.Lock()
defer c.quicListenersMu.Unlock()
entry := c.quicListeners[key]
entry.refCount = entry.refCount - 1
if entry.refCount <= 0 {
delete(c.quicListeners, key)
entry.ln.Close()
} else {
c.quicListeners[key] = entry
}
}
func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (refCountedQuicTransport, error) {
if c.enableReuseport {
reuse, err := c.getReuse(network)
if err != nil {
return nil, err
}
return reuse.TransportForListen(network, laddr)
}
conn, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
}
return &singleOwnerTransport{
packetConn: conn,
Transport: quic.Transport{
Conn: conn,
StatelessResetKey: &c.srk,
TokenGeneratorKey: &c.tokenKey,
},
}, nil
}
func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (quic.Connection, error) {
naddr, v, err := FromQuicMultiaddr(raddr)
if err != nil {
return nil, err
}
netw, _, err := manet.DialArgs(raddr)
if err != nil {
return nil, err
}
quicConf := c.clientConfig.Clone()
quicConf.AllowConnectionWindowIncrease = allowWindowIncrease
if v == quic.Version1 {
// The endpoint has explicit support for QUIC v1, so we'll only use that version.
quicConf.Versions = []quic.VersionNumber{quic.Version1}
} else {
return nil, errors.New("unknown QUIC version")
}
tr, err := c.TransportForDial(netw, naddr)
if err != nil {
return nil, err
}
conn, err := tr.Dial(ctx, naddr, tlsConf, quicConf)
if err != nil {
tr.DecreaseCount()
return nil, err
}
return conn, nil
}
func (c *ConnManager) TransportForDial(network string, raddr *net.UDPAddr) (refCountedQuicTransport, error) {
if c.enableReuseport {
reuse, err := c.getReuse(network)
if err != nil {
return nil, err
}
return reuse.TransportForDial(network, raddr)
}
var laddr *net.UDPAddr
switch network {
case "udp4":
laddr = &net.UDPAddr{IP: net.IPv4zero, Port: 0}
case "udp6":
laddr = &net.UDPAddr{IP: net.IPv6zero, Port: 0}
}
conn, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
}
return &singleOwnerTransport{Transport: quic.Transport{Conn: conn, StatelessResetKey: &c.srk}, packetConn: conn}, nil
}
func (c *ConnManager) Protocols() []int {
return []int{ma.P_QUIC_V1}
}
func (c *ConnManager) Close() error {
if !c.enableReuseport {
return nil
}
if err := c.reuseUDP6.Close(); err != nil {
return err
}
return c.reuseUDP4.Close()
}

View File

@@ -0,0 +1,219 @@
package quicreuse
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
"github.com/quic-go/quic-go"
)
type Listener interface {
Accept(context.Context) (quic.Connection, error)
Addr() net.Addr
Multiaddrs() []ma.Multiaddr
io.Closer
}
type protoConf struct {
ln *listener
tlsConf *tls.Config
allowWindowIncrease func(conn quic.Connection, delta uint64) bool
}
type quicListener struct {
l *quic.Listener
transport refCountedQuicTransport
running chan struct{}
addrs []ma.Multiaddr
protocolsMu sync.Mutex
protocols map[string]protoConf
}
func newQuicListener(tr refCountedQuicTransport, quicConfig *quic.Config) (*quicListener, error) {
localMultiaddrs := make([]ma.Multiaddr, 0, 2)
a, err := ToQuicMultiaddr(tr.LocalAddr(), quic.Version1)
if err != nil {
return nil, err
}
localMultiaddrs = append(localMultiaddrs, a)
cl := &quicListener{
protocols: map[string]protoConf{},
running: make(chan struct{}),
transport: tr,
addrs: localMultiaddrs,
}
tlsConf := &tls.Config{
SessionTicketsDisabled: true, // This is set for the config for client, but we set it here as well: https://github.com/quic-go/quic-go/issues/4029
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
cl.protocolsMu.Lock()
defer cl.protocolsMu.Unlock()
for _, proto := range info.SupportedProtos {
if entry, ok := cl.protocols[proto]; ok {
conf := entry.tlsConf
if conf.GetConfigForClient != nil {
return conf.GetConfigForClient(info)
}
return conf, nil
}
}
return nil, fmt.Errorf("no supported protocol found. offered: %+v", info.SupportedProtos)
},
}
quicConf := quicConfig.Clone()
quicConf.AllowConnectionWindowIncrease = cl.allowWindowIncrease
ln, err := tr.Listen(tlsConf, quicConf)
if err != nil {
return nil, err
}
cl.l = ln
go cl.Run() // This go routine shuts down once the underlying quic.Listener is closed (or returns an error).
return cl, nil
}
func (l *quicListener) allowWindowIncrease(conn quic.Connection, delta uint64) bool {
l.protocolsMu.Lock()
defer l.protocolsMu.Unlock()
conf, ok := l.protocols[conn.ConnectionState().TLS.NegotiatedProtocol]
if !ok {
return false
}
return conf.allowWindowIncrease(conn, delta)
}
func (l *quicListener) Add(tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool, onRemove func()) (Listener, error) {
l.protocolsMu.Lock()
defer l.protocolsMu.Unlock()
if len(tlsConf.NextProtos) == 0 {
return nil, errors.New("no ALPN found in tls.Config")
}
for _, proto := range tlsConf.NextProtos {
if _, ok := l.protocols[proto]; ok {
return nil, fmt.Errorf("already listening for protocol %s", proto)
}
}
ln := newSingleListener(l.l.Addr(), l.addrs, func() {
l.protocolsMu.Lock()
for _, proto := range tlsConf.NextProtos {
delete(l.protocols, proto)
}
l.protocolsMu.Unlock()
onRemove()
}, l.running)
for _, proto := range tlsConf.NextProtos {
l.protocols[proto] = protoConf{
ln: ln,
tlsConf: tlsConf,
allowWindowIncrease: allowWindowIncrease,
}
}
return ln, nil
}
func (l *quicListener) Run() error {
defer close(l.running)
defer l.transport.DecreaseCount()
for {
conn, err := l.l.Accept(context.Background())
if err != nil {
if errors.Is(err, quic.ErrServerClosed) || strings.Contains(err.Error(), "use of closed network connection") {
return transport.ErrListenerClosed
}
return err
}
proto := conn.ConnectionState().TLS.NegotiatedProtocol
l.protocolsMu.Lock()
ln, ok := l.protocols[proto]
if !ok {
l.protocolsMu.Unlock()
return fmt.Errorf("negotiated unknown protocol: %s", proto)
}
ln.ln.add(conn)
l.protocolsMu.Unlock()
}
}
func (l *quicListener) Close() error {
err := l.l.Close()
<-l.running // wait for Run to return
return err
}
const queueLen = 16
// A listener for a single ALPN protocol (set).
type listener struct {
queue chan quic.Connection
acceptLoopRunning chan struct{}
addr net.Addr
addrs []ma.Multiaddr
remove func()
closeOnce sync.Once
}
var _ Listener = &listener{}
func newSingleListener(addr net.Addr, addrs []ma.Multiaddr, remove func(), running chan struct{}) *listener {
return &listener{
queue: make(chan quic.Connection, queueLen),
acceptLoopRunning: running,
remove: remove,
addr: addr,
addrs: addrs,
}
}
func (l *listener) add(c quic.Connection) {
select {
case l.queue <- c:
default:
c.CloseWithError(1, "queue full")
}
}
func (l *listener) Accept(ctx context.Context) (quic.Connection, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-l.acceptLoopRunning:
return nil, transport.ErrListenerClosed
case c, ok := <-l.queue:
if !ok {
return nil, transport.ErrListenerClosed
}
return c, nil
}
}
func (l *listener) Addr() net.Addr {
return l.addr
}
func (l *listener) Multiaddrs() []ma.Multiaddr {
return l.addrs
}
func (l *listener) Close() error {
l.closeOnce.Do(func() {
l.remove()
close(l.queue)
// drain the queue
for conn := range l.queue {
conn.CloseWithError(1, "closing")
}
})
return nil
}

View File

@@ -0,0 +1,18 @@
package quicreuse
type Option func(*ConnManager) error
func DisableReuseport() Option {
return func(m *ConnManager) error {
m.enableReuseport = false
return nil
}
}
// EnableMetrics enables Prometheus metrics collection.
func EnableMetrics() Option {
return func(m *ConnManager) error {
m.enableMetrics = true
return nil
}
}

View File

@@ -0,0 +1,58 @@
package quicreuse
import (
"errors"
"net"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/quic-go/quic-go"
)
var (
quicV1MA = ma.StringCast("/quic-v1")
)
func ToQuicMultiaddr(na net.Addr, version quic.VersionNumber) (ma.Multiaddr, error) {
udpMA, err := manet.FromNetAddr(na)
if err != nil {
return nil, err
}
switch version {
case quic.Version1:
return udpMA.Encapsulate(quicV1MA), nil
default:
return nil, errors.New("unknown QUIC version")
}
}
func FromQuicMultiaddr(addr ma.Multiaddr) (*net.UDPAddr, quic.VersionNumber, error) {
var version quic.VersionNumber
var partsBeforeQUIC []ma.Multiaddr
ma.ForEach(addr, func(c ma.Component) bool {
switch c.Protocol().Code {
case ma.P_QUIC_V1:
version = quic.Version1
return false
default:
partsBeforeQUIC = append(partsBeforeQUIC, &c)
return true
}
})
if len(partsBeforeQUIC) == 0 {
return nil, version, errors.New("no addr before QUIC component")
}
if version == 0 {
// Not found
return nil, version, errors.New("unknown QUIC version")
}
netAddr, err := manet.ToNetAddr(ma.Join(partsBeforeQUIC...))
if err != nil {
return nil, version, err
}
udpAddr, ok := netAddr.(*net.UDPAddr)
if !ok {
return nil, 0, errors.New("not a *net.UDPAddr")
}
return udpAddr, version, nil
}

View File

@@ -0,0 +1,353 @@
package quicreuse
import (
"context"
"crypto/tls"
"net"
"sync"
"time"
"github.com/google/gopacket/routing"
"github.com/libp2p/go-netroute"
"github.com/quic-go/quic-go"
)
type refCountedQuicTransport interface {
LocalAddr() net.Addr
// Used to send packets directly around QUIC. Useful for hole punching.
WriteTo([]byte, net.Addr) (int, error)
Close() error
// count transport reference
DecreaseCount()
IncreaseCount()
Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *quic.Config) (quic.Connection, error)
Listen(tlsConf *tls.Config, conf *quic.Config) (*quic.Listener, error)
}
type singleOwnerTransport struct {
quic.Transport
// Used to write packets directly around QUIC.
packetConn net.PacketConn
}
func (c *singleOwnerTransport) IncreaseCount() {}
func (c *singleOwnerTransport) DecreaseCount() {
c.Transport.Close()
}
func (c *singleOwnerTransport) LocalAddr() net.Addr {
return c.Transport.Conn.LocalAddr()
}
func (c *singleOwnerTransport) Close() error {
// TODO(when we drop support for go 1.19) use errors.Join
c.Transport.Close()
return c.packetConn.Close()
}
func (c *singleOwnerTransport) WriteTo(b []byte, addr net.Addr) (int, error) {
return c.Transport.WriteTo(b, addr)
}
// Constant. Defined as variables to simplify testing.
var (
garbageCollectInterval = 30 * time.Second
maxUnusedDuration = 10 * time.Second
)
type refcountedTransport struct {
quic.Transport
// Used to write packets directly around QUIC.
packetConn net.PacketConn
mutex sync.Mutex
refCount int
unusedSince time.Time
}
func (c *refcountedTransport) IncreaseCount() {
c.mutex.Lock()
c.refCount++
c.unusedSince = time.Time{}
c.mutex.Unlock()
}
func (c *refcountedTransport) Close() error {
// TODO(when we drop support for go 1.19) use errors.Join
c.Transport.Close()
return c.packetConn.Close()
}
func (c *refcountedTransport) WriteTo(b []byte, addr net.Addr) (int, error) {
return c.Transport.WriteTo(b, addr)
}
func (c *refcountedTransport) LocalAddr() net.Addr {
return c.Transport.Conn.LocalAddr()
}
func (c *refcountedTransport) DecreaseCount() {
c.mutex.Lock()
c.refCount--
if c.refCount == 0 {
c.unusedSince = time.Now()
}
c.mutex.Unlock()
}
func (c *refcountedTransport) ShouldGarbageCollect(now time.Time) bool {
c.mutex.Lock()
defer c.mutex.Unlock()
return !c.unusedSince.IsZero() && c.unusedSince.Add(maxUnusedDuration).Before(now)
}
type reuse struct {
mutex sync.Mutex
closeChan chan struct{}
gcStopChan chan struct{}
routes routing.Router
unicast map[string] /* IP.String() */ map[int] /* port */ *refcountedTransport
// globalListeners contains transports that are listening on 0.0.0.0 / ::
globalListeners map[int]*refcountedTransport
// globalDialers contains transports that we've dialed out from. These transports are listening on 0.0.0.0 / ::
// On Dial, transports are reused from this map if no transport is available in the globalListeners
// On Listen, transports are reused from this map if the requested port is 0, and then moved to globalListeners
globalDialers map[int]*refcountedTransport
statelessResetKey *quic.StatelessResetKey
tokenGeneratorKey *quic.TokenGeneratorKey
}
func newReuse(srk *quic.StatelessResetKey, tokenKey *quic.TokenGeneratorKey) *reuse {
r := &reuse{
unicast: make(map[string]map[int]*refcountedTransport),
globalListeners: make(map[int]*refcountedTransport),
globalDialers: make(map[int]*refcountedTransport),
closeChan: make(chan struct{}),
gcStopChan: make(chan struct{}),
statelessResetKey: srk,
tokenGeneratorKey: tokenKey,
}
go r.gc()
return r
}
func (r *reuse) gc() {
defer func() {
r.mutex.Lock()
for _, tr := range r.globalListeners {
tr.Close()
}
for _, tr := range r.globalDialers {
tr.Close()
}
for _, trs := range r.unicast {
for _, tr := range trs {
tr.Close()
}
}
r.mutex.Unlock()
close(r.gcStopChan)
}()
ticker := time.NewTicker(garbageCollectInterval)
defer ticker.Stop()
for {
select {
case <-r.closeChan:
return
case <-ticker.C:
now := time.Now()
r.mutex.Lock()
for key, tr := range r.globalListeners {
if tr.ShouldGarbageCollect(now) {
tr.Close()
delete(r.globalListeners, key)
}
}
for key, tr := range r.globalDialers {
if tr.ShouldGarbageCollect(now) {
tr.Close()
delete(r.globalDialers, key)
}
}
for ukey, trs := range r.unicast {
for key, tr := range trs {
if tr.ShouldGarbageCollect(now) {
tr.Close()
delete(trs, key)
}
}
if len(trs) == 0 {
delete(r.unicast, ukey)
// If we've dropped all transports with a unicast binding,
// assume our routes may have changed.
if len(r.unicast) == 0 {
r.routes = nil
} else {
// Ignore the error, there's nothing we can do about
// it.
r.routes, _ = netroute.New()
}
}
}
r.mutex.Unlock()
}
}
}
func (r *reuse) TransportForDial(network string, raddr *net.UDPAddr) (*refcountedTransport, error) {
var ip *net.IP
// Only bother looking up the source address if we actually _have_ non 0.0.0.0 listeners.
// Otherwise, save some time.
r.mutex.Lock()
router := r.routes
r.mutex.Unlock()
if router != nil {
_, _, src, err := router.Route(raddr.IP)
if err == nil && !src.IsUnspecified() {
ip = &src
}
}
r.mutex.Lock()
defer r.mutex.Unlock()
tr, err := r.transportForDialLocked(network, ip)
if err != nil {
return nil, err
}
tr.IncreaseCount()
return tr, nil
}
func (r *reuse) transportForDialLocked(network string, source *net.IP) (*refcountedTransport, error) {
if source != nil {
// We already have at least one suitable transport...
if trs, ok := r.unicast[source.String()]; ok {
// ... we don't care which port we're dialing from. Just use the first.
for _, tr := range trs {
return tr, nil
}
}
}
// Use a transport listening on 0.0.0.0 (or ::).
// Again, we don't care about the port number.
for _, tr := range r.globalListeners {
return tr, nil
}
// Use a transport we've previously dialed from
for _, tr := range r.globalDialers {
return tr, nil
}
// We don't have a transport that we can use for dialing.
// Dial a new connection from a random port.
var addr *net.UDPAddr
switch network {
case "udp4":
addr = &net.UDPAddr{IP: net.IPv4zero, Port: 0}
case "udp6":
addr = &net.UDPAddr{IP: net.IPv6zero, Port: 0}
}
conn, err := net.ListenUDP(network, addr)
if err != nil {
return nil, err
}
tr := &refcountedTransport{Transport: quic.Transport{
Conn: conn,
StatelessResetKey: r.statelessResetKey,
TokenGeneratorKey: r.tokenGeneratorKey,
}, packetConn: conn}
r.globalDialers[conn.LocalAddr().(*net.UDPAddr).Port] = tr
return tr, nil
}
func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcountedTransport, error) {
r.mutex.Lock()
defer r.mutex.Unlock()
// Check if we can reuse a transport we have already dialed out from.
// We reuse a transport from globalDialers when the requested port is 0 or the requested
// port is already in the globalDialers.
// If we are reusing a transport from globalDialers, we move the globalDialers entry to
// globalListeners
if laddr.IP.IsUnspecified() {
var rTr *refcountedTransport
var localAddr *net.UDPAddr
if laddr.Port == 0 {
// the requested port is 0, we can reuse any transport
for _, tr := range r.globalDialers {
rTr = tr
localAddr = rTr.LocalAddr().(*net.UDPAddr)
delete(r.globalDialers, localAddr.Port)
break
}
} else if _, ok := r.globalDialers[laddr.Port]; ok {
rTr = r.globalDialers[laddr.Port]
localAddr = rTr.LocalAddr().(*net.UDPAddr)
delete(r.globalDialers, localAddr.Port)
}
// found a match
if rTr != nil {
rTr.IncreaseCount()
r.globalListeners[localAddr.Port] = rTr
return rTr, nil
}
}
conn, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
}
localAddr := conn.LocalAddr().(*net.UDPAddr)
tr := &refcountedTransport{
Transport: quic.Transport{
Conn: conn,
StatelessResetKey: r.statelessResetKey,
},
packetConn: conn,
}
tr.IncreaseCount()
// Deal with listen on a global address
if localAddr.IP.IsUnspecified() {
// The kernel already checked that the laddr is not already listen
// so we need not check here (when we create ListenUDP).
r.globalListeners[localAddr.Port] = tr
return tr, nil
}
// Deal with listen on a unicast address
if _, ok := r.unicast[localAddr.IP.String()]; !ok {
r.unicast[localAddr.IP.String()] = make(map[int]*refcountedTransport)
// Assume the system's routes may have changed if we're adding a new listener.
// Ignore the error, there's nothing we can do.
r.routes, _ = netroute.New()
}
// The kernel already checked that the laddr is not already listen
// so we need not check here (when we create ListenUDP).
r.unicast[localAddr.IP.String()][localAddr.Port] = tr
return tr, nil
}
func (r *reuse) Close() error {
close(r.closeChan)
<-r.gcStopChan
return nil
}

View File

@@ -0,0 +1,94 @@
package quicreuse
import (
"bufio"
"fmt"
"io"
"os"
"time"
golog "github.com/ipfs/go-log/v2"
"github.com/klauspost/compress/zstd"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/qlog"
)
var log = golog.Logger("quic-utils")
// QLOGTracer holds a qlog tracer dir, if qlogging is enabled (enabled using the QLOGDIR environment variable).
// Otherwise it is an empty string.
var qlogTracerDir string
func init() {
qlogTracerDir = os.Getenv("QLOGDIR")
}
func qloggerForDir(qlogDir string, p logging.Perspective, ci quic.ConnectionID) *logging.ConnectionTracer {
// create the QLOGDIR, if it doesn't exist
if err := os.MkdirAll(qlogDir, 0777); err != nil {
log.Errorf("creating the QLOGDIR failed: %s", err)
return nil
}
return qlog.NewConnectionTracer(newQlogger(qlogDir, p, ci), p, ci)
}
// The qlogger logs qlog events to a temporary file: .<name>.qlog.swp.
// When it is closed, it compresses the temporary file and saves it as <name>.qlog.zst.
// It is not possible to compress on the fly, as compression algorithms keep a lot of internal state,
// which can easily exhaust the host system's memory when running a few hundred QUIC connections in parallel.
type qlogger struct {
f *os.File // QLOGDIR/.log_xxx.qlog.swp
filename string // QLOGDIR/log_xxx.qlog.zst
*bufio.Writer // buffering the f
}
func newQlogger(qlogDir string, role logging.Perspective, connID quic.ConnectionID) io.WriteCloser {
t := time.Now().UTC().Format("2006-01-02T15-04-05.999999999UTC")
r := "server"
if role == logging.PerspectiveClient {
r = "client"
}
finalFilename := fmt.Sprintf("%s%clog_%s_%s_%s.qlog.zst", qlogDir, os.PathSeparator, t, r, connID)
filename := fmt.Sprintf("%s%c.log_%s_%s_%s.qlog.swp", qlogDir, os.PathSeparator, t, r, connID)
f, err := os.Create(filename)
if err != nil {
log.Errorf("unable to create qlog file %s: %s", filename, err)
return nil
}
return &qlogger{
f: f,
filename: finalFilename,
// The size of a qlog file for a raw file download is ~2/3 of the amount of data transferred.
// bufio.NewWriter creates a buffer with a buffer of only 4 kB, leading to a large number of syscalls.
Writer: bufio.NewWriterSize(f, 128<<10),
}
}
func (l *qlogger) Close() error {
defer os.Remove(l.f.Name())
defer l.f.Close()
if err := l.Writer.Flush(); err != nil {
return err
}
if _, err := l.f.Seek(0, io.SeekStart); err != nil { // set the read position to the beginning of the file
return err
}
f, err := os.Create(l.filename)
if err != nil {
return err
}
defer f.Close()
buf := bufio.NewWriterSize(f, 128<<10)
c, err := zstd.NewWriter(buf, zstd.WithEncoderLevel(zstd.SpeedFastest), zstd.WithWindowSize(32*1024))
if err != nil {
return err
}
if _, err := io.Copy(c, l.f); err != nil {
return err
}
if err := c.Close(); err != nil {
return err
}
return buf.Flush()
}

View File

@@ -0,0 +1,268 @@
//go:build !windows && !riscv64
package tcp
import (
"strings"
"sync"
"time"
"github.com/marten-seemann/tcp"
"github.com/mikioh/tcpinfo"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/prometheus/client_golang/prometheus"
)
var (
newConns *prometheus.CounterVec
closedConns *prometheus.CounterVec
segsSentDesc *prometheus.Desc
segsRcvdDesc *prometheus.Desc
bytesSentDesc *prometheus.Desc
bytesRcvdDesc *prometheus.Desc
)
const collectFrequency = 10 * time.Second
var collector *aggregatingCollector
var initMetricsOnce sync.Once
func initMetrics() {
segsSentDesc = prometheus.NewDesc("tcp_sent_segments_total", "TCP segments sent", nil, nil)
segsRcvdDesc = prometheus.NewDesc("tcp_rcvd_segments_total", "TCP segments received", nil, nil)
bytesSentDesc = prometheus.NewDesc("tcp_sent_bytes", "TCP bytes sent", nil, nil)
bytesRcvdDesc = prometheus.NewDesc("tcp_rcvd_bytes", "TCP bytes received", nil, nil)
collector = newAggregatingCollector()
prometheus.MustRegister(collector)
const direction = "direction"
newConns = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "tcp_connections_new_total",
Help: "TCP new connections",
},
[]string{direction},
)
prometheus.MustRegister(newConns)
closedConns = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "tcp_connections_closed_total",
Help: "TCP connections closed",
},
[]string{direction},
)
prometheus.MustRegister(closedConns)
}
type aggregatingCollector struct {
cronOnce sync.Once
mutex sync.Mutex
highestID uint64
conns map[uint64] /* id */ *tracingConn
rtts prometheus.Histogram
connDurations prometheus.Histogram
segsSent, segsRcvd uint64
bytesSent, bytesRcvd uint64
}
var _ prometheus.Collector = &aggregatingCollector{}
func newAggregatingCollector() *aggregatingCollector {
c := &aggregatingCollector{
conns: make(map[uint64]*tracingConn),
rtts: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "tcp_rtt",
Help: "TCP round trip time",
Buckets: prometheus.ExponentialBuckets(0.001, 1.25, 40), // 1ms to ~6000ms
}),
connDurations: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "tcp_connection_duration",
Help: "TCP Connection Duration",
Buckets: prometheus.ExponentialBuckets(1, 1.5, 40), // 1s to ~12 weeks
}),
}
return c
}
func (c *aggregatingCollector) AddConn(t *tracingConn) uint64 {
c.mutex.Lock()
defer c.mutex.Unlock()
c.highestID++
c.conns[c.highestID] = t
return c.highestID
}
func (c *aggregatingCollector) removeConn(id uint64) {
delete(c.conns, id)
}
func (c *aggregatingCollector) Describe(descs chan<- *prometheus.Desc) {
descs <- c.rtts.Desc()
descs <- c.connDurations.Desc()
if hasSegmentCounter {
descs <- segsSentDesc
descs <- segsRcvdDesc
}
if hasByteCounter {
descs <- bytesSentDesc
descs <- bytesRcvdDesc
}
}
func (c *aggregatingCollector) cron() {
ticker := time.NewTicker(collectFrequency)
defer ticker.Stop()
for now := range ticker.C {
c.gatherMetrics(now)
}
}
func (c *aggregatingCollector) gatherMetrics(now time.Time) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.segsSent = 0
c.segsRcvd = 0
c.bytesSent = 0
c.bytesRcvd = 0
for _, conn := range c.conns {
info, err := conn.getTCPInfo()
if err != nil {
if strings.Contains(err.Error(), "use of closed network connection") {
continue
}
log.Errorf("Failed to get TCP info: %s", err)
continue
}
if hasSegmentCounter {
c.segsSent += getSegmentsSent(info)
c.segsRcvd += getSegmentsRcvd(info)
}
if hasByteCounter {
c.bytesSent += getBytesSent(info)
c.bytesRcvd += getBytesRcvd(info)
}
c.rtts.Observe(info.RTT.Seconds())
c.connDurations.Observe(now.Sub(conn.startTime).Seconds())
}
}
func (c *aggregatingCollector) Collect(metrics chan<- prometheus.Metric) {
// Start collecting the metrics collection the first time Collect is called.
c.cronOnce.Do(func() {
c.gatherMetrics(time.Now())
go c.cron()
})
c.mutex.Lock()
defer c.mutex.Unlock()
metrics <- c.rtts
metrics <- c.connDurations
if hasSegmentCounter {
segsSentMetric, err := prometheus.NewConstMetric(segsSentDesc, prometheus.CounterValue, float64(c.segsSent))
if err != nil {
log.Errorf("creating tcp_sent_segments_total metric failed: %v", err)
return
}
segsRcvdMetric, err := prometheus.NewConstMetric(segsRcvdDesc, prometheus.CounterValue, float64(c.segsRcvd))
if err != nil {
log.Errorf("creating tcp_rcvd_segments_total metric failed: %v", err)
return
}
metrics <- segsSentMetric
metrics <- segsRcvdMetric
}
if hasByteCounter {
bytesSentMetric, err := prometheus.NewConstMetric(bytesSentDesc, prometheus.CounterValue, float64(c.bytesSent))
if err != nil {
log.Errorf("creating tcp_sent_bytes metric failed: %v", err)
return
}
bytesRcvdMetric, err := prometheus.NewConstMetric(bytesRcvdDesc, prometheus.CounterValue, float64(c.bytesRcvd))
if err != nil {
log.Errorf("creating tcp_rcvd_bytes metric failed: %v", err)
return
}
metrics <- bytesSentMetric
metrics <- bytesRcvdMetric
}
}
func (c *aggregatingCollector) ClosedConn(conn *tracingConn, direction string) {
c.mutex.Lock()
collector.removeConn(conn.id)
c.mutex.Unlock()
closedConns.WithLabelValues(direction).Inc()
}
type tracingConn struct {
id uint64
startTime time.Time
isClient bool
manet.Conn
tcpConn *tcp.Conn
}
func newTracingConn(c manet.Conn, isClient bool) (*tracingConn, error) {
initMetricsOnce.Do(func() { initMetrics() })
conn, err := tcp.NewConn(c)
if err != nil {
return nil, err
}
tc := &tracingConn{
startTime: time.Now(),
isClient: isClient,
Conn: c,
tcpConn: conn,
}
tc.id = collector.AddConn(tc)
newConns.WithLabelValues(tc.getDirection()).Inc()
return tc, nil
}
func (c *tracingConn) getDirection() string {
if c.isClient {
return "outgoing"
}
return "incoming"
}
func (c *tracingConn) Close() error {
collector.ClosedConn(c, c.getDirection())
return c.Conn.Close()
}
func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) {
var o tcpinfo.Info
var b [256]byte
i, err := c.tcpConn.Option(o.Level(), o.Name(), b[:])
if err != nil {
return nil, err
}
info := i.(*tcpinfo.Info)
return info, nil
}
type tracingListener struct {
manet.Listener
}
func newTracingListener(l manet.Listener) *tracingListener {
return &tracingListener{Listener: l}
}
func (l *tracingListener) Accept() (manet.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return newTracingConn(conn, false)
}

View File

@@ -0,0 +1,15 @@
//go:build darwin
package tcp
import "github.com/mikioh/tcpinfo"
const (
hasSegmentCounter = true
hasByteCounter = true
)
func getSegmentsSent(info *tcpinfo.Info) uint64 { return info.Sys.SegsSent }
func getSegmentsRcvd(info *tcpinfo.Info) uint64 { return info.Sys.SegsReceived }
func getBytesSent(info *tcpinfo.Info) uint64 { return info.Sys.BytesSent }
func getBytesRcvd(info *tcpinfo.Info) uint64 { return info.Sys.BytesReceived }

View File

@@ -0,0 +1,15 @@
//go:build !linux && !darwin && !windows && !riscv64
package tcp
import "github.com/mikioh/tcpinfo"
const (
hasSegmentCounter = false
hasByteCounter = false
)
func getSegmentsSent(info *tcpinfo.Info) uint64 { return 0 }
func getSegmentsRcvd(info *tcpinfo.Info) uint64 { return 0 }
func getBytesSent(info *tcpinfo.Info) uint64 { return 0 }
func getBytesRcvd(info *tcpinfo.Info) uint64 { return 0 }

View File

@@ -0,0 +1,15 @@
//go:build linux
package tcp
import "github.com/mikioh/tcpinfo"
const (
hasSegmentCounter = true
hasByteCounter = false
)
func getSegmentsSent(info *tcpinfo.Info) uint64 { return uint64(info.Sys.SegsOut) }
func getSegmentsRcvd(info *tcpinfo.Info) uint64 { return uint64(info.Sys.SegsIn) }
func getBytesSent(info *tcpinfo.Info) uint64 { return 0 }
func getBytesRcvd(info *tcpinfo.Info) uint64 { return 0 }

View File

@@ -0,0 +1,10 @@
// riscv64 see: https://github.com/marten-seemann/tcp/pull/1
//go:build windows || riscv64
package tcp
import manet "github.com/multiformats/go-multiaddr/net"
func newTracingConn(c manet.Conn, _ bool) (manet.Conn, error) { return c, nil }
func newTracingListener(l manet.Listener) manet.Listener { return l }

View File

@@ -0,0 +1,35 @@
package tcp
import (
"os"
"strings"
"github.com/libp2p/go-reuseport"
)
// envReuseport is the env variable name used to turn off reuse port.
// It default to true.
const envReuseport = "LIBP2P_TCP_REUSEPORT"
// envReuseportVal stores the value of envReuseport. defaults to true.
var envReuseportVal = true
func init() {
v := strings.ToLower(os.Getenv(envReuseport))
if v == "false" || v == "f" || v == "0" {
envReuseportVal = false
log.Infof("REUSEPORT disabled (LIBP2P_TCP_REUSEPORT=%s)", v)
}
}
// ReuseportIsAvailable returns whether reuseport is available to be used. This
// is here because we want to be able to turn reuseport on and off selectively.
// For now we use an ENV variable, as this handles our pressing need:
//
// LIBP2P_TCP_REUSEPORT=false ipfs daemon
//
// If this becomes a sought after feature, we could add this to the config.
// In the end, reuseport is a stop-gap.
func ReuseportIsAvailable() bool {
return envReuseportVal && reuseport.Available()
}

View File

@@ -0,0 +1,270 @@
package tcp
import (
"context"
"errors"
"net"
"os"
"runtime"
"syscall"
"time"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/net/reuseport"
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
mafmt "github.com/multiformats/go-multiaddr-fmt"
manet "github.com/multiformats/go-multiaddr/net"
)
const defaultConnectTimeout = 5 * time.Second
var log = logging.Logger("tcp-tpt")
const keepAlivePeriod = 30 * time.Second
type canKeepAlive interface {
SetKeepAlive(bool) error
SetKeepAlivePeriod(time.Duration) error
}
var _ canKeepAlive = &net.TCPConn{}
func tryKeepAlive(conn net.Conn, keepAlive bool) {
keepAliveConn, ok := conn.(canKeepAlive)
if !ok {
log.Errorf("Can't set TCP keepalives.")
return
}
if err := keepAliveConn.SetKeepAlive(keepAlive); err != nil {
// Sometimes we seem to get "invalid argument" results from this function on Darwin.
// This might be due to a closed connection, but I can't reproduce that on Linux.
//
// But there's nothing we can do about invalid arguments, so we'll drop this to a
// debug.
if errors.Is(err, os.ErrInvalid) || errors.Is(err, syscall.EINVAL) {
log.Debugw("failed to enable TCP keepalive", "error", err)
} else {
log.Errorw("failed to enable TCP keepalive", "error", err)
}
return
}
if runtime.GOOS != "openbsd" {
if err := keepAliveConn.SetKeepAlivePeriod(keepAlivePeriod); err != nil {
log.Errorw("failed set keepalive period", "error", err)
}
}
}
// try to set linger on the connection, if possible.
func tryLinger(conn net.Conn, sec int) {
type canLinger interface {
SetLinger(int) error
}
if lingerConn, ok := conn.(canLinger); ok {
_ = lingerConn.SetLinger(sec)
}
}
type tcpListener struct {
manet.Listener
sec int
}
func (ll *tcpListener) Accept() (manet.Conn, error) {
c, err := ll.Listener.Accept()
if err != nil {
return nil, err
}
tryLinger(c, ll.sec)
tryKeepAlive(c, true)
// We're not calling OpenConnection in the resource manager here,
// since the manet.Conn doesn't allow us to save the scope.
// It's the caller's (usually the p2p/net/upgrader) responsibility
// to call the resource manager.
return c, nil
}
type Option func(*TcpTransport) error
func DisableReuseport() Option {
return func(tr *TcpTransport) error {
tr.disableReuseport = true
return nil
}
}
func WithConnectionTimeout(d time.Duration) Option {
return func(tr *TcpTransport) error {
tr.connectTimeout = d
return nil
}
}
func WithMetrics() Option {
return func(tr *TcpTransport) error {
tr.enableMetrics = true
return nil
}
}
// TcpTransport is the TCP transport.
type TcpTransport struct {
// Connection upgrader for upgrading insecure stream connections to
// secure multiplex connections.
upgrader transport.Upgrader
disableReuseport bool // Explicitly disable reuseport.
enableMetrics bool
// TCP connect timeout
connectTimeout time.Duration
rcmgr network.ResourceManager
reuse reuseport.Transport
}
var _ transport.Transport = &TcpTransport{}
var _ transport.DialUpdater = &TcpTransport{}
// NewTCPTransport creates a tcp transport object that tracks dialers and listeners
// created. It represents an entire TCP stack (though it might not necessarily be).
func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*TcpTransport, error) {
if rcmgr == nil {
rcmgr = &network.NullResourceManager{}
}
tr := &TcpTransport{
upgrader: upgrader,
connectTimeout: defaultConnectTimeout, // can be set by using the WithConnectionTimeout option
rcmgr: rcmgr,
}
for _, o := range opts {
if err := o(tr); err != nil {
return nil, err
}
}
return tr, nil
}
var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_TCP))
// CanDial returns true if this transport believes it can dial the given
// multiaddr.
func (t *TcpTransport) CanDial(addr ma.Multiaddr) bool {
return dialMatcher.Matches(addr)
}
func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) {
// Apply the deadline iff applicable
if t.connectTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, t.connectTimeout)
defer cancel()
}
if t.UseReuseport() {
return t.reuse.DialContext(ctx, raddr)
}
var d manet.Dialer
return d.DialContext(ctx, raddr)
}
// Dial dials the peer at the remote address.
func (t *TcpTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (transport.CapableConn, error) {
return t.DialWithUpdates(ctx, raddr, p, nil)
}
func (t *TcpTransport) DialWithUpdates(ctx context.Context, raddr ma.Multiaddr, p peer.ID, updateChan chan<- transport.DialUpdate) (transport.CapableConn, error) {
connScope, err := t.rcmgr.OpenConnection(network.DirOutbound, true, raddr)
if err != nil {
log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err)
return nil, err
}
c, err := t.dialWithScope(ctx, raddr, p, connScope, updateChan)
if err != nil {
connScope.Done()
return nil, err
}
return c, nil
}
func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, connScope network.ConnManagementScope, updateChan chan<- transport.DialUpdate) (transport.CapableConn, error) {
if err := connScope.SetPeer(p); err != nil {
log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err)
return nil, err
}
conn, err := t.maDial(ctx, raddr)
if err != nil {
return nil, err
}
// Set linger to 0 so we never get stuck in the TIME-WAIT state. When
// linger is 0, connections are _reset_ instead of closed with a FIN.
// This means we can immediately reuse the 5-tuple and reconnect.
tryLinger(conn, 0)
tryKeepAlive(conn, true)
c := conn
if t.enableMetrics {
var err error
c, err = newTracingConn(conn, true)
if err != nil {
return nil, err
}
}
if updateChan != nil {
select {
case updateChan <- transport.DialUpdate{Kind: transport.UpdateKindHandshakeProgressed, Addr: raddr}:
default:
// It is better to skip the update than to delay upgrading the connection
}
}
direction := network.DirOutbound
if ok, isClient, _ := network.GetSimultaneousConnect(ctx); ok && !isClient {
direction = network.DirInbound
}
return t.upgrader.Upgrade(ctx, t, c, direction, p, connScope)
}
// UseReuseport returns true if reuseport is enabled and available.
func (t *TcpTransport) UseReuseport() bool {
return !t.disableReuseport && ReuseportIsAvailable()
}
func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) {
if t.UseReuseport() {
return t.reuse.Listen(laddr)
}
return manet.Listen(laddr)
}
// Listen listens on the given multiaddr.
func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) {
list, err := t.maListen(laddr)
if err != nil {
return nil, err
}
if t.enableMetrics {
list = newTracingListener(&tcpListener{list, 0})
}
return t.upgrader.UpgradeListener(t, list), nil
}
// Protocols returns the list of terminal protocols this transport can dial.
func (t *TcpTransport) Protocols() []int {
return []int{ma.P_TCP}
}
// Proxy always returns false for the TCP transport.
func (t *TcpTransport) Proxy() bool {
return false
}
func (t *TcpTransport) String() string {
return "TCP"
}

View File

@@ -0,0 +1,5 @@
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

View File

@@ -0,0 +1,19 @@
The MIT License (MIT)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@@ -0,0 +1,175 @@
package websocket
import (
"fmt"
"net"
"net/url"
"strconv"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
// Addr is an implementation of net.Addr for WebSocket.
type Addr struct {
*url.URL
}
var _ net.Addr = (*Addr)(nil)
// Network returns the network type for a WebSocket, "websocket".
func (addr *Addr) Network() string {
return "websocket"
}
// NewAddr creates an Addr with `ws` scheme (insecure).
//
// Deprecated. Use NewAddrWithScheme.
func NewAddr(host string) *Addr {
// Older versions of the transport only supported insecure connections (i.e.
// WS instead of WSS). Assume that is the case here.
return NewAddrWithScheme(host, false)
}
// NewAddrWithScheme creates a new Addr using the given host string. isSecure
// should be true for WSS connections and false for WS.
func NewAddrWithScheme(host string, isSecure bool) *Addr {
scheme := "ws"
if isSecure {
scheme = "wss"
}
return &Addr{
URL: &url.URL{
Scheme: scheme,
Host: host,
},
}
}
func ConvertWebsocketMultiaddrToNetAddr(maddr ma.Multiaddr) (net.Addr, error) {
url, err := parseMultiaddr(maddr)
if err != nil {
return nil, err
}
return &Addr{URL: url}, nil
}
func ParseWebsocketNetAddr(a net.Addr) (ma.Multiaddr, error) {
wsa, ok := a.(*Addr)
if !ok {
return nil, fmt.Errorf("not a websocket address")
}
var (
tcpma ma.Multiaddr
err error
port int
host = wsa.Hostname()
)
// Get the port
if portStr := wsa.Port(); portStr != "" {
port, err = strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("failed to parse port '%q': %s", portStr, err)
}
} else {
return nil, fmt.Errorf("invalid port in url: '%q'", wsa.URL)
}
// NOTE: Ignoring IPv6 zones...
// Detect if host is IP address or DNS
if ip := net.ParseIP(host); ip != nil {
// Assume IP address
tcpma, err = manet.FromNetAddr(&net.TCPAddr{
IP: ip,
Port: port,
})
if err != nil {
return nil, err
}
} else {
// Assume DNS name
tcpma, err = ma.NewMultiaddr(fmt.Sprintf("/dns/%s/tcp/%d", host, port))
if err != nil {
return nil, err
}
}
wsma, err := ma.NewMultiaddr("/" + wsa.Scheme)
if err != nil {
return nil, err
}
return tcpma.Encapsulate(wsma), nil
}
func parseMultiaddr(maddr ma.Multiaddr) (*url.URL, error) {
parsed, err := parseWebsocketMultiaddr(maddr)
if err != nil {
return nil, err
}
scheme := "ws"
if parsed.isWSS {
scheme = "wss"
}
network, host, err := manet.DialArgs(parsed.restMultiaddr)
if err != nil {
return nil, err
}
switch network {
case "tcp", "tcp4", "tcp6":
default:
return nil, fmt.Errorf("unsupported websocket network %s", network)
}
return &url.URL{
Scheme: scheme,
Host: host,
}, nil
}
type parsedWebsocketMultiaddr struct {
isWSS bool
// sni is the SNI value for the TLS handshake, and for setting HTTP Host header
sni *ma.Component
// the rest of the multiaddr before the /tls/sni/example.com/ws or /ws or /wss
restMultiaddr ma.Multiaddr
}
func parseWebsocketMultiaddr(a ma.Multiaddr) (parsedWebsocketMultiaddr, error) {
out := parsedWebsocketMultiaddr{}
// First check if we have a WSS component. If so we'll canonicalize it into a /tls/ws
withoutWss := a.Decapsulate(wssComponent)
if !withoutWss.Equal(a) {
a = withoutWss.Encapsulate(tlsWsComponent)
}
// Remove the ws component
withoutWs := a.Decapsulate(wsComponent)
if withoutWs.Equal(a) {
return out, fmt.Errorf("not a websocket multiaddr")
}
rest := withoutWs
// If this is not a wss then withoutWs is the rest of the multiaddr
out.restMultiaddr = withoutWs
for {
var head *ma.Component
rest, head = ma.SplitLast(rest)
if head == nil || rest == nil {
break
}
if head.Protocol().Code == ma.P_SNI {
out.sni = head
} else if head.Protocol().Code == ma.P_TLS {
out.isWSS = true
out.restMultiaddr = rest
break
}
}
return out, nil
}

View File

@@ -0,0 +1,164 @@
package websocket
import (
"io"
"net"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
ws "github.com/gorilla/websocket"
)
// GracefulCloseTimeout is the time to wait trying to gracefully close a
// connection before simply cutting it.
var GracefulCloseTimeout = 100 * time.Millisecond
// Conn implements net.Conn interface for gorilla/websocket.
type Conn struct {
*ws.Conn
secure bool
DefaultMessageType int
reader io.Reader
closeOnce sync.Once
readLock, writeLock sync.Mutex
}
var _ net.Conn = (*Conn)(nil)
// NewConn creates a Conn given a regular gorilla/websocket Conn.
func NewConn(raw *ws.Conn, secure bool) *Conn {
return &Conn{
Conn: raw,
secure: secure,
DefaultMessageType: ws.BinaryMessage,
}
}
func (c *Conn) Read(b []byte) (int, error) {
c.readLock.Lock()
defer c.readLock.Unlock()
if c.reader == nil {
if err := c.prepNextReader(); err != nil {
return 0, err
}
}
for {
n, err := c.reader.Read(b)
switch err {
case io.EOF:
c.reader = nil
if n > 0 {
return n, nil
}
if err := c.prepNextReader(); err != nil {
return 0, err
}
// explicitly looping
default:
return n, err
}
}
}
func (c *Conn) prepNextReader() error {
t, r, err := c.Conn.NextReader()
if err != nil {
if wserr, ok := err.(*ws.CloseError); ok {
if wserr.Code == 1000 || wserr.Code == 1005 {
return io.EOF
}
}
return err
}
if t == ws.CloseMessage {
return io.EOF
}
c.reader = r
return nil
}
func (c *Conn) Write(b []byte) (n int, err error) {
c.writeLock.Lock()
defer c.writeLock.Unlock()
if err := c.Conn.WriteMessage(c.DefaultMessageType, b); err != nil {
return 0, err
}
return len(b), nil
}
// Close closes the connection. Only the first call to Close will receive the
// close error, subsequent and concurrent calls will return nil.
// This method is thread-safe.
func (c *Conn) Close() error {
var err error
c.closeOnce.Do(func() {
err1 := c.Conn.WriteControl(
ws.CloseMessage,
ws.FormatCloseMessage(ws.CloseNormalClosure, "closed"),
time.Now().Add(GracefulCloseTimeout),
)
err2 := c.Conn.Close()
switch {
case err1 != nil:
err = err1
case err2 != nil:
err = err2
}
})
return err
}
func (c *Conn) LocalAddr() net.Addr {
return NewAddrWithScheme(c.Conn.LocalAddr().String(), c.secure)
}
func (c *Conn) RemoteAddr() net.Addr {
return NewAddrWithScheme(c.Conn.RemoteAddr().String(), c.secure)
}
func (c *Conn) SetDeadline(t time.Time) error {
if err := c.SetReadDeadline(t); err != nil {
return err
}
return c.SetWriteDeadline(t)
}
func (c *Conn) SetReadDeadline(t time.Time) error {
// Don't lock when setting the read deadline. That would prevent us from
// interrupting an in-progress read.
return c.Conn.SetReadDeadline(t)
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
// Unlike the read deadline, we need to lock when setting the write
// deadline.
c.writeLock.Lock()
defer c.writeLock.Unlock()
return c.Conn.SetWriteDeadline(t)
}
type capableConn struct {
transport.CapableConn
}
func (c *capableConn) ConnState() network.ConnectionState {
cs := c.CapableConn.ConnState()
cs.Transport = "websocket"
return cs
}

View File

@@ -0,0 +1,160 @@
package websocket
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"strings"
"github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
type listener struct {
nl net.Listener
server http.Server
// The Go standard library sets the http.Server.TLSConfig no matter if this is a WS or WSS,
// so we can't rely on checking if server.TLSConfig is set.
isWss bool
laddr ma.Multiaddr
closed chan struct{}
incoming chan *Conn
}
func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr {
if !pwma.isWSS {
return pwma.restMultiaddr.Encapsulate(wsComponent)
}
if pwma.sni == nil {
return pwma.restMultiaddr.Encapsulate(tlsComponent).Encapsulate(wsComponent)
}
return pwma.restMultiaddr.Encapsulate(tlsComponent).Encapsulate(pwma.sni).Encapsulate(wsComponent)
}
// newListener creates a new listener from a raw net.Listener.
// tlsConf may be nil (for unencrypted websockets).
func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) {
parsed, err := parseWebsocketMultiaddr(a)
if err != nil {
return nil, err
}
if parsed.isWSS && tlsConf == nil {
return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a)
}
lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr)
if err != nil {
return nil, err
}
nl, err := net.Listen(lnet, lnaddr)
if err != nil {
return nil, err
}
laddr, err := manet.FromNetAddr(nl.Addr())
if err != nil {
return nil, err
}
first, _ := ma.SplitFirst(a)
// Don't resolve dns addresses.
// We want to be able to announce domain names, so the peer can validate the TLS certificate.
if c := first.Protocol().Code; c == ma.P_DNS || c == ma.P_DNS4 || c == ma.P_DNS6 || c == ma.P_DNSADDR {
_, last := ma.SplitFirst(laddr)
laddr = first.Encapsulate(last)
}
parsed.restMultiaddr = laddr
ln := &listener{
nl: nl,
laddr: parsed.toMultiaddr(),
incoming: make(chan *Conn),
closed: make(chan struct{}),
}
ln.server = http.Server{Handler: ln}
if parsed.isWSS {
ln.isWss = true
ln.server.TLSConfig = tlsConf
}
return ln, nil
}
func (l *listener) serve() {
defer close(l.closed)
if !l.isWss {
l.server.Serve(l.nl)
} else {
l.server.ServeTLS(l.nl, "", "")
}
}
func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
// The upgrader writes a response for us.
return
}
select {
case l.incoming <- NewConn(c, l.isWss):
case <-l.closed:
c.Close()
}
// The connection has been hijacked, it's safe to return.
}
func (l *listener) Accept() (manet.Conn, error) {
select {
case c, ok := <-l.incoming:
if !ok {
return nil, transport.ErrListenerClosed
}
mnc, err := manet.WrapNetConn(c)
if err != nil {
c.Close()
return nil, err
}
return mnc, nil
case <-l.closed:
return nil, transport.ErrListenerClosed
}
}
func (l *listener) Addr() net.Addr {
return l.nl.Addr()
}
func (l *listener) Close() error {
l.server.Close()
err := l.nl.Close()
<-l.closed
if strings.Contains(err.Error(), "use of closed network connection") {
return transport.ErrListenerClosed
}
return err
}
func (l *listener) Multiaddr() ma.Multiaddr {
return l.laddr
}
type transportListener struct {
transport.Listener
}
func (l *transportListener) Accept() (transport.CapableConn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &capableConn{CapableConn: conn}, nil
}

View File

@@ -0,0 +1,246 @@
// Package websocket implements a websocket based transport for go-libp2p.
package websocket
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
mafmt "github.com/multiformats/go-multiaddr-fmt"
manet "github.com/multiformats/go-multiaddr/net"
ws "github.com/gorilla/websocket"
)
// WsFmt is multiaddr formatter for WsProtocol
var WsFmt = mafmt.And(mafmt.TCP, mafmt.Base(ma.P_WS))
var dialMatcher = mafmt.And(
mafmt.Or(mafmt.IP, mafmt.DNS),
mafmt.Base(ma.P_TCP),
mafmt.Or(
mafmt.Base(ma.P_WS),
mafmt.And(
mafmt.Or(
mafmt.And(
mafmt.Base(ma.P_TLS),
mafmt.Base(ma.P_SNI)),
mafmt.Base(ma.P_TLS),
),
mafmt.Base(ma.P_WS)),
mafmt.Base(ma.P_WSS)))
var (
wssComponent = ma.StringCast("/wss")
tlsWsComponent = ma.StringCast("/tls/ws")
tlsComponent = ma.StringCast("/tls")
wsComponent = ma.StringCast("/ws")
)
func init() {
manet.RegisterFromNetAddr(ParseWebsocketNetAddr, "websocket")
manet.RegisterToNetAddr(ConvertWebsocketMultiaddrToNetAddr, "ws")
manet.RegisterToNetAddr(ConvertWebsocketMultiaddrToNetAddr, "wss")
}
// Default gorilla upgrader
var upgrader = ws.Upgrader{
// Allow requests from *all* origins.
CheckOrigin: func(r *http.Request) bool {
return true
},
}
type Option func(*WebsocketTransport) error
// WithTLSClientConfig sets a TLS client configuration on the WebSocket Dialer. Only
// relevant for non-browser usages.
//
// Some useful use cases include setting InsecureSkipVerify to `true`, or
// setting user-defined trusted CA certificates.
func WithTLSClientConfig(c *tls.Config) Option {
return func(t *WebsocketTransport) error {
t.tlsClientConf = c
return nil
}
}
// WithTLSConfig sets a TLS configuration for the WebSocket listener.
func WithTLSConfig(conf *tls.Config) Option {
return func(t *WebsocketTransport) error {
t.tlsConf = conf
return nil
}
}
// WebsocketTransport is the actual go-libp2p transport
type WebsocketTransport struct {
upgrader transport.Upgrader
rcmgr network.ResourceManager
tlsClientConf *tls.Config
tlsConf *tls.Config
}
var _ transport.Transport = (*WebsocketTransport)(nil)
func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*WebsocketTransport, error) {
if rcmgr == nil {
rcmgr = &network.NullResourceManager{}
}
t := &WebsocketTransport{
upgrader: u,
rcmgr: rcmgr,
tlsClientConf: &tls.Config{},
}
for _, opt := range opts {
if err := opt(t); err != nil {
return nil, err
}
}
return t, nil
}
func (t *WebsocketTransport) CanDial(a ma.Multiaddr) bool {
return dialMatcher.Matches(a)
}
func (t *WebsocketTransport) Protocols() []int {
return []int{ma.P_WS, ma.P_WSS}
}
func (t *WebsocketTransport) Proxy() bool {
return false
}
func (t *WebsocketTransport) Resolve(_ context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) {
parsed, err := parseWebsocketMultiaddr(maddr)
if err != nil {
return nil, err
}
if !parsed.isWSS {
// No /tls/ws component, this isn't a secure websocket multiaddr. We can just return it here
return []ma.Multiaddr{maddr}, nil
}
if parsed.sni == nil {
var err error
// We don't have an sni component, we'll use dns/dnsaddr
ma.ForEach(parsed.restMultiaddr, func(c ma.Component) bool {
switch c.Protocol().Code {
case ma.P_DNS, ma.P_DNS4, ma.P_DNS6:
// err shouldn't happen since this means we couldn't parse a dns hostname for an sni value.
parsed.sni, err = ma.NewComponent("sni", c.Value())
return false
}
return true
})
if err != nil {
return nil, err
}
}
if parsed.sni == nil {
// we didn't find anything to set the sni with. So we just return the given multiaddr
return []ma.Multiaddr{maddr}, nil
}
return []ma.Multiaddr{parsed.toMultiaddr()}, nil
}
func (t *WebsocketTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (transport.CapableConn, error) {
connScope, err := t.rcmgr.OpenConnection(network.DirOutbound, true, raddr)
if err != nil {
return nil, err
}
c, err := t.dialWithScope(ctx, raddr, p, connScope)
if err != nil {
connScope.Done()
return nil, err
}
return c, nil
}
func (t *WebsocketTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, connScope network.ConnManagementScope) (transport.CapableConn, error) {
macon, err := t.maDial(ctx, raddr)
if err != nil {
return nil, err
}
conn, err := t.upgrader.Upgrade(ctx, t, macon, network.DirOutbound, p, connScope)
if err != nil {
return nil, err
}
return &capableConn{CapableConn: conn}, nil
}
func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) {
wsurl, err := parseMultiaddr(raddr)
if err != nil {
return nil, err
}
isWss := wsurl.Scheme == "wss"
dialer := ws.Dialer{HandshakeTimeout: 30 * time.Second}
if isWss {
sni := ""
sni, err = raddr.ValueForProtocol(ma.P_SNI)
if err != nil {
sni = ""
}
if sni != "" {
copytlsClientConf := t.tlsClientConf.Clone()
copytlsClientConf.ServerName = sni
dialer.TLSClientConfig = copytlsClientConf
ipAddr := wsurl.Host
// Setting the NetDial because we already have the resolved IP address, so we don't want to do another resolution.
// We set the `.Host` to the sni field so that the host header gets properly set.
dialer.NetDial = func(network, address string) (net.Conn, error) {
tcpAddr, err := net.ResolveTCPAddr(network, ipAddr)
if err != nil {
return nil, err
}
return net.DialTCP("tcp", nil, tcpAddr)
}
wsurl.Host = sni + ":" + wsurl.Port()
} else {
dialer.TLSClientConfig = t.tlsClientConf
}
}
wscon, _, err := dialer.DialContext(ctx, wsurl.String(), nil)
if err != nil {
return nil, err
}
mnc, err := manet.WrapNetConn(NewConn(wscon, isWss))
if err != nil {
wscon.Close()
return nil, err
}
return mnc, nil
}
func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) {
l, err := newListener(a, t.tlsConf)
if err != nil {
return nil, err
}
go l.serve()
return l, nil
}
func (t *WebsocketTransport) Listen(a ma.Multiaddr) (transport.Listener, error) {
malist, err := t.maListen(a)
if err != nil {
return nil, err
}
return &transportListener{Listener: t.upgrader.UpgradeListener(t, malist)}, nil
}

View File

@@ -0,0 +1,213 @@
package libp2pwebtransport
import (
"context"
"crypto/sha256"
"crypto/tls"
"encoding/binary"
"fmt"
"sync"
"time"
"github.com/benbjohnson/clock"
ic "github.com/libp2p/go-libp2p/core/crypto"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multihash"
)
// Allow for a bit of clock skew.
// When we generate a certificate, the NotBefore time is set to clockSkewAllowance before the current time.
// Similarly, we stop using a certificate one clockSkewAllowance before its expiry time.
const clockSkewAllowance = time.Hour
const validityMinusTwoSkew = certValidity - (2 * clockSkewAllowance)
type certConfig struct {
tlsConf *tls.Config
sha256 [32]byte // cached from the tlsConf
}
func (c *certConfig) Start() time.Time { return c.tlsConf.Certificates[0].Leaf.NotBefore }
func (c *certConfig) End() time.Time { return c.tlsConf.Certificates[0].Leaf.NotAfter }
func newCertConfig(key ic.PrivKey, start, end time.Time) (*certConfig, error) {
conf, err := getTLSConf(key, start, end)
if err != nil {
return nil, err
}
return &certConfig{
tlsConf: conf,
sha256: sha256.Sum256(conf.Certificates[0].Leaf.Raw),
}, nil
}
// Certificate renewal logic:
// 1. On startup, we generate one cert that is valid from now (-1h, to allow for clock skew), and another
// cert that is valid from the expiry date of the first certificate (again, with allowance for clock skew).
// 2. Once we reach 1h before expiry of the first certificate, we switch over to the second certificate.
// At the same time, we stop advertising the certhash of the first cert and generate the next cert.
type certManager struct {
clock clock.Clock
ctx context.Context
ctxCancel context.CancelFunc
refCount sync.WaitGroup
mx sync.RWMutex
lastConfig *certConfig // initially nil
currentConfig *certConfig
nextConfig *certConfig // nil until we have passed half the certValidity of the current config
addrComp ma.Multiaddr
serializedCertHashes [][]byte
}
func newCertManager(hostKey ic.PrivKey, clock clock.Clock) (*certManager, error) {
m := &certManager{clock: clock}
m.ctx, m.ctxCancel = context.WithCancel(context.Background())
if err := m.init(hostKey); err != nil {
return nil, err
}
m.background(hostKey)
return m, nil
}
// getCurrentTimeBucket returns the canonical start time of the given time as
// bucketed by ranges of certValidity since unix epoch (plus an offset). This
// lets you get the same time ranges across reboots without having to persist
// state.
// ```
// ... v--- epoch + offset
// ... |--------| |--------| ...
// ... |--------| |--------| ...
// ```
func getCurrentBucketStartTime(now time.Time, offset time.Duration) time.Time {
currentBucket := (now.UnixMilli() - offset.Milliseconds()) / validityMinusTwoSkew.Milliseconds()
return time.UnixMilli(offset.Milliseconds() + currentBucket*validityMinusTwoSkew.Milliseconds())
}
func (m *certManager) init(hostKey ic.PrivKey) error {
start := m.clock.Now()
pubkeyBytes, err := hostKey.GetPublic().Raw()
if err != nil {
return err
}
// We want to add a random offset to each start time so that not all certs
// rotate at the same time across the network. The offset represents moving
// the bucket start time some `offset` earlier.
offset := (time.Duration(binary.LittleEndian.Uint16(pubkeyBytes)) * time.Minute) % certValidity
// We want the certificate have been valid for at least one clockSkewAllowance
start = start.Add(-clockSkewAllowance)
startTime := getCurrentBucketStartTime(start, offset)
m.nextConfig, err = newCertConfig(hostKey, startTime, startTime.Add(certValidity))
if err != nil {
return err
}
return m.rollConfig(hostKey)
}
func (m *certManager) rollConfig(hostKey ic.PrivKey) error {
// We stop using the current certificate clockSkewAllowance before its expiry time.
// At this point, the next certificate needs to be valid for one clockSkewAllowance.
nextStart := m.nextConfig.End().Add(-2 * clockSkewAllowance)
c, err := newCertConfig(hostKey, nextStart, nextStart.Add(certValidity))
if err != nil {
return err
}
m.lastConfig = m.currentConfig
m.currentConfig = m.nextConfig
m.nextConfig = c
if err := m.cacheSerializedCertHashes(); err != nil {
return err
}
return m.cacheAddrComponent()
}
func (m *certManager) background(hostKey ic.PrivKey) {
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(m.clock.Now())
log.Debugw("setting timer", "duration", d.String())
t := m.clock.Timer(d)
m.refCount.Add(1)
go func() {
defer m.refCount.Done()
defer t.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-t.C:
now := m.clock.Now()
m.mx.Lock()
if err := m.rollConfig(hostKey); err != nil {
log.Errorw("rolling config failed", "error", err)
}
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(now)
log.Debugw("rolling certificates", "next", d.String())
t.Reset(d)
m.mx.Unlock()
}
}
}()
}
func (m *certManager) GetConfig() *tls.Config {
m.mx.RLock()
defer m.mx.RUnlock()
return m.currentConfig.tlsConf
}
func (m *certManager) AddrComponent() ma.Multiaddr {
m.mx.RLock()
defer m.mx.RUnlock()
return m.addrComp
}
func (m *certManager) SerializedCertHashes() [][]byte {
return m.serializedCertHashes
}
func (m *certManager) cacheSerializedCertHashes() error {
hashes := make([][32]byte, 0, 3)
if m.lastConfig != nil {
hashes = append(hashes, m.lastConfig.sha256)
}
hashes = append(hashes, m.currentConfig.sha256)
if m.nextConfig != nil {
hashes = append(hashes, m.nextConfig.sha256)
}
m.serializedCertHashes = m.serializedCertHashes[:0]
for _, certHash := range hashes {
h, err := multihash.Encode(certHash[:], multihash.SHA2_256)
if err != nil {
return fmt.Errorf("failed to encode certificate hash: %w", err)
}
m.serializedCertHashes = append(m.serializedCertHashes, h)
}
return nil
}
func (m *certManager) cacheAddrComponent() error {
addr, err := addrComponentForCert(m.currentConfig.sha256[:])
if err != nil {
return err
}
if m.nextConfig != nil {
comp, err := addrComponentForCert(m.nextConfig.sha256[:])
if err != nil {
return err
}
addr = addr.Encapsulate(comp)
}
m.addrComp = addr
return nil
}
func (m *certManager) Close() error {
m.ctxCancel()
m.refCount.Wait()
return nil
}

View File

@@ -0,0 +1,82 @@
package libp2pwebtransport
import (
"context"
"github.com/libp2p/go-libp2p/core/network"
tpt "github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
"github.com/quic-go/webtransport-go"
)
type connSecurityMultiaddrs struct {
network.ConnSecurity
network.ConnMultiaddrs
}
type connMultiaddrs struct {
local, remote ma.Multiaddr
}
var _ network.ConnMultiaddrs = &connMultiaddrs{}
func (c *connMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.local }
func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote }
type conn struct {
*connSecurityMultiaddrs
transport *transport
session *webtransport.Session
scope network.ConnManagementScope
}
var _ tpt.CapableConn = &conn{}
func newConn(tr *transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnManagementScope) *conn {
return &conn{
connSecurityMultiaddrs: sconn,
transport: tr,
session: sess,
scope: scope,
}
}
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
str, err := c.session.OpenStreamSync(ctx)
if err != nil {
return nil, err
}
return &stream{str}, nil
}
func (c *conn) AcceptStream() (network.MuxedStream, error) {
str, err := c.session.AcceptStream(context.Background())
if err != nil {
return nil, err
}
return &stream{str}, nil
}
func (c *conn) allowWindowIncrease(size uint64) bool {
return c.scope.ReserveMemory(int(size), network.ReservationPriorityMedium) == nil
}
// Close closes the connection.
// It must be called even if the peer closed the connection in order for
// garbage collection to properly work in this package.
func (c *conn) Close() error {
c.scope.Done()
c.transport.removeConn(c.session)
return c.session.CloseWithError(0, "")
}
func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil }
func (c *conn) Scope() network.ConnScope { return c.scope }
func (c *conn) Transport() tpt.Transport { return c.transport }
func (c *conn) ConnState() network.ConnectionState {
return network.ConnectionState{Transport: "webtransport"}
}

View File

@@ -0,0 +1,155 @@
package libp2pwebtransport
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/binary"
"errors"
"fmt"
"io"
"math/big"
"time"
"golang.org/x/crypto/hkdf"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/multiformats/go-multihash"
"github.com/quic-go/quic-go/http3"
)
const deterministicCertInfo = "determinisitic cert"
func getTLSConf(key ic.PrivKey, start, end time.Time) (*tls.Config, error) {
cert, priv, err := generateCert(key, start, end)
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{{
Certificate: [][]byte{cert.Raw},
PrivateKey: priv,
Leaf: cert,
}},
NextProtos: []string{http3.NextProtoH3},
}, nil
}
// generateCert generates certs deterministically based on the `key` and start
// time passed in. Uses `golang.org/x/crypto/hkdf`.
func generateCert(key ic.PrivKey, start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, error) {
keyBytes, err := key.Raw()
if err != nil {
return nil, nil, err
}
startTimeSalt := make([]byte, 8)
binary.LittleEndian.PutUint64(startTimeSalt, uint64(start.UnixNano()))
deterministicHKDFReader := newDeterministicReader(keyBytes, startTimeSalt, deterministicCertInfo)
b := make([]byte, 8)
if _, err := deterministicHKDFReader.Read(b); err != nil {
return nil, nil, err
}
serial := int64(binary.BigEndian.Uint64(b))
if serial < 0 {
serial = -serial
}
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(serial),
Subject: pkix.Name{},
NotBefore: start,
NotAfter: end,
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), deterministicHKDFReader)
if err != nil {
return nil, nil, err
}
caBytes, err := x509.CreateCertificate(deterministicHKDFReader, certTempl, certTempl, caPrivateKey.Public(), caPrivateKey)
if err != nil {
return nil, nil, err
}
ca, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, nil, err
}
return ca, caPrivateKey, nil
}
func verifyRawCerts(rawCerts [][]byte, certHashes []multihash.DecodedMultihash) error {
if len(rawCerts) < 1 {
return errors.New("no cert")
}
leaf := rawCerts[len(rawCerts)-1]
// The W3C WebTransport specification currently only allows SHA-256 certificates for serverCertificateHashes.
hash := sha256.Sum256(leaf)
var verified bool
for _, h := range certHashes {
if h.Code == multihash.SHA2_256 && bytes.Equal(h.Digest, hash[:]) {
verified = true
break
}
}
if !verified {
digests := make([][]byte, 0, len(certHashes))
for _, h := range certHashes {
digests = append(digests, h.Digest)
}
return fmt.Errorf("cert hash not found: %#x (expected: %#x)", hash, digests)
}
cert, err := x509.ParseCertificate(leaf)
if err != nil {
return err
}
// TODO: is this the best (and complete?) way to identify RSA certificates?
switch cert.SignatureAlgorithm {
case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA, x509.MD2WithRSA, x509.MD5WithRSA:
return errors.New("cert uses RSA")
}
if l := cert.NotAfter.Sub(cert.NotBefore); l > 14*24*time.Hour {
return fmt.Errorf("cert must not be valid for longer than 14 days (NotBefore: %s, NotAfter: %s, Length: %s)", cert.NotBefore, cert.NotAfter, l)
}
now := time.Now()
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
return fmt.Errorf("cert not valid (NotBefore: %s, NotAfter: %s)", cert.NotBefore, cert.NotAfter)
}
return nil
}
// deterministicReader is a hack. It counter-acts the Go library's attempt at
// making ECDSA signatures non-deterministic. Go adds non-determinism by
// randomly dropping a singly byte from the reader stream. This counteracts this
// by detecting when a read is a single byte and using a different reader
// instead.
type deterministicReader struct {
reader io.Reader
singleByteReader io.Reader
}
func newDeterministicReader(seed []byte, salt []byte, info string) io.Reader {
reader := hkdf.New(sha256.New, seed, salt, []byte(info))
singleByteReader := hkdf.New(sha256.New, seed, salt, []byte(info+" single byte"))
return &deterministicReader{
reader: reader,
singleByteReader: singleByteReader,
}
}
func (r *deterministicReader) Read(p []byte) (n int, err error) {
if len(p) == 1 {
return r.singleByteReader.Read(p)
}
return r.reader.Read(p)
}

View File

@@ -0,0 +1,216 @@
package libp2pwebtransport
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"time"
"github.com/libp2p/go-libp2p/core/network"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/security/noise"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
ma "github.com/multiformats/go-multiaddr"
"github.com/quic-go/webtransport-go"
)
const queueLen = 16
const handshakeTimeout = 10 * time.Second
type listener struct {
transport *transport
isStaticTLSConf bool
reuseListener quicreuse.Listener
server webtransport.Server
ctx context.Context
ctxCancel context.CancelFunc
serverClosed chan struct{} // is closed when server.Serve returns
addr net.Addr
multiaddr ma.Multiaddr
queue chan tpt.CapableConn
}
var _ tpt.Listener = &listener{}
func newListener(reuseListener quicreuse.Listener, t *transport, isStaticTLSConf bool) (tpt.Listener, error) {
localMultiaddr, err := toWebtransportMultiaddr(reuseListener.Addr())
if err != nil {
return nil, err
}
ln := &listener{
reuseListener: reuseListener,
transport: t,
isStaticTLSConf: isStaticTLSConf,
queue: make(chan tpt.CapableConn, queueLen),
serverClosed: make(chan struct{}),
addr: reuseListener.Addr(),
multiaddr: localMultiaddr,
server: webtransport.Server{
CheckOrigin: func(r *http.Request) bool { return true },
},
}
ln.ctx, ln.ctxCancel = context.WithCancel(context.Background())
mux := http.NewServeMux()
mux.HandleFunc(webtransportHTTPEndpoint, ln.httpHandler)
ln.server.H3.Handler = mux
go func() {
defer close(ln.serverClosed)
for {
conn, err := ln.reuseListener.Accept(context.Background())
if err != nil {
log.Debugw("serving failed", "addr", ln.Addr(), "error", err)
return
}
go ln.server.ServeQUICConn(conn)
}
}()
return ln, nil
}
func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) {
typ, ok := r.URL.Query()["type"]
if !ok || len(typ) != 1 || typ[0] != "noise" {
w.WriteHeader(http.StatusBadRequest)
return
}
remoteMultiaddr, err := stringToWebtransportMultiaddr(r.RemoteAddr)
if err != nil {
// This should never happen.
log.Errorw("converting remote address failed", "remote", r.RemoteAddr, "error", err)
w.WriteHeader(http.StatusBadRequest)
return
}
if l.transport.gater != nil && !l.transport.gater.InterceptAccept(&connMultiaddrs{local: l.multiaddr, remote: remoteMultiaddr}) {
w.WriteHeader(http.StatusForbidden)
return
}
connScope, err := l.transport.rcmgr.OpenConnection(network.DirInbound, false, remoteMultiaddr)
if err != nil {
log.Debugw("resource manager blocked incoming connection", "addr", r.RemoteAddr, "error", err)
w.WriteHeader(http.StatusServiceUnavailable)
return
}
err = l.httpHandlerWithConnScope(w, r, connScope)
if err != nil {
connScope.Done()
}
}
func (l *listener) httpHandlerWithConnScope(w http.ResponseWriter, r *http.Request, connScope network.ConnManagementScope) error {
sess, err := l.server.Upgrade(w, r)
if err != nil {
log.Debugw("upgrade failed", "error", err)
// TODO: think about the status code to use here
w.WriteHeader(500)
return err
}
ctx, cancel := context.WithTimeout(l.ctx, handshakeTimeout)
sconn, err := l.handshake(ctx, sess)
if err != nil {
cancel()
log.Debugw("handshake failed", "error", err)
sess.CloseWithError(1, "")
return err
}
cancel()
if l.transport.gater != nil && !l.transport.gater.InterceptSecured(network.DirInbound, sconn.RemotePeer(), sconn) {
// TODO: can we close with a specific error here?
sess.CloseWithError(errorCodeConnectionGating, "")
return errors.New("gater blocked connection")
}
if err := connScope.SetPeer(sconn.RemotePeer()); err != nil {
log.Debugw("resource manager blocked incoming connection for peer", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err)
sess.CloseWithError(1, "")
return err
}
conn := newConn(l.transport, sess, sconn, connScope)
l.transport.addConn(sess, conn)
select {
case l.queue <- conn:
default:
log.Debugw("accept queue full, dropping incoming connection", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err)
sess.CloseWithError(1, "")
return errors.New("accept queue full")
}
return nil
}
func (l *listener) Accept() (tpt.CapableConn, error) {
select {
case <-l.ctx.Done():
return nil, tpt.ErrListenerClosed
case c := <-l.queue:
return c, nil
}
}
func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*connSecurityMultiaddrs, error) {
local, err := toWebtransportMultiaddr(sess.LocalAddr())
if err != nil {
return nil, fmt.Errorf("error determiniting local addr: %w", err)
}
remote, err := toWebtransportMultiaddr(sess.RemoteAddr())
if err != nil {
return nil, fmt.Errorf("error determiniting remote addr: %w", err)
}
str, err := sess.AcceptStream(ctx)
if err != nil {
return nil, err
}
var earlyData [][]byte
if !l.isStaticTLSConf {
earlyData = l.transport.certManager.SerializedCertHashes()
}
n, err := l.transport.noise.WithSessionOptions(noise.EarlyData(
nil,
newEarlyDataSender(&pb.NoiseExtensions{WebtransportCerthashes: earlyData}),
))
if err != nil {
return nil, fmt.Errorf("failed to initialize Noise session: %w", err)
}
c, err := n.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "")
if err != nil {
return nil, err
}
return &connSecurityMultiaddrs{
ConnSecurity: c,
ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote},
}, nil
}
func (l *listener) Addr() net.Addr {
return l.addr
}
func (l *listener) Multiaddr() ma.Multiaddr {
if l.transport.certManager == nil {
return l.multiaddr
}
return l.multiaddr.Encapsulate(l.transport.certManager.AddrComponent())
}
func (l *listener) Close() error {
l.ctxCancel()
l.reuseListener.Close()
err := l.server.Close()
<-l.serverClosed
return err
}

View File

@@ -0,0 +1,113 @@
package libp2pwebtransport
import (
"errors"
"fmt"
"net"
"strconv"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multibase"
"github.com/multiformats/go-multihash"
)
var webtransportMA = ma.StringCast("/quic-v1/webtransport")
func toWebtransportMultiaddr(na net.Addr) (ma.Multiaddr, error) {
addr, err := manet.FromNetAddr(na)
if err != nil {
return nil, err
}
if _, err := addr.ValueForProtocol(ma.P_UDP); err != nil {
return nil, errors.New("not a UDP address")
}
return addr.Encapsulate(webtransportMA), nil
}
func stringToWebtransportMultiaddr(str string) (ma.Multiaddr, error) {
host, portStr, err := net.SplitHostPort(str)
if err != nil {
return nil, err
}
port, err := strconv.ParseInt(portStr, 10, 32)
if err != nil {
return nil, err
}
ip := net.ParseIP(host)
if ip == nil {
return nil, errors.New("failed to parse IP")
}
return toWebtransportMultiaddr(&net.UDPAddr{IP: ip, Port: int(port)})
}
func extractCertHashes(addr ma.Multiaddr) ([]multihash.DecodedMultihash, error) {
certHashesStr := make([]string, 0, 2)
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_CERTHASH {
certHashesStr = append(certHashesStr, c.Value())
}
return true
})
certHashes := make([]multihash.DecodedMultihash, 0, len(certHashesStr))
for _, s := range certHashesStr {
_, ch, err := multibase.Decode(s)
if err != nil {
return nil, fmt.Errorf("failed to multibase-decode certificate hash: %w", err)
}
dh, err := multihash.Decode(ch)
if err != nil {
return nil, fmt.Errorf("failed to multihash-decode certificate hash: %w", err)
}
certHashes = append(certHashes, *dh)
}
return certHashes, nil
}
func addrComponentForCert(hash []byte) (ma.Multiaddr, error) {
mh, err := multihash.Encode(hash, multihash.SHA2_256)
if err != nil {
return nil, err
}
certStr, err := multibase.Encode(multibase.Base58BTC, mh)
if err != nil {
return nil, err
}
return ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr)
}
// IsWebtransportMultiaddr returns true if the given multiaddr is a well formed
// webtransport multiaddr. Returns the number of certhashes found.
func IsWebtransportMultiaddr(multiaddr ma.Multiaddr) (bool, int) {
const (
init = iota
foundUDP
foundQuicV1
foundWebTransport
)
state := init
certhashCount := 0
ma.ForEach(multiaddr, func(c ma.Component) bool {
switch c.Protocol().Code {
case ma.P_UDP:
if state == init {
state = foundUDP
}
case ma.P_QUIC_V1:
if state == foundUDP {
state = foundQuicV1
}
case ma.P_WEBTRANSPORT:
if state == foundQuicV1 {
state = foundWebTransport
}
case ma.P_CERTHASH:
if state == foundWebTransport {
certhashCount++
}
}
return true
})
return state == foundWebTransport, certhashCount
}

View File

@@ -0,0 +1,36 @@
package libp2pwebtransport
import (
"context"
"net"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/security/noise"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
)
type earlyDataHandler struct {
earlyData *pb.NoiseExtensions
receive func(extensions *pb.NoiseExtensions) error
}
var _ noise.EarlyDataHandler = &earlyDataHandler{}
func newEarlyDataSender(earlyData *pb.NoiseExtensions) noise.EarlyDataHandler {
return &earlyDataHandler{earlyData: earlyData}
}
func newEarlyDataReceiver(receive func(*pb.NoiseExtensions) error) noise.EarlyDataHandler {
return &earlyDataHandler{receive: receive}
}
func (e *earlyDataHandler) Send(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions {
return e.earlyData
}
func (e *earlyDataHandler) Received(_ context.Context, _ net.Conn, ext *pb.NoiseExtensions) error {
if e.receive == nil {
return nil
}
return e.receive(ext)
}

View File

@@ -0,0 +1,71 @@
package libp2pwebtransport
import (
"errors"
"net"
"github.com/libp2p/go-libp2p/core/network"
"github.com/quic-go/webtransport-go"
)
const (
reset webtransport.StreamErrorCode = 0
)
type webtransportStream struct {
webtransport.Stream
wsess *webtransport.Session
}
var _ net.Conn = &webtransportStream{}
func (s *webtransportStream) LocalAddr() net.Addr {
return s.wsess.LocalAddr()
}
func (s *webtransportStream) RemoteAddr() net.Addr {
return s.wsess.RemoteAddr()
}
type stream struct {
webtransport.Stream
}
var _ network.MuxedStream = &stream{}
func (s *stream) Read(b []byte) (n int, err error) {
n, err = s.Stream.Read(b)
if err != nil && errors.Is(err, &webtransport.StreamError{}) {
err = network.ErrReset
}
return n, err
}
func (s *stream) Write(b []byte) (n int, err error) {
n, err = s.Stream.Write(b)
if err != nil && errors.Is(err, &webtransport.StreamError{}) {
err = network.ErrReset
}
return n, err
}
func (s *stream) Reset() error {
s.Stream.CancelRead(reset)
s.Stream.CancelWrite(reset)
return nil
}
func (s *stream) Close() error {
s.Stream.CancelRead(reset)
return s.Stream.Close()
}
func (s *stream) CloseRead() error {
s.Stream.CancelRead(reset)
return nil
}
func (s *stream) CloseWrite() error {
return s.Stream.Close()
}

View File

@@ -0,0 +1,419 @@
package libp2pwebtransport
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
"github.com/libp2p/go-libp2p/core/connmgr"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/pnet"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/security/noise"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
"github.com/benbjohnson/clock"
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multihash"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/webtransport-go"
)
var log = logging.Logger("webtransport")
const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport"
const errorCodeConnectionGating = 0x47415445 // GATE in ASCII
const certValidity = 14 * 24 * time.Hour
type Option func(*transport) error
func WithClock(cl clock.Clock) Option {
return func(t *transport) error {
t.clock = cl
return nil
}
}
// WithTLSClientConfig sets a custom tls.Config used for dialing.
// This option is most useful for setting a custom tls.Config.RootCAs certificate pool.
// When dialing a multiaddr that contains a /certhash component, this library will set InsecureSkipVerify and
// overwrite the VerifyPeerCertificate callback.
func WithTLSClientConfig(c *tls.Config) Option {
return func(t *transport) error {
t.tlsClientConf = c
return nil
}
}
type transport struct {
privKey ic.PrivKey
pid peer.ID
clock clock.Clock
connManager *quicreuse.ConnManager
rcmgr network.ResourceManager
gater connmgr.ConnectionGater
listenOnce sync.Once
listenOnceErr error
certManager *certManager
hasCertManager atomic.Bool // set to true once the certManager is initialized
staticTLSConf *tls.Config
tlsClientConf *tls.Config
noise *noise.Transport
connMx sync.Mutex
conns map[uint64]*conn // using quic-go's ConnectionTracingKey as map key
}
var _ tpt.Transport = &transport{}
var _ tpt.Resolver = &transport{}
var _ io.Closer = &transport{}
func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (tpt.Transport, error) {
if len(psk) > 0 {
log.Error("WebTransport doesn't support private networks yet.")
return nil, errors.New("WebTransport doesn't support private networks yet")
}
if rcmgr == nil {
rcmgr = &network.NullResourceManager{}
}
id, err := peer.IDFromPrivateKey(key)
if err != nil {
return nil, err
}
t := &transport{
pid: id,
privKey: key,
rcmgr: rcmgr,
gater: gater,
clock: clock.New(),
connManager: connManager,
conns: map[uint64]*conn{},
}
for _, opt := range opts {
if err := opt(t); err != nil {
return nil, err
}
}
n, err := noise.New(noise.ID, key, nil)
if err != nil {
return nil, err
}
t.noise = n
return t, nil
}
func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) {
scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr)
if err != nil {
log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err)
return nil, err
}
c, err := t.dialWithScope(ctx, raddr, p, scope)
if err != nil {
scope.Done()
return nil, err
}
return c, nil
}
func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) {
_, addr, err := manet.DialArgs(raddr)
if err != nil {
return nil, err
}
url := fmt.Sprintf("https://%s%s?type=noise", addr, webtransportHTTPEndpoint)
certHashes, err := extractCertHashes(raddr)
if err != nil {
return nil, err
}
if len(certHashes) == 0 {
return nil, errors.New("can't dial webtransport without certhashes")
}
sni, _ := extractSNI(raddr)
if err := scope.SetPeer(p); err != nil {
log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err)
return nil, err
}
maddr, _ := ma.SplitFunc(raddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_WEBTRANSPORT })
sess, err := t.dial(ctx, maddr, url, sni, certHashes)
if err != nil {
return nil, err
}
sconn, err := t.upgrade(ctx, sess, p, certHashes)
if err != nil {
sess.CloseWithError(1, "")
return nil, err
}
if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, sconn) {
sess.CloseWithError(errorCodeConnectionGating, "")
return nil, fmt.Errorf("secured connection gated")
}
conn := newConn(t, sess, sconn, scope)
t.addConn(sess, conn)
return conn, nil
}
func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) {
var tlsConf *tls.Config
if t.tlsClientConf != nil {
tlsConf = t.tlsClientConf.Clone()
} else {
tlsConf = &tls.Config{}
}
tlsConf.NextProtos = append(tlsConf.NextProtos, http3.NextProtoH3)
if sni != "" {
tlsConf.ServerName = sni
}
if len(certHashes) > 0 {
// This is not insecure. We verify the certificate ourselves.
// See https://www.w3.org/TR/webtransport/#certificate-hashes.
tlsConf.InsecureSkipVerify = true
tlsConf.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
return verifyRawCerts(rawCerts, certHashes)
}
}
conn, err := t.connManager.DialQUIC(ctx, addr, tlsConf, t.allowWindowIncrease)
if err != nil {
return nil, err
}
dialer := webtransport.Dialer{
RoundTripper: &http3.RoundTripper{
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
return conn.(quic.EarlyConnection), nil
},
},
}
rsp, sess, err := dialer.Dial(ctx, url, nil)
if err != nil {
return nil, err
}
if rsp.StatusCode < 200 || rsp.StatusCode > 299 {
return nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode)
}
return sess, err
}
func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (*connSecurityMultiaddrs, error) {
local, err := toWebtransportMultiaddr(sess.LocalAddr())
if err != nil {
return nil, fmt.Errorf("error determining local addr: %w", err)
}
remote, err := toWebtransportMultiaddr(sess.RemoteAddr())
if err != nil {
return nil, fmt.Errorf("error determining remote addr: %w", err)
}
str, err := sess.OpenStreamSync(ctx)
if err != nil {
return nil, err
}
defer str.Close()
// Now run a Noise handshake (using early data) and get all the certificate hashes from the server.
// We will verify that the certhashes we used to dial is a subset of the certhashes we received from the server.
var verified bool
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(func(b *pb.NoiseExtensions) error {
decodedCertHashes, err := decodeCertHashesFromProtobuf(b.WebtransportCerthashes)
if err != nil {
return err
}
for _, sent := range certHashes {
var found bool
for _, rcvd := range decodedCertHashes {
if sent.Code == rcvd.Code && bytes.Equal(sent.Digest, rcvd.Digest) {
found = true
break
}
}
if !found {
return fmt.Errorf("missing cert hash: %v", sent)
}
}
verified = true
return nil
}), nil))
if err != nil {
return nil, fmt.Errorf("failed to create Noise transport: %w", err)
}
c, err := n.SecureOutbound(ctx, &webtransportStream{Stream: str, wsess: sess}, p)
if err != nil {
return nil, err
}
defer c.Close()
// The Noise handshake _should_ guarantee that our verification callback is called.
// Double-check just in case.
if !verified {
return nil, errors.New("didn't verify")
}
return &connSecurityMultiaddrs{
ConnSecurity: c,
ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote},
}, nil
}
func decodeCertHashesFromProtobuf(b [][]byte) ([]multihash.DecodedMultihash, error) {
hashes := make([]multihash.DecodedMultihash, 0, len(b))
for _, h := range b {
dh, err := multihash.Decode(h)
if err != nil {
return nil, fmt.Errorf("failed to decode hash: %w", err)
}
hashes = append(hashes, *dh)
}
return hashes, nil
}
func (t *transport) CanDial(addr ma.Multiaddr) bool {
ok, _ := IsWebtransportMultiaddr(addr)
return ok
}
func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) {
isWebTransport, certhashCount := IsWebtransportMultiaddr(laddr)
if !isWebTransport {
return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr)
}
if certhashCount > 0 {
return nil, fmt.Errorf("cannot listen on a specific certhash non-WebTransport addr: %s", laddr)
}
if t.staticTLSConf == nil {
t.listenOnce.Do(func() {
t.certManager, t.listenOnceErr = newCertManager(t.privKey, t.clock)
t.hasCertManager.Store(true)
})
if t.listenOnceErr != nil {
return nil, t.listenOnceErr
}
} else {
return nil, errors.New("static TLS config not supported on WebTransport")
}
tlsConf := t.staticTLSConf.Clone()
if tlsConf == nil {
tlsConf = &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return t.certManager.GetConfig(), nil
}}
}
tlsConf.NextProtos = append(tlsConf.NextProtos, http3.NextProtoH3)
ln, err := t.connManager.ListenQUIC(laddr, tlsConf, t.allowWindowIncrease)
if err != nil {
return nil, err
}
return newListener(ln, t, t.staticTLSConf != nil)
}
func (t *transport) Protocols() []int {
return []int{ma.P_WEBTRANSPORT}
}
func (t *transport) Proxy() bool {
return false
}
func (t *transport) Close() error {
t.listenOnce.Do(func() {})
if t.certManager != nil {
return t.certManager.Close()
}
return nil
}
func (t *transport) allowWindowIncrease(conn quic.Connection, size uint64) bool {
t.connMx.Lock()
defer t.connMx.Unlock()
c, ok := t.conns[conn.Context().Value(quic.ConnectionTracingKey).(uint64)]
if !ok {
return false
}
return c.allowWindowIncrease(size)
}
func (t *transport) addConn(sess *webtransport.Session, c *conn) {
t.connMx.Lock()
t.conns[sess.Context().Value(quic.ConnectionTracingKey).(uint64)] = c
t.connMx.Unlock()
}
func (t *transport) removeConn(sess *webtransport.Session) {
t.connMx.Lock()
delete(t.conns, sess.Context().Value(quic.ConnectionTracingKey).(uint64))
t.connMx.Unlock()
}
// extractSNI returns what the SNI should be for the given maddr. If there is an
// SNI component in the multiaddr, then it will be returned and
// foundSniComponent will be true. If there's no SNI component, but there is a
// DNS-like component, then that will be returned for the sni and
// foundSniComponent will be false (since we didn't find an actual sni component).
func extractSNI(maddr ma.Multiaddr) (sni string, foundSniComponent bool) {
ma.ForEach(maddr, func(c ma.Component) bool {
switch c.Protocol().Code {
case ma.P_SNI:
sni = c.Value()
foundSniComponent = true
return false
case ma.P_DNS, ma.P_DNS4, ma.P_DNS6, ma.P_DNSADDR:
sni = c.Value()
// Keep going in case we find an `sni` component
return true
}
return true
})
return sni, foundSniComponent
}
// Resolve implements transport.Resolver
func (t *transport) Resolve(_ context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) {
sni, foundSniComponent := extractSNI(maddr)
if foundSniComponent || sni == "" {
// The multiaddr already had an sni field, we can keep using it. Or we don't have any sni like thing
return []ma.Multiaddr{maddr}, nil
}
beforeQuicMA, afterIncludingQuicMA := ma.SplitFunc(maddr, func(c ma.Component) bool {
return c.Protocol().Code == ma.P_QUIC_V1
})
quicComponent, afterQuicMA := ma.SplitFirst(afterIncludingQuicMA)
sniComponent, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_SNI).Name, sni)
if err != nil {
return nil, err
}
return []ma.Multiaddr{beforeQuicMA.Encapsulate(quicComponent).Encapsulate(sniComponent).Encapsulate(afterQuicMA)}, nil
}
// AddCertHashes adds the current certificate hashes to a multiaddress.
// If called before Listen, it's a no-op.
func (t *transport) AddCertHashes(m ma.Multiaddr) (ma.Multiaddr, bool) {
if !t.hasCertManager.Load() {
return m, false
}
return m.Encapsulate(t.certManager.AddrComponent()), true
}