Integrate BACKBEAT SDK and resolve KACHING license validation
Major integrations and fixes: - Added BACKBEAT SDK integration for P2P operation timing - Implemented beat-aware status tracking for distributed operations - Added Docker secrets support for secure license management - Resolved KACHING license validation via HTTPS/TLS - Updated docker-compose configuration for clean stack deployment - Disabled rollback policies to prevent deployment failures - Added license credential storage (CHORUS-DEV-MULTI-001) Technical improvements: - BACKBEAT P2P operation tracking with phase management - Enhanced configuration system with file-based secrets - Improved error handling for license validation - Clean separation of KACHING and CHORUS deployment stacks 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
17
vendor/github.com/quic-go/quic-go/.gitignore
generated
vendored
Normal file
17
vendor/github.com/quic-go/quic-go/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
debug
|
||||
debug.test
|
||||
main
|
||||
mockgen_tmp.go
|
||||
*.qtr
|
||||
*.qlog
|
||||
*.txt
|
||||
race.[0-9]*
|
||||
|
||||
fuzzing/*/*.zip
|
||||
fuzzing/*/coverprofile
|
||||
fuzzing/*/crashers
|
||||
fuzzing/*/sonarprofile
|
||||
fuzzing/*/suppressions
|
||||
fuzzing/*/corpus/
|
||||
|
||||
gomock_reflect_*/
|
||||
44
vendor/github.com/quic-go/quic-go/.golangci.yml
generated
vendored
Normal file
44
vendor/github.com/quic-go/quic-go/.golangci.yml
generated
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
run:
|
||||
skip-files:
|
||||
- internal/handshake/cipher_suite.go
|
||||
linters-settings:
|
||||
depguard:
|
||||
type: blacklist
|
||||
packages:
|
||||
- github.com/marten-seemann/qtls
|
||||
- github.com/quic-go/qtls-go1-19
|
||||
- github.com/quic-go/qtls-go1-20
|
||||
packages-with-error-message:
|
||||
- github.com/marten-seemann/qtls: "importing qtls only allowed in internal/qtls"
|
||||
- github.com/quic-go/qtls-go1-19: "importing qtls only allowed in internal/qtls"
|
||||
- github.com/quic-go/qtls-go1-20: "importing qtls only allowed in internal/qtls"
|
||||
misspell:
|
||||
ignore-words:
|
||||
- ect
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
- asciicheck
|
||||
- depguard
|
||||
- exhaustive
|
||||
- exportloopref
|
||||
- goimports
|
||||
- gofmt # redundant, since gofmt *should* be a no-op after gofumpt
|
||||
- gofumpt
|
||||
- gosimple
|
||||
- ineffassign
|
||||
- misspell
|
||||
- prealloc
|
||||
- staticcheck
|
||||
- stylecheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- unused
|
||||
- vet
|
||||
|
||||
issues:
|
||||
exclude-rules:
|
||||
- path: internal/qtls
|
||||
linters:
|
||||
- depguard
|
||||
109
vendor/github.com/quic-go/quic-go/Changelog.md
generated
vendored
Normal file
109
vendor/github.com/quic-go/quic-go/Changelog.md
generated
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
# Changelog
|
||||
|
||||
## v0.22.0 (2021-07-25)
|
||||
|
||||
- Use `ReadBatch` to read multiple UDP packets from the socket with a single syscall
|
||||
- Add a config option (`Config.DisableVersionNegotiationPackets`) to disable sending of Version Negotiation packets
|
||||
- Drop support for QUIC draft versions 32 and 34
|
||||
- Remove the `RetireBugBackwardsCompatibilityMode`, which was intended to mitigate a bug when retiring connection IDs in quic-go in v0.17.2 and ealier
|
||||
|
||||
## v0.21.2 (2021-07-15)
|
||||
|
||||
- Update qtls (for Go 1.15, 1.16 and 1.17rc1) to include the fix for the crypto/tls panic (see https://groups.google.com/g/golang-dev/c/5LJ2V7rd-Ag/m/YGLHVBZ6AAAJ for details)
|
||||
|
||||
## v0.21.0 (2021-06-01)
|
||||
|
||||
- quic-go now supports RFC 9000!
|
||||
|
||||
## v0.20.0 (2021-03-19)
|
||||
|
||||
- Remove the `quic.Config.HandshakeTimeout`. Introduce a `quic.Config.HandshakeIdleTimeout`.
|
||||
|
||||
## v0.17.1 (2020-06-20)
|
||||
|
||||
- Supports QUIC WG draft-29.
|
||||
- Improve bundling of ACK frames (#2543).
|
||||
|
||||
## v0.16.0 (2020-05-31)
|
||||
|
||||
- Supports QUIC WG draft-28.
|
||||
|
||||
## v0.15.0 (2020-03-01)
|
||||
|
||||
- Supports QUIC WG draft-27.
|
||||
- Add support for 0-RTT.
|
||||
- Remove `Session.Close()`. Applications need to pass an application error code to the transport using `Session.CloseWithError()`.
|
||||
- Make the TLS Cipher Suites configurable (via `tls.Config.CipherSuites`).
|
||||
|
||||
## v0.14.0 (2019-12-04)
|
||||
|
||||
- Supports QUIC WG draft-24.
|
||||
|
||||
## v0.13.0 (2019-11-05)
|
||||
|
||||
- Supports QUIC WG draft-23.
|
||||
- Add an `EarlyListener` that allows sending of 0.5-RTT data.
|
||||
- Add a `TokenStore` to store address validation tokens.
|
||||
- Issue and use new connection IDs during a connection.
|
||||
|
||||
## v0.12.0 (2019-08-05)
|
||||
|
||||
- Implement HTTP/3.
|
||||
- Rename `quic.Cookie` to `quic.Token` and `quic.Config.AcceptCookie` to `quic.Config.AcceptToken`.
|
||||
- Distinguish between Retry tokens and tokens sent in NEW_TOKEN frames.
|
||||
- Enforce application protocol negotiation (via `tls.Config.NextProtos`).
|
||||
- Use a varint for error codes.
|
||||
- Add support for [quic-trace](https://github.com/google/quic-trace).
|
||||
- Add a context to `Listener.Accept`, `Session.Accept{Uni}Stream` and `Session.Open{Uni}StreamSync`.
|
||||
- Implement TLS key updates.
|
||||
|
||||
## v0.11.0 (2019-04-05)
|
||||
|
||||
- Drop support for gQUIC. For qQUIC support, please switch to the *gquic* branch.
|
||||
- Implement QUIC WG draft-19.
|
||||
- Use [qtls](https://github.com/marten-seemann/qtls) for TLS 1.3.
|
||||
- Return a `tls.ConnectionState` from `quic.Session.ConnectionState()`.
|
||||
- Remove the error return values from `quic.Stream.CancelRead()` and `quic.Stream.CancelWrite()`
|
||||
|
||||
## v0.10.0 (2018-08-28)
|
||||
|
||||
- Add support for QUIC 44, drop support for QUIC 42.
|
||||
|
||||
## v0.9.0 (2018-08-15)
|
||||
|
||||
- Add a `quic.Config` option for the length of the connection ID (for IETF QUIC).
|
||||
- Split Session.Close into one method for regular closing and one for closing with an error.
|
||||
|
||||
## v0.8.0 (2018-06-26)
|
||||
|
||||
- Add support for unidirectional streams (for IETF QUIC).
|
||||
- Add a `quic.Config` option for the maximum number of incoming streams.
|
||||
- Add support for QUIC 42 and 43.
|
||||
- Add dial functions that use a context.
|
||||
- Multiplex clients on a net.PacketConn, when using Dial(conn).
|
||||
|
||||
## v0.7.0 (2018-02-03)
|
||||
|
||||
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
|
||||
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
|
||||
- Expose the `ConnectionState` in the `Session` (experimental API).
|
||||
- Implement packet pacing.
|
||||
|
||||
## v0.6.0 (2017-12-12)
|
||||
|
||||
- Add support for QUIC 39, drop support for QUIC 35 - 37
|
||||
- Added `quic.Config` options for maximal flow control windows
|
||||
- Add a `quic.Config` option for QUIC versions
|
||||
- Add a `quic.Config` option to request omission of the connection ID from a server
|
||||
- Add a `quic.Config` option to configure the source address validation
|
||||
- Add a `quic.Config` option to configure the handshake timeout
|
||||
- Add a `quic.Config` option to configure the idle timeout
|
||||
- Add a `quic.Config` option to configure keep-alive
|
||||
- Rename the STK to Cookie
|
||||
- Implement `net.Conn`-style deadlines for streams
|
||||
- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/quic-go/quic-go) for details.
|
||||
- Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/quic-go/quic-go/wiki/Logging) for more details.
|
||||
- Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper`
|
||||
- Changed `h2quic.Server.Serve()` to accept a `net.PacketConn`
|
||||
- Drop support for Go 1.7 and 1.8.
|
||||
- Various bugfixes
|
||||
21
vendor/github.com/quic-go/quic-go/LICENSE
generated
vendored
Normal file
21
vendor/github.com/quic-go/quic-go/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2016 the quic-go authors & Google, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
230
vendor/github.com/quic-go/quic-go/README.md
generated
vendored
Normal file
230
vendor/github.com/quic-go/quic-go/README.md
generated
vendored
Normal file
@@ -0,0 +1,230 @@
|
||||
# A QUIC implementation in pure Go
|
||||
|
||||
<img src="docs/quic.png" width=303 height=124>
|
||||
|
||||
[](https://pkg.go.dev/github.com/quic-go/quic-go)
|
||||
[](https://codecov.io/gh/quic-go/quic-go/)
|
||||
[](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:quic-go)
|
||||
|
||||
quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)).
|
||||
|
||||
In addition to these base RFCs, it also implements the following RFCs:
|
||||
* Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221))
|
||||
* Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899))
|
||||
* QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369))
|
||||
|
||||
## Using QUIC
|
||||
|
||||
### Running a Server
|
||||
|
||||
The central entry point is the `quic.Transport`. A transport manages QUIC connections running on a single UDP socket. Since QUIC uses Connection IDs, it can demultiplex a listener (accepting incoming connections) and an arbitrary number of outgoing QUIC connections on the same UDP socket.
|
||||
|
||||
```go
|
||||
udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: 1234})
|
||||
// ... error handling
|
||||
tr := quic.Transport{
|
||||
Conn: udpConn,
|
||||
}
|
||||
ln, err := tr.Listen(tlsConf, quicConf)
|
||||
// ... error handling
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
// ... error handling
|
||||
// handle the connection, usually in a new Go routine
|
||||
}
|
||||
}()
|
||||
```
|
||||
|
||||
The listener `ln` can now be used to accept incoming QUIC connections by (repeatedly) calling the `Accept` method (see below for more information on the `quic.Connection`).
|
||||
|
||||
As a shortcut, `quic.Listen` and `quic.ListenAddr` can be used without explicitly initializing a `quic.Transport`:
|
||||
|
||||
```
|
||||
ln, err := quic.Listen(udpConn, tlsConf, quicConf)
|
||||
```
|
||||
|
||||
When using the shortcut, it's not possible to reuse the same UDP socket for outgoing connections.
|
||||
|
||||
### Running a Client
|
||||
|
||||
As mentioned above, multiple outgoing connections can share a single UDP socket, since QUIC uses Connection IDs to demultiplex connections.
|
||||
|
||||
```go
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
|
||||
defer cancel()
|
||||
conn, err := tr.Dial(ctx, <server address>, <tls.Config>, <quic.Config>)
|
||||
// ... error handling
|
||||
```
|
||||
|
||||
As a shortcut, `quic.Dial` and `quic.DialAddr` can be used without explictly initializing a `quic.Transport`:
|
||||
|
||||
```go
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
|
||||
defer cancel()
|
||||
conn, err := quic.Dial(ctx, conn, <server address>, <tls.Config>, <quic.Config>)
|
||||
```
|
||||
|
||||
Just as we saw before when used a similar shortcut to run a server, it's also not possible to reuse the same UDP socket for other outgoing connections, or to listen for incoming connections.
|
||||
|
||||
### Using a QUIC Connection
|
||||
|
||||
#### Accepting Streams
|
||||
|
||||
QUIC is a stream-multiplexed transport. A `quic.Connection` fundamentally differs from the `net.Conn` and the `net.PacketConn` interface defined in the standard library. Data is sent and received on (unidirectional and bidirectional) streams (and, if supported, in [datagrams](#quic-datagrams)), not on the connection itself. The stream state machine is described in detail in [Section 3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-3).
|
||||
|
||||
Note: A unidirectional stream is a stream that the initiator can only write to (`quic.SendStream`), and the receiver can only read from (`quic.ReceiveStream`). A bidirectional stream (`quic.Stream`) allows reading from and writing to for both sides.
|
||||
|
||||
On the receiver side, streams are accepted using the `AcceptStream` (for bidirectional) and `AcceptUniStream` functions. For most user cases, it makes sense to call these functions in a loop:
|
||||
|
||||
```go
|
||||
for {
|
||||
str, err := conn.AcceptStream(context.Background()) // for bidirectional streams
|
||||
// ... error handling
|
||||
// handle the stream, usually in a new Go routine
|
||||
}
|
||||
```
|
||||
|
||||
These functions return an error when the underlying QUIC connection is closed.
|
||||
|
||||
#### Opening Streams
|
||||
|
||||
There are two slightly different ways to open streams, one synchronous and one (potentially) asynchronous. This API is necessary since the receiver grants us a certain number of streams that we're allowed to open. It may grant us additional streams later on (typically when existing streams are closed), but it means that at the time we want to open a new stream, we might not be able to do so.
|
||||
|
||||
Using the synchronous method `OpenStreamSync` for bidirectional streams, and `OpenUniStreamSync` for unidirectional streams, an application can block until the peer allows opening additional streams. In case that we're allowed to open a new stream, these methods return right away:
|
||||
|
||||
```go
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
str, err := conn.OpenStreamSync(ctx) // wait up to 5s to open a new bidirectional stream
|
||||
```
|
||||
|
||||
The asynchronous version never blocks. If it's currently not possible to open a new stream, it returns a `net.Error` timeout error:
|
||||
|
||||
```go
|
||||
str, err := conn.OpenStream()
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
// It's currently not possible to open another stream,
|
||||
// but it might be possible later, once the peer allowed us to do so.
|
||||
}
|
||||
```
|
||||
|
||||
These functions return an error when the underlying QUIC connection is closed.
|
||||
|
||||
#### Using Streams
|
||||
|
||||
Using QUIC streams is pretty straightforward. The `quic.ReceiveStream` implements the `io.Reader` interface, and the `quic.SendStream` implements the `io.Writer` interface. A bidirectional stream (`quic.Stream`) implements both these interfaces. Conceptually, a bidirectional stream can be thought of as the composition of two unidirectional streams in opposite directions.
|
||||
|
||||
Calling `Close` on a `quic.SendStream` or a `quic.Stream` closes the send side of the stream. On the receiver side, this will be surfaced as an `io.EOF` returned from the `io.Reader` once all data has been consumed. Note that for bidirectional streams, `Close` _only_ closes the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
|
||||
|
||||
In case the application wishes to abort sending on a `quic.SendStream` or a `quic.Stream` , it can reset the send side by calling `CancelWrite` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Reader`. Note that for bidirectional streams, `CancelWrite` _only_ resets the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
|
||||
|
||||
Conversely, in case the application wishes to abort receiving from a `quic.ReceiveStream` or a `quic.Stream`, it can ask the sender to abort data transmission by calling `CancelRead` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Writer`. Note that for bidirectional streams, `CancelWrite` _only_ resets the receive side of the stream. It is still possible to write to the stream.
|
||||
|
||||
A bidirectional stream is only closed once both the read and the write side of the stream have been either closed and reset. Only then the peer is granted a new stream according to the maximum number of concurrent streams configured via `quic.Config.MaxIncomingStreams`.
|
||||
|
||||
### Configuring QUIC
|
||||
|
||||
The `quic.Config` struct passed to both the listen and dial calls (see above) contains a wide range of configuration options for QUIC connections, incl. the ability to fine-tune flow control limits, the number of streams that the peer is allowed to open concurrently, keep-alives, idle timeouts, and many more. Please refer to the documentation for the `quic.Config` for details.
|
||||
|
||||
The `quic.Transport` contains a few configuration options that don't apply to any single QUIC connection, but to all connections handled by that transport. It is highly recommend to set the `StatelessResetToken`, which allows endpoints to quickly recover from crashes / reboots of our node (see [Section 10.3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-10.3)).
|
||||
|
||||
### Closing a Connection
|
||||
|
||||
#### When the remote Peer closes the Connection
|
||||
|
||||
In case the peer closes the QUIC connection, all calls to open streams, accept streams, as well as all methods on streams immediately return an error. Additionally, it is set as cancellation cause of the connection context. Users can use errors assertions to find out what exactly went wrong:
|
||||
|
||||
* `quic.VersionNegotiationError`: Happens during the handshake, if there is no overlap between our and the remote's supported QUIC versions.
|
||||
* `quic.HandshakeTimeoutError`: Happens if the QUIC handshake doesn't complete within the time specified in `quic.Config.HandshakeTimeout`.
|
||||
* `quic.IdleTimeoutError`: Happens after completion of the handshake if the connection is idle for longer than the minimum of both peers idle timeouts (as configured by `quic.Config.IdleTimeout`). The connection is considered idle when no stream data (and datagrams, if applicable) are exchanged for that period. The QUIC connection can be instructed to regularly send a packet to prevent a connection from going idle by setting `quic.Config.KeepAlive`. However, this is no guarantee that the peer doesn't suddenly go away (e.g. by abruptly shutting down the node or by crashing), or by a NAT binding expiring, in which case this error might still occur.
|
||||
* `quic.StatelessResetError`: Happens when the remote peer lost the state required to decrypt the packet. This requires the `quic.Transport.StatelessResetToken` to be configured by the peer.
|
||||
* `quic.TransportError`: Happens if when the QUIC protocol is violated. Unless the error code is `APPLICATION_ERROR`, this will not happen unless one of the QUIC stacks involved is misbehaving. Please open an issue if you encounter this error.
|
||||
* `quic.ApplicationError`: Happens when the remote decides to close the connection, see below.
|
||||
|
||||
#### Initiated by the Application
|
||||
|
||||
A `quic.Connection` can be closed using `CloseWithError`:
|
||||
|
||||
```go
|
||||
conn.CloseWithError(0x42, "error 0x42 occurred")
|
||||
```
|
||||
|
||||
Applications can transmit both an error code (an unsigned 62-bit number) as well as a UTF-8 encoded human-readable reason. The error code allows the receiver to learn why the connection was closed, and the reason can be useful for debugging purposes.
|
||||
|
||||
On the receiver side, this is surfaced as a `quic.ApplicationError`.
|
||||
|
||||
### QUIC Datagrams
|
||||
|
||||
Unreliable datagrams are a QUIC extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) that is negotiated during the handshake. Support can be enabled by setting the `quic.Config.EnableDatagram` flag. Note that this doesn't guarantee that the peer also supports datagrams. Whether or not the feature negotiation succeeded can be learned from the `quic.ConnectionState.SupportsDatagrams` obtained from `quic.Connection.ConnectionState()`.
|
||||
|
||||
QUIC DATAGRAMs are a new QUIC frame type sent in QUIC 1-RTT packets (i.e. after completion of the handshake). Therefore, they're end-to-end encrypted and congestion-controlled. However, if a DATAGRAM frame is deemed lost by QUIC's loss detection mechanism, they are not automatically retransmitted.
|
||||
|
||||
Datagrams are sent using the `SendMessage` method on the `quic.Connection`:
|
||||
|
||||
```go
|
||||
conn.SendMessage([]byte("foobar"))
|
||||
```
|
||||
|
||||
And received using `ReceiveMessage`:
|
||||
|
||||
```go
|
||||
msg, err := conn.ReceiveMessage()
|
||||
```
|
||||
|
||||
Note that this code path is currently not optimized. It works for datagrams that are sent occasionally, but it doesn't achieve the same throughput as writing data on a stream. Please get in touch on issue #3766 if your use case relies on high datagram throughput, or if you'd like to help fix this issue. There are also some restrictions regarding the maximum message size (see #3599).
|
||||
|
||||
|
||||
|
||||
## Using HTTP/3
|
||||
|
||||
### As a server
|
||||
|
||||
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard library http package in Go:
|
||||
|
||||
```go
|
||||
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
|
||||
http3.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil)
|
||||
```
|
||||
|
||||
### As a client
|
||||
|
||||
See the [example client](example/client/main.go). Use a `http3.RoundTripper` as a `Transport` in a `http.Client`.
|
||||
|
||||
```go
|
||||
http.Client{
|
||||
Transport: &http3.RoundTripper{},
|
||||
}
|
||||
```
|
||||
|
||||
## Projects using quic-go
|
||||
|
||||
| Project | Description | Stars |
|
||||
| --------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
|
||||
| [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. |  |
|
||||
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support |  |
|
||||
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS |  |
|
||||
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins |  |
|
||||
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others |  |
|
||||
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy |  |
|
||||
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications |  |
|
||||
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. |  |
|
||||
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization |  |
|
||||
| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy |  |
|
||||
| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions |  |
|
||||
| [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System |  |
|
||||
|
||||
If you'd like to see your project added to this list, please send us a PR.
|
||||
|
||||
## Release Policy
|
||||
|
||||
quic-go always aims to support the latest two Go releases.
|
||||
|
||||
### Dependency on forked crypto/tls
|
||||
|
||||
Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20).
|
||||
This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward.
|
||||
|
||||
## Contributing
|
||||
|
||||
We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/quic-go/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment.
|
||||
19
vendor/github.com/quic-go/quic-go/SECURITY.md
generated
vendored
Normal file
19
vendor/github.com/quic-go/quic-go/SECURITY.md
generated
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
# Security Policy
|
||||
|
||||
quic-go still in development. This means that there may be problems in our protocols,
|
||||
or there may be mistakes in our implementations.
|
||||
We take security vulnerabilities very seriously. If you discover a security issue,
|
||||
please bring it to our attention right away!
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
If you find a vulnerability that may affect live deployments -- for example, by exposing
|
||||
a remote execution exploit -- please [**report privately**](https://github.com/quic-go/quic-go/security/advisories/new).
|
||||
Please **DO NOT file a public issue**.
|
||||
|
||||
If the issue is an implementation weakness that cannot be immediately exploited or
|
||||
something not yet deployed, just discuss it openly.
|
||||
|
||||
## Reporting a non security bug
|
||||
|
||||
For non-security bugs, please simply file a GitHub [issue](https://github.com/quic-go/quic-go/issues/new).
|
||||
92
vendor/github.com/quic-go/quic-go/buffer_pool.go
generated
vendored
Normal file
92
vendor/github.com/quic-go/quic-go/buffer_pool.go
generated
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type packetBuffer struct {
|
||||
Data []byte
|
||||
|
||||
// refCount counts how many packets Data is used in.
|
||||
// It doesn't support concurrent use.
|
||||
// It is > 1 when used for coalesced packet.
|
||||
refCount int
|
||||
}
|
||||
|
||||
// Split increases the refCount.
|
||||
// It must be called when a packet buffer is used for more than one packet,
|
||||
// e.g. when splitting coalesced packets.
|
||||
func (b *packetBuffer) Split() {
|
||||
b.refCount++
|
||||
}
|
||||
|
||||
// Decrement decrements the reference counter.
|
||||
// It doesn't put the buffer back into the pool.
|
||||
func (b *packetBuffer) Decrement() {
|
||||
b.refCount--
|
||||
if b.refCount < 0 {
|
||||
panic("negative packetBuffer refCount")
|
||||
}
|
||||
}
|
||||
|
||||
// MaybeRelease puts the packet buffer back into the pool,
|
||||
// if the reference counter already reached 0.
|
||||
func (b *packetBuffer) MaybeRelease() {
|
||||
// only put the packetBuffer back if it's not used any more
|
||||
if b.refCount == 0 {
|
||||
b.putBack()
|
||||
}
|
||||
}
|
||||
|
||||
// Release puts back the packet buffer into the pool.
|
||||
// It should be called when processing is definitely finished.
|
||||
func (b *packetBuffer) Release() {
|
||||
b.Decrement()
|
||||
if b.refCount != 0 {
|
||||
panic("packetBuffer refCount not zero")
|
||||
}
|
||||
b.putBack()
|
||||
}
|
||||
|
||||
// Len returns the length of Data
|
||||
func (b *packetBuffer) Len() protocol.ByteCount { return protocol.ByteCount(len(b.Data)) }
|
||||
func (b *packetBuffer) Cap() protocol.ByteCount { return protocol.ByteCount(cap(b.Data)) }
|
||||
|
||||
func (b *packetBuffer) putBack() {
|
||||
if cap(b.Data) == protocol.MaxPacketBufferSize {
|
||||
bufferPool.Put(b)
|
||||
return
|
||||
}
|
||||
if cap(b.Data) == protocol.MaxLargePacketBufferSize {
|
||||
largeBufferPool.Put(b)
|
||||
return
|
||||
}
|
||||
panic("putPacketBuffer called with packet of wrong size!")
|
||||
}
|
||||
|
||||
var bufferPool, largeBufferPool sync.Pool
|
||||
|
||||
func getPacketBuffer() *packetBuffer {
|
||||
buf := bufferPool.Get().(*packetBuffer)
|
||||
buf.refCount = 1
|
||||
buf.Data = buf.Data[:0]
|
||||
return buf
|
||||
}
|
||||
|
||||
func getLargePacketBuffer() *packetBuffer {
|
||||
buf := largeBufferPool.Get().(*packetBuffer)
|
||||
buf.refCount = 1
|
||||
buf.Data = buf.Data[:0]
|
||||
return buf
|
||||
}
|
||||
|
||||
func init() {
|
||||
bufferPool.New = func() any {
|
||||
return &packetBuffer{Data: make([]byte, 0, protocol.MaxPacketBufferSize)}
|
||||
}
|
||||
largeBufferPool.New = func() any {
|
||||
return &packetBuffer{Data: make([]byte, 0, protocol.MaxLargePacketBufferSize)}
|
||||
}
|
||||
}
|
||||
251
vendor/github.com/quic-go/quic-go/client.go
generated
vendored
Normal file
251
vendor/github.com/quic-go/quic-go/client.go
generated
vendored
Normal file
@@ -0,0 +1,251 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
sendConn sendConn
|
||||
|
||||
use0RTT bool
|
||||
|
||||
packetHandlers packetHandlerManager
|
||||
onClose func()
|
||||
|
||||
tlsConf *tls.Config
|
||||
config *Config
|
||||
|
||||
connIDGenerator ConnectionIDGenerator
|
||||
srcConnID protocol.ConnectionID
|
||||
destConnID protocol.ConnectionID
|
||||
|
||||
initialPacketNumber protocol.PacketNumber
|
||||
hasNegotiatedVersion bool
|
||||
version protocol.VersionNumber
|
||||
|
||||
handshakeChan chan struct{}
|
||||
|
||||
conn quicConn
|
||||
|
||||
tracer *logging.ConnectionTracer
|
||||
tracingID uint64
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// make it possible to mock connection ID for initial generation in the tests
|
||||
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
||||
|
||||
// DialAddr establishes a new QUIC connection to a server.
|
||||
// It resolves the address, and then creates a new UDP connection to dial the QUIC server.
|
||||
// When the QUIC connection is closed, this UDP connection is closed.
|
||||
// See Dial for more details.
|
||||
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tr, err := setupTransport(udpConn, tlsConf, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tr.dial(ctx, udpAddr, addr, tlsConf, conf, false)
|
||||
}
|
||||
|
||||
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
|
||||
// See DialAddr for more details.
|
||||
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tr, err := setupTransport(udpConn, tlsConf, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, true)
|
||||
if err != nil {
|
||||
tr.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
|
||||
// See Dial for more details.
|
||||
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
|
||||
dl, err := setupTransport(c, tlsConf, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := dl.DialEarly(ctx, addr, tlsConf, conf)
|
||||
if err != nil {
|
||||
dl.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does),
|
||||
// ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
|
||||
// will be used instead of ReadFrom and WriteTo to read/write packets.
|
||||
// The tls.Config must define an application protocol (using NextProtos).
|
||||
//
|
||||
// This is a convenience function. More advanced use cases should instantiate a Transport,
|
||||
// which offers configuration options for a more fine-grained control of the connection establishment,
|
||||
// including reusing the underlying UDP socket for multiple QUIC connections.
|
||||
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
|
||||
dl, err := setupTransport(c, tlsConf, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := dl.Dial(ctx, addr, tlsConf, conf)
|
||||
if err != nil {
|
||||
dl.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) {
|
||||
if tlsConf == nil {
|
||||
return nil, errors.New("quic: tls.Config not set")
|
||||
}
|
||||
return &Transport{
|
||||
Conn: c,
|
||||
createdConn: createdPacketConn,
|
||||
isSingleUse: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func dial(
|
||||
ctx context.Context,
|
||||
conn sendConn,
|
||||
connIDGenerator ConnectionIDGenerator,
|
||||
packetHandlers packetHandlerManager,
|
||||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
onClose func(),
|
||||
use0RTT bool,
|
||||
) (quicConn, error) {
|
||||
c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.packetHandlers = packetHandlers
|
||||
|
||||
c.tracingID = nextConnTracingID()
|
||||
if c.config.Tracer != nil {
|
||||
c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
|
||||
}
|
||||
if c.tracer != nil && c.tracer.StartedConnection != nil {
|
||||
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
|
||||
}
|
||||
if err := c.dial(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.conn, nil
|
||||
}
|
||||
|
||||
func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
|
||||
srcConnID, err := connIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
destConnID, err := generateConnectionIDForInitial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &client{
|
||||
connIDGenerator: connIDGenerator,
|
||||
srcConnID: srcConnID,
|
||||
destConnID: destConnID,
|
||||
sendConn: sendConn,
|
||||
use0RTT: use0RTT,
|
||||
onClose: onClose,
|
||||
tlsConf: tlsConf,
|
||||
config: config,
|
||||
version: config.Versions[0],
|
||||
handshakeChan: make(chan struct{}),
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *client) dial(ctx context.Context) error {
|
||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||
|
||||
c.conn = newClientConnection(
|
||||
c.sendConn,
|
||||
c.packetHandlers,
|
||||
c.destConnID,
|
||||
c.srcConnID,
|
||||
c.connIDGenerator,
|
||||
c.config,
|
||||
c.tlsConf,
|
||||
c.initialPacketNumber,
|
||||
c.use0RTT,
|
||||
c.hasNegotiatedVersion,
|
||||
c.tracer,
|
||||
c.tracingID,
|
||||
c.logger,
|
||||
c.version,
|
||||
)
|
||||
c.packetHandlers.Add(c.srcConnID, c.conn)
|
||||
|
||||
errorChan := make(chan error, 1)
|
||||
recreateChan := make(chan errCloseForRecreating)
|
||||
go func() {
|
||||
err := c.conn.run()
|
||||
var recreateErr *errCloseForRecreating
|
||||
if errors.As(err, &recreateErr) {
|
||||
recreateChan <- *recreateErr
|
||||
return
|
||||
}
|
||||
if c.onClose != nil {
|
||||
c.onClose()
|
||||
}
|
||||
errorChan <- err // returns as soon as the connection is closed
|
||||
}()
|
||||
|
||||
// only set when we're using 0-RTT
|
||||
// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
|
||||
var earlyConnChan <-chan struct{}
|
||||
if c.use0RTT {
|
||||
earlyConnChan = c.conn.earlyConnReady()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.conn.shutdown()
|
||||
return context.Cause(ctx)
|
||||
case err := <-errorChan:
|
||||
return err
|
||||
case recreateErr := <-recreateChan:
|
||||
c.initialPacketNumber = recreateErr.nextPacketNumber
|
||||
c.version = recreateErr.nextVersion
|
||||
c.hasNegotiatedVersion = true
|
||||
return c.dial(ctx)
|
||||
case <-earlyConnChan:
|
||||
// ready to send 0-RTT data
|
||||
return nil
|
||||
case <-c.conn.HandshakeComplete():
|
||||
// handshake successfully completed
|
||||
return nil
|
||||
}
|
||||
}
|
||||
64
vendor/github.com/quic-go/quic-go/closed_conn.go
generated
vendored
Normal file
64
vendor/github.com/quic-go/quic-go/closed_conn.go
generated
vendored
Normal file
@@ -0,0 +1,64 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"math/bits"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A closedLocalConn is a connection that we closed locally.
|
||||
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
|
||||
// with an exponential backoff.
|
||||
type closedLocalConn struct {
|
||||
counter uint32
|
||||
perspective protocol.Perspective
|
||||
logger utils.Logger
|
||||
|
||||
sendPacket func(net.Addr, packetInfo)
|
||||
}
|
||||
|
||||
var _ packetHandler = &closedLocalConn{}
|
||||
|
||||
// newClosedLocalConn creates a new closedLocalConn and runs it.
|
||||
func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
|
||||
return &closedLocalConn{
|
||||
sendPacket: sendPacket,
|
||||
perspective: pers,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *closedLocalConn) handlePacket(p receivedPacket) {
|
||||
c.counter++
|
||||
// exponential backoff
|
||||
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
|
||||
if bits.OnesCount32(c.counter) != 1 {
|
||||
return
|
||||
}
|
||||
c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", c.counter)
|
||||
c.sendPacket(p.remoteAddr, p.info)
|
||||
}
|
||||
|
||||
func (c *closedLocalConn) shutdown() {}
|
||||
func (c *closedLocalConn) destroy(error) {}
|
||||
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }
|
||||
|
||||
// A closedRemoteConn is a connection that was closed remotely.
|
||||
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
|
||||
// We can just ignore those packets.
|
||||
type closedRemoteConn struct {
|
||||
perspective protocol.Perspective
|
||||
}
|
||||
|
||||
var _ packetHandler = &closedRemoteConn{}
|
||||
|
||||
func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
|
||||
return &closedRemoteConn{perspective: pers}
|
||||
}
|
||||
|
||||
func (s *closedRemoteConn) handlePacket(receivedPacket) {}
|
||||
func (s *closedRemoteConn) shutdown() {}
|
||||
func (s *closedRemoteConn) destroy(error) {}
|
||||
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }
|
||||
14
vendor/github.com/quic-go/quic-go/codecov.yml
generated
vendored
Normal file
14
vendor/github.com/quic-go/quic-go/codecov.yml
generated
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
coverage:
|
||||
round: nearest
|
||||
ignore:
|
||||
- http3/gzip_reader.go
|
||||
- interop/
|
||||
- internal/handshake/cipher_suite.go
|
||||
- internal/utils/linkedlist/linkedlist.go
|
||||
- fuzzing/
|
||||
- metrics/
|
||||
status:
|
||||
project:
|
||||
default:
|
||||
threshold: 0.5
|
||||
patch: false
|
||||
129
vendor/github.com/quic-go/quic-go/config.go
generated
vendored
Normal file
129
vendor/github.com/quic-go/quic-go/config.go
generated
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
// Clone clones a Config
|
||||
func (c *Config) Clone() *Config {
|
||||
copy := *c
|
||||
return ©
|
||||
}
|
||||
|
||||
func (c *Config) handshakeTimeout() time.Duration {
|
||||
return 2 * c.HandshakeIdleTimeout
|
||||
}
|
||||
|
||||
func (c *Config) maxRetryTokenAge() time.Duration {
|
||||
return c.handshakeTimeout()
|
||||
}
|
||||
|
||||
func validateConfig(config *Config) error {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
const maxStreams = 1 << 60
|
||||
if config.MaxIncomingStreams > maxStreams {
|
||||
config.MaxIncomingStreams = maxStreams
|
||||
}
|
||||
if config.MaxIncomingUniStreams > maxStreams {
|
||||
config.MaxIncomingUniStreams = maxStreams
|
||||
}
|
||||
if config.MaxStreamReceiveWindow > quicvarint.Max {
|
||||
config.MaxStreamReceiveWindow = quicvarint.Max
|
||||
}
|
||||
if config.MaxConnectionReceiveWindow > quicvarint.Max {
|
||||
config.MaxConnectionReceiveWindow = quicvarint.Max
|
||||
}
|
||||
// check that all QUIC versions are actually supported
|
||||
for _, v := range config.Versions {
|
||||
if !protocol.IsValidVersion(v) {
|
||||
return fmt.Errorf("invalid QUIC version: %s", v)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateServerConfig(config *Config) *Config {
|
||||
config = populateConfig(config)
|
||||
if config.RequireAddressValidation == nil {
|
||||
config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
// populateConfig populates fields in the quic.Config with their default values, if none are set
|
||||
// it may be called with nil
|
||||
func populateConfig(config *Config) *Config {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
versions := config.Versions
|
||||
if len(versions) == 0 {
|
||||
versions = protocol.SupportedVersions
|
||||
}
|
||||
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
|
||||
if config.HandshakeIdleTimeout != 0 {
|
||||
handshakeIdleTimeout = config.HandshakeIdleTimeout
|
||||
}
|
||||
idleTimeout := protocol.DefaultIdleTimeout
|
||||
if config.MaxIdleTimeout != 0 {
|
||||
idleTimeout = config.MaxIdleTimeout
|
||||
}
|
||||
initialStreamReceiveWindow := config.InitialStreamReceiveWindow
|
||||
if initialStreamReceiveWindow == 0 {
|
||||
initialStreamReceiveWindow = protocol.DefaultInitialMaxStreamData
|
||||
}
|
||||
maxStreamReceiveWindow := config.MaxStreamReceiveWindow
|
||||
if maxStreamReceiveWindow == 0 {
|
||||
maxStreamReceiveWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
|
||||
}
|
||||
initialConnectionReceiveWindow := config.InitialConnectionReceiveWindow
|
||||
if initialConnectionReceiveWindow == 0 {
|
||||
initialConnectionReceiveWindow = protocol.DefaultInitialMaxData
|
||||
}
|
||||
maxConnectionReceiveWindow := config.MaxConnectionReceiveWindow
|
||||
if maxConnectionReceiveWindow == 0 {
|
||||
maxConnectionReceiveWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
|
||||
}
|
||||
maxIncomingStreams := config.MaxIncomingStreams
|
||||
if maxIncomingStreams == 0 {
|
||||
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
|
||||
} else if maxIncomingStreams < 0 {
|
||||
maxIncomingStreams = 0
|
||||
}
|
||||
maxIncomingUniStreams := config.MaxIncomingUniStreams
|
||||
if maxIncomingUniStreams == 0 {
|
||||
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
|
||||
} else if maxIncomingUniStreams < 0 {
|
||||
maxIncomingUniStreams = 0
|
||||
}
|
||||
|
||||
return &Config{
|
||||
GetConfigForClient: config.GetConfigForClient,
|
||||
Versions: versions,
|
||||
HandshakeIdleTimeout: handshakeIdleTimeout,
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
RequireAddressValidation: config.RequireAddressValidation,
|
||||
KeepAlivePeriod: config.KeepAlivePeriod,
|
||||
InitialStreamReceiveWindow: initialStreamReceiveWindow,
|
||||
MaxStreamReceiveWindow: maxStreamReceiveWindow,
|
||||
InitialConnectionReceiveWindow: initialConnectionReceiveWindow,
|
||||
MaxConnectionReceiveWindow: maxConnectionReceiveWindow,
|
||||
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
|
||||
MaxIncomingStreams: maxIncomingStreams,
|
||||
MaxIncomingUniStreams: maxIncomingUniStreams,
|
||||
TokenStore: config.TokenStore,
|
||||
EnableDatagrams: config.EnableDatagrams,
|
||||
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
|
||||
Allow0RTT: config.Allow0RTT,
|
||||
Tracer: config.Tracer,
|
||||
}
|
||||
}
|
||||
139
vendor/github.com/quic-go/quic-go/conn_id_generator.go
generated
vendored
Normal file
139
vendor/github.com/quic-go/quic-go/conn_id_generator.go
generated
vendored
Normal file
@@ -0,0 +1,139 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type connIDGenerator struct {
|
||||
generator ConnectionIDGenerator
|
||||
highestSeq uint64
|
||||
|
||||
activeSrcConnIDs map[uint64]protocol.ConnectionID
|
||||
initialClientDestConnID *protocol.ConnectionID // nil for the client
|
||||
|
||||
addConnectionID func(protocol.ConnectionID)
|
||||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
|
||||
removeConnectionID func(protocol.ConnectionID)
|
||||
retireConnectionID func(protocol.ConnectionID)
|
||||
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte)
|
||||
queueControlFrame func(wire.Frame)
|
||||
}
|
||||
|
||||
func newConnIDGenerator(
|
||||
initialConnectionID protocol.ConnectionID,
|
||||
initialClientDestConnID *protocol.ConnectionID, // nil for the client
|
||||
addConnectionID func(protocol.ConnectionID),
|
||||
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
|
||||
removeConnectionID func(protocol.ConnectionID),
|
||||
retireConnectionID func(protocol.ConnectionID),
|
||||
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte),
|
||||
queueControlFrame func(wire.Frame),
|
||||
generator ConnectionIDGenerator,
|
||||
) *connIDGenerator {
|
||||
m := &connIDGenerator{
|
||||
generator: generator,
|
||||
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
|
||||
addConnectionID: addConnectionID,
|
||||
getStatelessResetToken: getStatelessResetToken,
|
||||
removeConnectionID: removeConnectionID,
|
||||
retireConnectionID: retireConnectionID,
|
||||
replaceWithClosed: replaceWithClosed,
|
||||
queueControlFrame: queueControlFrame,
|
||||
}
|
||||
m.activeSrcConnIDs[0] = initialConnectionID
|
||||
m.initialClientDestConnID = initialClientDestConnID
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
|
||||
if m.generator.ConnectionIDLen() == 0 {
|
||||
return nil
|
||||
}
|
||||
// The active_connection_id_limit transport parameter is the number of
|
||||
// connection IDs the peer will store. This limit includes the connection ID
|
||||
// used during the handshake, and the one sent in the preferred_address
|
||||
// transport parameter.
|
||||
// We currently don't send the preferred_address transport parameter,
|
||||
// so we can issue (limit - 1) connection IDs.
|
||||
for i := uint64(len(m.activeSrcConnIDs)); i < utils.Min(limit, protocol.MaxIssuedConnectionIDs); i++ {
|
||||
if err := m.issueNewConnID(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error {
|
||||
if seq > m.highestSeq {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq),
|
||||
}
|
||||
}
|
||||
connID, ok := m.activeSrcConnIDs[seq]
|
||||
// We might already have deleted this connection ID, if this is a duplicate frame.
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if connID == sentWithDestConnID {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
|
||||
}
|
||||
}
|
||||
m.retireConnectionID(connID)
|
||||
delete(m.activeSrcConnIDs, seq)
|
||||
// Don't issue a replacement for the initial connection ID.
|
||||
if seq == 0 {
|
||||
return nil
|
||||
}
|
||||
return m.issueNewConnID()
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) issueNewConnID() error {
|
||||
connID, err := m.generator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.activeSrcConnIDs[m.highestSeq+1] = connID
|
||||
m.addConnectionID(connID)
|
||||
m.queueControlFrame(&wire.NewConnectionIDFrame{
|
||||
SequenceNumber: m.highestSeq + 1,
|
||||
ConnectionID: connID,
|
||||
StatelessResetToken: m.getStatelessResetToken(connID),
|
||||
})
|
||||
m.highestSeq++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) SetHandshakeComplete() {
|
||||
if m.initialClientDestConnID != nil {
|
||||
m.retireConnectionID(*m.initialClientDestConnID)
|
||||
m.initialClientDestConnID = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) RemoveAll() {
|
||||
if m.initialClientDestConnID != nil {
|
||||
m.removeConnectionID(*m.initialClientDestConnID)
|
||||
}
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
m.removeConnectionID(connID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) {
|
||||
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
|
||||
if m.initialClientDestConnID != nil {
|
||||
connIDs = append(connIDs, *m.initialClientDestConnID)
|
||||
}
|
||||
for _, connID := range m.activeSrcConnIDs {
|
||||
connIDs = append(connIDs, connID)
|
||||
}
|
||||
m.replaceWithClosed(connIDs, pers, connClose)
|
||||
}
|
||||
214
vendor/github.com/quic-go/quic-go/conn_id_manager.go
generated
vendored
Normal file
214
vendor/github.com/quic-go/quic-go/conn_id_manager.go
generated
vendored
Normal file
@@ -0,0 +1,214 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type newConnID struct {
|
||||
SequenceNumber uint64
|
||||
ConnectionID protocol.ConnectionID
|
||||
StatelessResetToken protocol.StatelessResetToken
|
||||
}
|
||||
|
||||
type connIDManager struct {
|
||||
queue list.List[newConnID]
|
||||
|
||||
handshakeComplete bool
|
||||
activeSequenceNumber uint64
|
||||
highestRetired uint64
|
||||
activeConnectionID protocol.ConnectionID
|
||||
activeStatelessResetToken *protocol.StatelessResetToken
|
||||
|
||||
// We change the connection ID after sending on average
|
||||
// protocol.PacketsPerConnectionID packets. The actual value is randomized
|
||||
// hide the packet loss rate from on-path observers.
|
||||
rand utils.Rand
|
||||
packetsSinceLastChange uint32
|
||||
packetsPerConnectionID uint32
|
||||
|
||||
addStatelessResetToken func(protocol.StatelessResetToken)
|
||||
removeStatelessResetToken func(protocol.StatelessResetToken)
|
||||
queueControlFrame func(wire.Frame)
|
||||
}
|
||||
|
||||
func newConnIDManager(
|
||||
initialDestConnID protocol.ConnectionID,
|
||||
addStatelessResetToken func(protocol.StatelessResetToken),
|
||||
removeStatelessResetToken func(protocol.StatelessResetToken),
|
||||
queueControlFrame func(wire.Frame),
|
||||
) *connIDManager {
|
||||
return &connIDManager{
|
||||
activeConnectionID: initialDestConnID,
|
||||
addStatelessResetToken: addStatelessResetToken,
|
||||
removeStatelessResetToken: removeStatelessResetToken,
|
||||
queueControlFrame: queueControlFrame,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *connIDManager) AddFromPreferredAddress(connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
|
||||
return h.addConnectionID(1, connID, resetToken)
|
||||
}
|
||||
|
||||
func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
|
||||
if err := h.add(f); err != nil {
|
||||
return err
|
||||
}
|
||||
if h.queue.Len() >= protocol.MaxActiveConnectionIDs {
|
||||
return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
|
||||
// If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active
|
||||
// connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately.
|
||||
if f.SequenceNumber < h.activeSequenceNumber || f.SequenceNumber < h.highestRetired {
|
||||
h.queueControlFrame(&wire.RetireConnectionIDFrame{
|
||||
SequenceNumber: f.SequenceNumber,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Retire elements in the queue.
|
||||
// Doesn't retire the active connection ID.
|
||||
if f.RetirePriorTo > h.highestRetired {
|
||||
var next *list.Element[newConnID]
|
||||
for el := h.queue.Front(); el != nil; el = next {
|
||||
if el.Value.SequenceNumber >= f.RetirePriorTo {
|
||||
break
|
||||
}
|
||||
next = el.Next()
|
||||
h.queueControlFrame(&wire.RetireConnectionIDFrame{
|
||||
SequenceNumber: el.Value.SequenceNumber,
|
||||
})
|
||||
h.queue.Remove(el)
|
||||
}
|
||||
h.highestRetired = f.RetirePriorTo
|
||||
}
|
||||
|
||||
if f.SequenceNumber == h.activeSequenceNumber {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := h.addConnectionID(f.SequenceNumber, f.ConnectionID, f.StatelessResetToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Retire the active connection ID, if necessary.
|
||||
if h.activeSequenceNumber < f.RetirePriorTo {
|
||||
// The queue is guaranteed to have at least one element at this point.
|
||||
h.updateConnectionID()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
|
||||
// insert a new element at the end
|
||||
if h.queue.Len() == 0 || h.queue.Back().Value.SequenceNumber < seq {
|
||||
h.queue.PushBack(newConnID{
|
||||
SequenceNumber: seq,
|
||||
ConnectionID: connID,
|
||||
StatelessResetToken: resetToken,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
// insert a new element somewhere in the middle
|
||||
for el := h.queue.Front(); el != nil; el = el.Next() {
|
||||
if el.Value.SequenceNumber == seq {
|
||||
if el.Value.ConnectionID != connID {
|
||||
return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq)
|
||||
}
|
||||
if el.Value.StatelessResetToken != resetToken {
|
||||
return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq)
|
||||
}
|
||||
break
|
||||
}
|
||||
if el.Value.SequenceNumber > seq {
|
||||
h.queue.InsertBefore(newConnID{
|
||||
SequenceNumber: seq,
|
||||
ConnectionID: connID,
|
||||
StatelessResetToken: resetToken,
|
||||
}, el)
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *connIDManager) updateConnectionID() {
|
||||
h.queueControlFrame(&wire.RetireConnectionIDFrame{
|
||||
SequenceNumber: h.activeSequenceNumber,
|
||||
})
|
||||
h.highestRetired = utils.Max(h.highestRetired, h.activeSequenceNumber)
|
||||
if h.activeStatelessResetToken != nil {
|
||||
h.removeStatelessResetToken(*h.activeStatelessResetToken)
|
||||
}
|
||||
|
||||
front := h.queue.Remove(h.queue.Front())
|
||||
h.activeSequenceNumber = front.SequenceNumber
|
||||
h.activeConnectionID = front.ConnectionID
|
||||
h.activeStatelessResetToken = &front.StatelessResetToken
|
||||
h.packetsSinceLastChange = 0
|
||||
h.packetsPerConnectionID = protocol.PacketsPerConnectionID/2 + uint32(h.rand.Int31n(protocol.PacketsPerConnectionID))
|
||||
h.addStatelessResetToken(*h.activeStatelessResetToken)
|
||||
}
|
||||
|
||||
func (h *connIDManager) Close() {
|
||||
if h.activeStatelessResetToken != nil {
|
||||
h.removeStatelessResetToken(*h.activeStatelessResetToken)
|
||||
}
|
||||
}
|
||||
|
||||
// is called when the server performs a Retry
|
||||
// and when the server changes the connection ID in the first Initial sent
|
||||
func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) {
|
||||
if h.activeSequenceNumber != 0 {
|
||||
panic("expected first connection ID to have sequence number 0")
|
||||
}
|
||||
h.activeConnectionID = newConnID
|
||||
}
|
||||
|
||||
// is called when the server provides a stateless reset token in the transport parameters
|
||||
func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) {
|
||||
if h.activeSequenceNumber != 0 {
|
||||
panic("expected first connection ID to have sequence number 0")
|
||||
}
|
||||
h.activeStatelessResetToken = &token
|
||||
h.addStatelessResetToken(token)
|
||||
}
|
||||
|
||||
func (h *connIDManager) SentPacket() {
|
||||
h.packetsSinceLastChange++
|
||||
}
|
||||
|
||||
func (h *connIDManager) shouldUpdateConnID() bool {
|
||||
if !h.handshakeComplete {
|
||||
return false
|
||||
}
|
||||
// initiate the first change as early as possible (after handshake completion)
|
||||
if h.queue.Len() > 0 && h.activeSequenceNumber == 0 {
|
||||
return true
|
||||
}
|
||||
// For later changes, only change if
|
||||
// 1. The queue of connection IDs is filled more than 50%.
|
||||
// 2. We sent at least PacketsPerConnectionID packets
|
||||
return 2*h.queue.Len() >= protocol.MaxActiveConnectionIDs &&
|
||||
h.packetsSinceLastChange >= h.packetsPerConnectionID
|
||||
}
|
||||
|
||||
func (h *connIDManager) Get() protocol.ConnectionID {
|
||||
if h.shouldUpdateConnID() {
|
||||
h.updateConnectionID()
|
||||
}
|
||||
return h.activeConnectionID
|
||||
}
|
||||
|
||||
func (h *connIDManager) SetHandshakeComplete() {
|
||||
h.handshakeComplete = true
|
||||
}
|
||||
2387
vendor/github.com/quic-go/quic-go/connection.go
generated
vendored
Normal file
2387
vendor/github.com/quic-go/quic-go/connection.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
51
vendor/github.com/quic-go/quic-go/connection_timer.go
generated
vendored
Normal file
51
vendor/github.com/quic-go/quic-go/connection_timer.go
generated
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
var deadlineSendImmediately = time.Time{}.Add(42 * time.Millisecond) // any value > time.Time{} and before time.Now() is fine
|
||||
|
||||
type connectionTimer struct {
|
||||
timer *utils.Timer
|
||||
last time.Time
|
||||
}
|
||||
|
||||
func newTimer() *connectionTimer {
|
||||
return &connectionTimer{timer: utils.NewTimer()}
|
||||
}
|
||||
|
||||
func (t *connectionTimer) SetRead() {
|
||||
if deadline := t.timer.Deadline(); deadline != deadlineSendImmediately {
|
||||
t.last = deadline
|
||||
}
|
||||
t.timer.SetRead()
|
||||
}
|
||||
|
||||
func (t *connectionTimer) Chan() <-chan time.Time {
|
||||
return t.timer.Chan()
|
||||
}
|
||||
|
||||
// SetTimer resets the timer.
|
||||
// It makes sure that the deadline is strictly increasing.
|
||||
// This prevents busy-looping in cases where the timer fires, but we can't actually send out a packet.
|
||||
// This doesn't apply to the pacing deadline, which can be set multiple times to deadlineSendImmediately.
|
||||
func (t *connectionTimer) SetTimer(idleTimeoutOrKeepAlive, ackAlarm, lossTime, pacing time.Time) {
|
||||
deadline := idleTimeoutOrKeepAlive
|
||||
if !ackAlarm.IsZero() && ackAlarm.Before(deadline) && ackAlarm.After(t.last) {
|
||||
deadline = ackAlarm
|
||||
}
|
||||
if !lossTime.IsZero() && lossTime.Before(deadline) && lossTime.After(t.last) {
|
||||
deadline = lossTime
|
||||
}
|
||||
if !pacing.IsZero() && pacing.Before(deadline) {
|
||||
deadline = pacing
|
||||
}
|
||||
t.timer.Reset(deadline)
|
||||
}
|
||||
|
||||
func (t *connectionTimer) Stop() {
|
||||
t.timer.Stop()
|
||||
}
|
||||
107
vendor/github.com/quic-go/quic-go/crypto_stream.go
generated
vendored
Normal file
107
vendor/github.com/quic-go/quic-go/crypto_stream.go
generated
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoStream interface {
|
||||
// for receiving data
|
||||
HandleCryptoFrame(*wire.CryptoFrame) error
|
||||
GetCryptoData() []byte
|
||||
Finish() error
|
||||
// for sending data
|
||||
io.Writer
|
||||
HasData() bool
|
||||
PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
|
||||
}
|
||||
|
||||
type cryptoStreamImpl struct {
|
||||
queue *frameSorter
|
||||
msgBuf []byte
|
||||
|
||||
highestOffset protocol.ByteCount
|
||||
finished bool
|
||||
|
||||
writeOffset protocol.ByteCount
|
||||
writeBuf []byte
|
||||
}
|
||||
|
||||
func newCryptoStream() cryptoStream {
|
||||
return &cryptoStreamImpl{queue: newFrameSorter()}
|
||||
}
|
||||
|
||||
func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
|
||||
highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
|
||||
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.CryptoBufferExceeded,
|
||||
ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset),
|
||||
}
|
||||
}
|
||||
if s.finished {
|
||||
if highestOffset > s.highestOffset {
|
||||
// reject crypto data received after this stream was already finished
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: "received crypto data after change of encryption level",
|
||||
}
|
||||
}
|
||||
// ignore data with a smaller offset than the highest received
|
||||
// could e.g. be a retransmission
|
||||
return nil
|
||||
}
|
||||
s.highestOffset = utils.Max(s.highestOffset, highestOffset)
|
||||
if err := s.queue.Push(f.Data, f.Offset, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
_, data, _ := s.queue.Pop()
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
s.msgBuf = append(s.msgBuf, data...)
|
||||
}
|
||||
}
|
||||
|
||||
// GetCryptoData retrieves data that was received in CRYPTO frames
|
||||
func (s *cryptoStreamImpl) GetCryptoData() []byte {
|
||||
b := s.msgBuf
|
||||
s.msgBuf = nil
|
||||
return b
|
||||
}
|
||||
|
||||
func (s *cryptoStreamImpl) Finish() error {
|
||||
if s.queue.HasMoreData() {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: "encryption level changed, but crypto stream has more data to read",
|
||||
}
|
||||
}
|
||||
s.finished = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Writes writes data that should be sent out in CRYPTO frames
|
||||
func (s *cryptoStreamImpl) Write(p []byte) (int, error) {
|
||||
s.writeBuf = append(s.writeBuf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (s *cryptoStreamImpl) HasData() bool {
|
||||
return len(s.writeBuf) > 0
|
||||
}
|
||||
|
||||
func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
|
||||
f := &wire.CryptoFrame{Offset: s.writeOffset}
|
||||
n := utils.Min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
|
||||
f.Data = s.writeBuf[:n]
|
||||
s.writeBuf = s.writeBuf[n:]
|
||||
s.writeOffset += n
|
||||
return f
|
||||
}
|
||||
82
vendor/github.com/quic-go/quic-go/crypto_stream_manager.go
generated
vendored
Normal file
82
vendor/github.com/quic-go/quic-go/crypto_stream_manager.go
generated
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/handshake"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoDataHandler interface {
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) error
|
||||
NextEvent() handshake.Event
|
||||
}
|
||||
|
||||
type cryptoStreamManager struct {
|
||||
cryptoHandler cryptoDataHandler
|
||||
|
||||
initialStream cryptoStream
|
||||
handshakeStream cryptoStream
|
||||
oneRTTStream cryptoStream
|
||||
}
|
||||
|
||||
func newCryptoStreamManager(
|
||||
cryptoHandler cryptoDataHandler,
|
||||
initialStream cryptoStream,
|
||||
handshakeStream cryptoStream,
|
||||
oneRTTStream cryptoStream,
|
||||
) *cryptoStreamManager {
|
||||
return &cryptoStreamManager{
|
||||
cryptoHandler: cryptoHandler,
|
||||
initialStream: initialStream,
|
||||
handshakeStream: handshakeStream,
|
||||
oneRTTStream: oneRTTStream,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
|
||||
var str cryptoStream
|
||||
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
str = m.initialStream
|
||||
case protocol.EncryptionHandshake:
|
||||
str = m.handshakeStream
|
||||
case protocol.Encryption1RTT:
|
||||
str = m.oneRTTStream
|
||||
default:
|
||||
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
|
||||
}
|
||||
if err := str.HandleCryptoFrame(frame); err != nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
data := str.GetCryptoData()
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
if err := m.cryptoHandler.HandleMessage(data, encLevel); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame {
|
||||
if !m.oneRTTStream.HasData() {
|
||||
return nil
|
||||
}
|
||||
return m.oneRTTStream.PopCryptoFrame(maxSize)
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) Drop(encLevel protocol.EncryptionLevel) error {
|
||||
//nolint:exhaustive // 1-RTT keys should never get dropped.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
return m.initialStream.Finish()
|
||||
case protocol.EncryptionHandshake:
|
||||
return m.handshakeStream.Finish()
|
||||
default:
|
||||
panic(fmt.Sprintf("dropped unexpected encryption level: %s", encLevel))
|
||||
}
|
||||
}
|
||||
126
vendor/github.com/quic-go/quic-go/datagram_queue.go
generated
vendored
Normal file
126
vendor/github.com/quic-go/quic-go/datagram_queue.go
generated
vendored
Normal file
@@ -0,0 +1,126 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type datagramQueue struct {
|
||||
sendQueue chan *wire.DatagramFrame
|
||||
nextFrame *wire.DatagramFrame
|
||||
|
||||
rcvMx sync.Mutex
|
||||
rcvQueue [][]byte
|
||||
rcvd chan struct{} // used to notify Receive that a new datagram was received
|
||||
|
||||
closeErr error
|
||||
closed chan struct{}
|
||||
|
||||
hasData func()
|
||||
|
||||
dequeued chan struct{}
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue {
|
||||
return &datagramQueue{
|
||||
hasData: hasData,
|
||||
sendQueue: make(chan *wire.DatagramFrame, 1),
|
||||
rcvd: make(chan struct{}, 1),
|
||||
dequeued: make(chan struct{}),
|
||||
closed: make(chan struct{}),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// AddAndWait queues a new DATAGRAM frame for sending.
|
||||
// It blocks until the frame has been dequeued.
|
||||
func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error {
|
||||
select {
|
||||
case h.sendQueue <- f:
|
||||
h.hasData()
|
||||
case <-h.closed:
|
||||
return h.closeErr
|
||||
}
|
||||
|
||||
select {
|
||||
case <-h.dequeued:
|
||||
return nil
|
||||
case <-h.closed:
|
||||
return h.closeErr
|
||||
}
|
||||
}
|
||||
|
||||
// Peek gets the next DATAGRAM frame for sending.
|
||||
// If actually sent out, Pop needs to be called before the next call to Peek.
|
||||
func (h *datagramQueue) Peek() *wire.DatagramFrame {
|
||||
if h.nextFrame != nil {
|
||||
return h.nextFrame
|
||||
}
|
||||
select {
|
||||
case h.nextFrame = <-h.sendQueue:
|
||||
h.dequeued <- struct{}{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return h.nextFrame
|
||||
}
|
||||
|
||||
func (h *datagramQueue) Pop() {
|
||||
if h.nextFrame == nil {
|
||||
panic("datagramQueue BUG: Pop called for nil frame")
|
||||
}
|
||||
h.nextFrame = nil
|
||||
}
|
||||
|
||||
// HandleDatagramFrame handles a received DATAGRAM frame.
|
||||
func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
|
||||
data := make([]byte, len(f.Data))
|
||||
copy(data, f.Data)
|
||||
var queued bool
|
||||
h.rcvMx.Lock()
|
||||
if len(h.rcvQueue) < protocol.DatagramRcvQueueLen {
|
||||
h.rcvQueue = append(h.rcvQueue, data)
|
||||
queued = true
|
||||
select {
|
||||
case h.rcvd <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
h.rcvMx.Unlock()
|
||||
if !queued && h.logger.Debug() {
|
||||
h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data))
|
||||
}
|
||||
}
|
||||
|
||||
// Receive gets a received DATAGRAM frame.
|
||||
func (h *datagramQueue) Receive(ctx context.Context) ([]byte, error) {
|
||||
for {
|
||||
h.rcvMx.Lock()
|
||||
if len(h.rcvQueue) > 0 {
|
||||
data := h.rcvQueue[0]
|
||||
h.rcvQueue = h.rcvQueue[1:]
|
||||
h.rcvMx.Unlock()
|
||||
return data, nil
|
||||
}
|
||||
h.rcvMx.Unlock()
|
||||
select {
|
||||
case <-h.rcvd:
|
||||
continue
|
||||
case <-h.closed:
|
||||
return nil, h.closeErr
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *datagramQueue) CloseWithError(e error) {
|
||||
h.closeErr = e
|
||||
close(h.closed)
|
||||
}
|
||||
63
vendor/github.com/quic-go/quic-go/errors.go
generated
vendored
Normal file
63
vendor/github.com/quic-go/quic-go/errors.go
generated
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
)
|
||||
|
||||
type (
|
||||
TransportError = qerr.TransportError
|
||||
ApplicationError = qerr.ApplicationError
|
||||
VersionNegotiationError = qerr.VersionNegotiationError
|
||||
StatelessResetError = qerr.StatelessResetError
|
||||
IdleTimeoutError = qerr.IdleTimeoutError
|
||||
HandshakeTimeoutError = qerr.HandshakeTimeoutError
|
||||
)
|
||||
|
||||
type (
|
||||
TransportErrorCode = qerr.TransportErrorCode
|
||||
ApplicationErrorCode = qerr.ApplicationErrorCode
|
||||
StreamErrorCode = qerr.StreamErrorCode
|
||||
)
|
||||
|
||||
const (
|
||||
NoError = qerr.NoError
|
||||
InternalError = qerr.InternalError
|
||||
ConnectionRefused = qerr.ConnectionRefused
|
||||
FlowControlError = qerr.FlowControlError
|
||||
StreamLimitError = qerr.StreamLimitError
|
||||
StreamStateError = qerr.StreamStateError
|
||||
FinalSizeError = qerr.FinalSizeError
|
||||
FrameEncodingError = qerr.FrameEncodingError
|
||||
TransportParameterError = qerr.TransportParameterError
|
||||
ConnectionIDLimitError = qerr.ConnectionIDLimitError
|
||||
ProtocolViolation = qerr.ProtocolViolation
|
||||
InvalidToken = qerr.InvalidToken
|
||||
ApplicationErrorErrorCode = qerr.ApplicationErrorErrorCode
|
||||
CryptoBufferExceeded = qerr.CryptoBufferExceeded
|
||||
KeyUpdateError = qerr.KeyUpdateError
|
||||
AEADLimitReached = qerr.AEADLimitReached
|
||||
NoViablePathError = qerr.NoViablePathError
|
||||
)
|
||||
|
||||
// A StreamError is used for Stream.CancelRead and Stream.CancelWrite.
|
||||
// It is also returned from Stream.Read and Stream.Write if the peer canceled reading or writing.
|
||||
type StreamError struct {
|
||||
StreamID StreamID
|
||||
ErrorCode StreamErrorCode
|
||||
Remote bool
|
||||
}
|
||||
|
||||
func (e *StreamError) Is(target error) bool {
|
||||
_, ok := target.(*StreamError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e *StreamError) Error() string {
|
||||
pers := "local"
|
||||
if e.Remote {
|
||||
pers = "remote"
|
||||
}
|
||||
return fmt.Sprintf("stream %d canceled by %s with error code %d", e.StreamID, pers, e.ErrorCode)
|
||||
}
|
||||
237
vendor/github.com/quic-go/quic-go/frame_sorter.go
generated
vendored
Normal file
237
vendor/github.com/quic-go/quic-go/frame_sorter.go
generated
vendored
Normal file
@@ -0,0 +1,237 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
|
||||
)
|
||||
|
||||
// byteInterval is an interval from one ByteCount to the other
|
||||
type byteInterval struct {
|
||||
Start protocol.ByteCount
|
||||
End protocol.ByteCount
|
||||
}
|
||||
|
||||
var byteIntervalElementPool sync.Pool
|
||||
|
||||
func init() {
|
||||
byteIntervalElementPool = *list.NewPool[byteInterval]()
|
||||
}
|
||||
|
||||
type frameSorterEntry struct {
|
||||
Data []byte
|
||||
DoneCb func()
|
||||
}
|
||||
|
||||
type frameSorter struct {
|
||||
queue map[protocol.ByteCount]frameSorterEntry
|
||||
readPos protocol.ByteCount
|
||||
gaps *list.List[byteInterval]
|
||||
}
|
||||
|
||||
var errDuplicateStreamData = errors.New("duplicate stream data")
|
||||
|
||||
func newFrameSorter() *frameSorter {
|
||||
s := frameSorter{
|
||||
gaps: list.NewWithPool[byteInterval](&byteIntervalElementPool),
|
||||
queue: make(map[protocol.ByteCount]frameSorterEntry),
|
||||
}
|
||||
s.gaps.PushFront(byteInterval{Start: 0, End: protocol.MaxByteCount})
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, doneCb func()) error {
|
||||
err := s.push(data, offset, doneCb)
|
||||
if err == errDuplicateStreamData {
|
||||
if doneCb != nil {
|
||||
doneCb()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func()) error {
|
||||
if len(data) == 0 {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
|
||||
start := offset
|
||||
end := offset + protocol.ByteCount(len(data))
|
||||
|
||||
if end <= s.gaps.Front().Value.Start {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
|
||||
startGap, startsInGap := s.findStartGap(start)
|
||||
endGap, endsInGap := s.findEndGap(startGap, end)
|
||||
|
||||
startGapEqualsEndGap := startGap == endGap
|
||||
|
||||
if (startGapEqualsEndGap && end <= startGap.Value.Start) ||
|
||||
(!startGapEqualsEndGap && startGap.Value.End >= endGap.Value.Start && end <= startGap.Value.Start) {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
|
||||
startGapNext := startGap.Next()
|
||||
startGapEnd := startGap.Value.End // save it, in case startGap is modified
|
||||
endGapStart := endGap.Value.Start // save it, in case endGap is modified
|
||||
endGapEnd := endGap.Value.End // save it, in case endGap is modified
|
||||
var adjustedStartGapEnd bool
|
||||
var wasCut bool
|
||||
|
||||
pos := start
|
||||
var hasReplacedAtLeastOne bool
|
||||
for {
|
||||
oldEntry, ok := s.queue[pos]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
oldEntryLen := protocol.ByteCount(len(oldEntry.Data))
|
||||
if end-pos > oldEntryLen || (hasReplacedAtLeastOne && end-pos == oldEntryLen) {
|
||||
// The existing frame is shorter than the new frame. Replace it.
|
||||
delete(s.queue, pos)
|
||||
pos += oldEntryLen
|
||||
hasReplacedAtLeastOne = true
|
||||
if oldEntry.DoneCb != nil {
|
||||
oldEntry.DoneCb()
|
||||
}
|
||||
} else {
|
||||
if !hasReplacedAtLeastOne {
|
||||
return errDuplicateStreamData
|
||||
}
|
||||
// The existing frame is longer than the new frame.
|
||||
// Cut the new frame such that the end aligns with the start of the existing frame.
|
||||
data = data[:pos-start]
|
||||
end = pos
|
||||
wasCut = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !startsInGap && !hasReplacedAtLeastOne {
|
||||
// cut the frame, such that it starts at the start of the gap
|
||||
data = data[startGap.Value.Start-start:]
|
||||
start = startGap.Value.Start
|
||||
wasCut = true
|
||||
}
|
||||
if start <= startGap.Value.Start {
|
||||
if end >= startGap.Value.End {
|
||||
// The frame covers the whole startGap. Delete the gap.
|
||||
s.gaps.Remove(startGap)
|
||||
} else {
|
||||
startGap.Value.Start = end
|
||||
}
|
||||
} else if !hasReplacedAtLeastOne {
|
||||
startGap.Value.End = start
|
||||
adjustedStartGapEnd = true
|
||||
}
|
||||
|
||||
if !startGapEqualsEndGap {
|
||||
s.deleteConsecutive(startGapEnd)
|
||||
var nextGap *list.Element[byteInterval]
|
||||
for gap := startGapNext; gap.Value.End < endGapStart; gap = nextGap {
|
||||
nextGap = gap.Next()
|
||||
s.deleteConsecutive(gap.Value.End)
|
||||
s.gaps.Remove(gap)
|
||||
}
|
||||
}
|
||||
|
||||
if !endsInGap && start != endGapEnd && end > endGapEnd {
|
||||
// cut the frame, such that it ends at the end of the gap
|
||||
data = data[:endGapEnd-start]
|
||||
end = endGapEnd
|
||||
wasCut = true
|
||||
}
|
||||
if end == endGapEnd {
|
||||
if !startGapEqualsEndGap {
|
||||
// The frame covers the whole endGap. Delete the gap.
|
||||
s.gaps.Remove(endGap)
|
||||
}
|
||||
} else {
|
||||
if startGapEqualsEndGap && adjustedStartGapEnd {
|
||||
// The frame split the existing gap into two.
|
||||
s.gaps.InsertAfter(byteInterval{Start: end, End: startGapEnd}, startGap)
|
||||
} else if !startGapEqualsEndGap {
|
||||
endGap.Value.Start = end
|
||||
}
|
||||
}
|
||||
|
||||
if wasCut && len(data) < protocol.MinStreamFrameBufferSize {
|
||||
newData := make([]byte, len(data))
|
||||
copy(newData, data)
|
||||
data = newData
|
||||
if doneCb != nil {
|
||||
doneCb()
|
||||
doneCb = nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps {
|
||||
return errors.New("too many gaps in received data")
|
||||
}
|
||||
|
||||
s.queue[start] = frameSorterEntry{Data: data, DoneCb: doneCb}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*list.Element[byteInterval], bool) {
|
||||
for gap := s.gaps.Front(); gap != nil; gap = gap.Next() {
|
||||
if offset >= gap.Value.Start && offset <= gap.Value.End {
|
||||
return gap, true
|
||||
}
|
||||
if offset < gap.Value.Start {
|
||||
return gap, false
|
||||
}
|
||||
}
|
||||
panic("no gap found")
|
||||
}
|
||||
|
||||
func (s *frameSorter) findEndGap(startGap *list.Element[byteInterval], offset protocol.ByteCount) (*list.Element[byteInterval], bool) {
|
||||
for gap := startGap; gap != nil; gap = gap.Next() {
|
||||
if offset >= gap.Value.Start && offset < gap.Value.End {
|
||||
return gap, true
|
||||
}
|
||||
if offset < gap.Value.Start {
|
||||
return gap.Prev(), false
|
||||
}
|
||||
}
|
||||
panic("no gap found")
|
||||
}
|
||||
|
||||
// deleteConsecutive deletes consecutive frames from the queue, starting at pos
|
||||
func (s *frameSorter) deleteConsecutive(pos protocol.ByteCount) {
|
||||
for {
|
||||
oldEntry, ok := s.queue[pos]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
oldEntryLen := protocol.ByteCount(len(oldEntry.Data))
|
||||
delete(s.queue, pos)
|
||||
if oldEntry.DoneCb != nil {
|
||||
oldEntry.DoneCb()
|
||||
}
|
||||
pos += oldEntryLen
|
||||
}
|
||||
}
|
||||
|
||||
func (s *frameSorter) Pop() (protocol.ByteCount, []byte, func()) {
|
||||
entry, ok := s.queue[s.readPos]
|
||||
if !ok {
|
||||
return s.readPos, nil, nil
|
||||
}
|
||||
delete(s.queue, s.readPos)
|
||||
offset := s.readPos
|
||||
s.readPos += protocol.ByteCount(len(entry.Data))
|
||||
if s.gaps.Front().Value.End <= s.readPos {
|
||||
panic("frame sorter BUG: read position higher than a gap")
|
||||
}
|
||||
return offset, entry.Data, entry.DoneCb
|
||||
}
|
||||
|
||||
// HasMoreData says if there is any more data queued at *any* offset.
|
||||
func (s *frameSorter) HasMoreData() bool {
|
||||
return len(s.queue) > 0
|
||||
}
|
||||
165
vendor/github.com/quic-go/quic-go/framer.go
generated
vendored
Normal file
165
vendor/github.com/quic-go/quic-go/framer.go
generated
vendored
Normal file
@@ -0,0 +1,165 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/ackhandler"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils/ringbuffer"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
type framer interface {
|
||||
HasData() bool
|
||||
|
||||
QueueControlFrame(wire.Frame)
|
||||
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount)
|
||||
|
||||
AddActiveStream(protocol.StreamID)
|
||||
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount)
|
||||
|
||||
Handle0RTTRejection() error
|
||||
}
|
||||
|
||||
type framerI struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
streamGetter streamGetter
|
||||
|
||||
activeStreams map[protocol.StreamID]struct{}
|
||||
streamQueue ringbuffer.RingBuffer[protocol.StreamID]
|
||||
|
||||
controlFrameMutex sync.Mutex
|
||||
controlFrames []wire.Frame
|
||||
}
|
||||
|
||||
var _ framer = &framerI{}
|
||||
|
||||
func newFramer(streamGetter streamGetter) framer {
|
||||
return &framerI{
|
||||
streamGetter: streamGetter,
|
||||
activeStreams: make(map[protocol.StreamID]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *framerI) HasData() bool {
|
||||
f.mutex.Lock()
|
||||
hasData := !f.streamQueue.Empty()
|
||||
f.mutex.Unlock()
|
||||
if hasData {
|
||||
return true
|
||||
}
|
||||
f.controlFrameMutex.Lock()
|
||||
hasData = len(f.controlFrames) > 0
|
||||
f.controlFrameMutex.Unlock()
|
||||
return hasData
|
||||
}
|
||||
|
||||
func (f *framerI) QueueControlFrame(frame wire.Frame) {
|
||||
f.controlFrameMutex.Lock()
|
||||
f.controlFrames = append(f.controlFrames, frame)
|
||||
f.controlFrameMutex.Unlock()
|
||||
}
|
||||
|
||||
func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) {
|
||||
var length protocol.ByteCount
|
||||
f.controlFrameMutex.Lock()
|
||||
for len(f.controlFrames) > 0 {
|
||||
frame := f.controlFrames[len(f.controlFrames)-1]
|
||||
frameLen := frame.Length(v)
|
||||
if length+frameLen > maxLen {
|
||||
break
|
||||
}
|
||||
frames = append(frames, ackhandler.Frame{Frame: frame})
|
||||
length += frameLen
|
||||
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
|
||||
}
|
||||
f.controlFrameMutex.Unlock()
|
||||
return frames, length
|
||||
}
|
||||
|
||||
func (f *framerI) AddActiveStream(id protocol.StreamID) {
|
||||
f.mutex.Lock()
|
||||
if _, ok := f.activeStreams[id]; !ok {
|
||||
f.streamQueue.PushBack(id)
|
||||
f.activeStreams[id] = struct{}{}
|
||||
}
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) {
|
||||
startLen := len(frames)
|
||||
var length protocol.ByteCount
|
||||
f.mutex.Lock()
|
||||
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
|
||||
numActiveStreams := f.streamQueue.Len()
|
||||
for i := 0; i < numActiveStreams; i++ {
|
||||
if protocol.MinStreamFrameSize+length > maxLen {
|
||||
break
|
||||
}
|
||||
id := f.streamQueue.PopFront()
|
||||
// This should never return an error. Better check it anyway.
|
||||
// The stream will only be in the streamQueue, if it enqueued itself there.
|
||||
str, err := f.streamGetter.GetOrOpenSendStream(id)
|
||||
// The stream can be nil if it completed after it said it had data.
|
||||
if str == nil || err != nil {
|
||||
delete(f.activeStreams, id)
|
||||
continue
|
||||
}
|
||||
remainingLen := maxLen - length
|
||||
// For the last STREAM frame, we'll remove the DataLen field later.
|
||||
// Therefore, we can pretend to have more bytes available when popping
|
||||
// the STREAM frame (which will always have the DataLen set).
|
||||
remainingLen += quicvarint.Len(uint64(remainingLen))
|
||||
frame, ok, hasMoreData := str.popStreamFrame(remainingLen, v)
|
||||
if hasMoreData { // put the stream back in the queue (at the end)
|
||||
f.streamQueue.PushBack(id)
|
||||
} else { // no more data to send. Stream is not active
|
||||
delete(f.activeStreams, id)
|
||||
}
|
||||
// The frame can be "nil"
|
||||
// * if the receiveStream was canceled after it said it had data
|
||||
// * the remaining size doesn't allow us to add another STREAM frame
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
frames = append(frames, frame)
|
||||
length += frame.Frame.Length(v)
|
||||
}
|
||||
f.mutex.Unlock()
|
||||
if len(frames) > startLen {
|
||||
l := frames[len(frames)-1].Frame.Length(v)
|
||||
// account for the smaller size of the last STREAM frame
|
||||
frames[len(frames)-1].Frame.DataLenPresent = false
|
||||
length += frames[len(frames)-1].Frame.Length(v) - l
|
||||
}
|
||||
return frames, length
|
||||
}
|
||||
|
||||
func (f *framerI) Handle0RTTRejection() error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
f.controlFrameMutex.Lock()
|
||||
f.streamQueue.Clear()
|
||||
for id := range f.activeStreams {
|
||||
delete(f.activeStreams, id)
|
||||
}
|
||||
var j int
|
||||
for i, frame := range f.controlFrames {
|
||||
switch frame.(type) {
|
||||
case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame:
|
||||
return errors.New("didn't expect MAX_DATA / MAX_STREAM_DATA / MAX_STREAMS frame to be sent in 0-RTT")
|
||||
case *wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame:
|
||||
continue
|
||||
default:
|
||||
f.controlFrames[j] = f.controlFrames[i]
|
||||
j++
|
||||
}
|
||||
}
|
||||
f.controlFrames = f.controlFrames[:j]
|
||||
f.controlFrameMutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
136
vendor/github.com/quic-go/quic-go/http3/body.go
generated
vendored
Normal file
136
vendor/github.com/quic-go/quic-go/http3/body.go
generated
vendored
Normal file
@@ -0,0 +1,136 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by:
|
||||
// * for the server: the http.Request.Body
|
||||
// * for the client: the http.Response.Body
|
||||
// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set.
|
||||
// When a stream is taken over, it's the caller's responsibility to close the stream.
|
||||
type HTTPStreamer interface {
|
||||
HTTPStream() Stream
|
||||
}
|
||||
|
||||
type StreamCreator interface {
|
||||
// Context returns a context that is cancelled when the underlying connection is closed.
|
||||
Context() context.Context
|
||||
OpenStream() (quic.Stream, error)
|
||||
OpenStreamSync(context.Context) (quic.Stream, error)
|
||||
OpenUniStream() (quic.SendStream, error)
|
||||
OpenUniStreamSync(context.Context) (quic.SendStream, error)
|
||||
LocalAddr() net.Addr
|
||||
RemoteAddr() net.Addr
|
||||
ConnectionState() quic.ConnectionState
|
||||
}
|
||||
|
||||
var _ StreamCreator = quic.Connection(nil)
|
||||
|
||||
// A Hijacker allows hijacking of the stream creating part of a quic.Session from a http.Response.Body.
|
||||
// It is used by WebTransport to create WebTransport streams after a session has been established.
|
||||
type Hijacker interface {
|
||||
StreamCreator() StreamCreator
|
||||
}
|
||||
|
||||
// The body of a http.Request or http.Response.
|
||||
type body struct {
|
||||
str quic.Stream
|
||||
|
||||
wasHijacked bool // set when HTTPStream is called
|
||||
}
|
||||
|
||||
var (
|
||||
_ io.ReadCloser = &body{}
|
||||
_ HTTPStreamer = &body{}
|
||||
)
|
||||
|
||||
func newRequestBody(str Stream) *body {
|
||||
return &body{str: str}
|
||||
}
|
||||
|
||||
func (r *body) HTTPStream() Stream {
|
||||
r.wasHijacked = true
|
||||
return r.str
|
||||
}
|
||||
|
||||
func (r *body) wasStreamHijacked() bool {
|
||||
return r.wasHijacked
|
||||
}
|
||||
|
||||
func (r *body) Read(b []byte) (int, error) {
|
||||
n, err := r.str.Read(b)
|
||||
return n, maybeReplaceError(err)
|
||||
}
|
||||
|
||||
func (r *body) Close() error {
|
||||
r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
return nil
|
||||
}
|
||||
|
||||
type hijackableBody struct {
|
||||
body
|
||||
conn quic.Connection // only needed to implement Hijacker
|
||||
|
||||
// only set for the http.Response
|
||||
// The channel is closed when the user is done with this response:
|
||||
// either when Read() errors, or when Close() is called.
|
||||
reqDone chan<- struct{}
|
||||
reqDoneClosed bool
|
||||
}
|
||||
|
||||
var (
|
||||
_ Hijacker = &hijackableBody{}
|
||||
_ HTTPStreamer = &hijackableBody{}
|
||||
)
|
||||
|
||||
func newResponseBody(str Stream, conn quic.Connection, done chan<- struct{}) *hijackableBody {
|
||||
return &hijackableBody{
|
||||
body: body{
|
||||
str: str,
|
||||
},
|
||||
reqDone: done,
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *hijackableBody) StreamCreator() StreamCreator {
|
||||
return r.conn
|
||||
}
|
||||
|
||||
func (r *hijackableBody) Read(b []byte) (int, error) {
|
||||
n, err := r.str.Read(b)
|
||||
if err != nil {
|
||||
r.requestDone()
|
||||
}
|
||||
return n, maybeReplaceError(err)
|
||||
}
|
||||
|
||||
func (r *hijackableBody) requestDone() {
|
||||
if r.reqDoneClosed || r.reqDone == nil {
|
||||
return
|
||||
}
|
||||
if r.reqDone != nil {
|
||||
close(r.reqDone)
|
||||
}
|
||||
r.reqDoneClosed = true
|
||||
}
|
||||
|
||||
func (r *body) StreamID() quic.StreamID {
|
||||
return r.str.StreamID()
|
||||
}
|
||||
|
||||
func (r *hijackableBody) Close() error {
|
||||
r.requestDone()
|
||||
// If the EOF was read, CancelRead() is a no-op.
|
||||
r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *hijackableBody) HTTPStream() Stream {
|
||||
return r.str
|
||||
}
|
||||
55
vendor/github.com/quic-go/quic-go/http3/capsule.go
generated
vendored
Normal file
55
vendor/github.com/quic-go/quic-go/http3/capsule.go
generated
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
// CapsuleType is the type of the capsule.
|
||||
type CapsuleType uint64
|
||||
|
||||
type exactReader struct {
|
||||
R *io.LimitedReader
|
||||
}
|
||||
|
||||
func (r *exactReader) Read(b []byte) (int, error) {
|
||||
n, err := r.R.Read(b)
|
||||
if r.R.N > 0 {
|
||||
return n, io.ErrUnexpectedEOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// ParseCapsule parses the header of a Capsule.
|
||||
// It returns an io.LimitedReader that can be used to read the Capsule value.
|
||||
// The Capsule value must be read entirely (i.e. until the io.EOF) before using r again.
|
||||
func ParseCapsule(r quicvarint.Reader) (CapsuleType, io.Reader, error) {
|
||||
ct, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return 0, nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
return 0, nil, err
|
||||
}
|
||||
l, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return 0, nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
return 0, nil, err
|
||||
}
|
||||
return CapsuleType(ct), &exactReader{R: io.LimitReader(r, int64(l)).(*io.LimitedReader)}, nil
|
||||
}
|
||||
|
||||
// WriteCapsule writes a capsule
|
||||
func WriteCapsule(w quicvarint.Writer, ct CapsuleType, value []byte) error {
|
||||
b := make([]byte, 0, 16)
|
||||
b = quicvarint.Append(b, uint64(ct))
|
||||
b = quicvarint.Append(b, uint64(len(value)))
|
||||
if _, err := w.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := w.Write(value)
|
||||
return err
|
||||
}
|
||||
477
vendor/github.com/quic-go/quic-go/http3/client.go
generated
vendored
Normal file
477
vendor/github.com/quic-go/quic-go/http3/client.go
generated
vendored
Normal file
@@ -0,0 +1,477 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
)
|
||||
|
||||
// MethodGet0RTT allows a GET request to be sent using 0-RTT.
|
||||
// Note that 0-RTT data doesn't provide replay protection.
|
||||
const MethodGet0RTT = "GET_0RTT"
|
||||
|
||||
const (
|
||||
defaultUserAgent = "quic-go HTTP/3"
|
||||
defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
|
||||
)
|
||||
|
||||
var defaultQuicConfig = &quic.Config{
|
||||
MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
|
||||
KeepAlivePeriod: 10 * time.Second,
|
||||
}
|
||||
|
||||
type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
|
||||
|
||||
var dialAddr dialFunc = quic.DialAddrEarly
|
||||
|
||||
type roundTripperOpts struct {
|
||||
DisableCompression bool
|
||||
EnableDatagram bool
|
||||
MaxHeaderBytes int64
|
||||
AdditionalSettings map[uint64]uint64
|
||||
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
|
||||
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
|
||||
}
|
||||
|
||||
// client is a HTTP3 client doing requests
|
||||
type client struct {
|
||||
tlsConf *tls.Config
|
||||
config *quic.Config
|
||||
opts *roundTripperOpts
|
||||
|
||||
dialOnce sync.Once
|
||||
dialer dialFunc
|
||||
handshakeErr error
|
||||
|
||||
requestWriter *requestWriter
|
||||
|
||||
decoder *qpack.Decoder
|
||||
|
||||
hostname string
|
||||
conn atomic.Pointer[quic.EarlyConnection]
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ roundTripCloser = &client{}
|
||||
|
||||
func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
|
||||
if conf == nil {
|
||||
conf = defaultQuicConfig.Clone()
|
||||
}
|
||||
if len(conf.Versions) == 0 {
|
||||
conf = conf.Clone()
|
||||
conf.Versions = []quic.VersionNumber{protocol.SupportedVersions[0]}
|
||||
}
|
||||
if len(conf.Versions) != 1 {
|
||||
return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
|
||||
}
|
||||
if conf.MaxIncomingStreams == 0 {
|
||||
conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams
|
||||
}
|
||||
conf.EnableDatagrams = opts.EnableDatagram
|
||||
logger := utils.DefaultLogger.WithPrefix("h3 client")
|
||||
|
||||
if tlsConf == nil {
|
||||
tlsConf = &tls.Config{}
|
||||
} else {
|
||||
tlsConf = tlsConf.Clone()
|
||||
}
|
||||
if tlsConf.ServerName == "" {
|
||||
sni, _, err := net.SplitHostPort(hostname)
|
||||
if err != nil {
|
||||
// It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
|
||||
sni = hostname
|
||||
}
|
||||
tlsConf.ServerName = sni
|
||||
}
|
||||
// Replace existing ALPNs by H3
|
||||
tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])}
|
||||
|
||||
return &client{
|
||||
hostname: authorityAddr("https", hostname),
|
||||
tlsConf: tlsConf,
|
||||
requestWriter: newRequestWriter(logger),
|
||||
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
|
||||
config: conf,
|
||||
opts: opts,
|
||||
dialer: dialer,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *client) dial(ctx context.Context) error {
|
||||
var err error
|
||||
var conn quic.EarlyConnection
|
||||
if c.dialer != nil {
|
||||
conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config)
|
||||
} else {
|
||||
conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.conn.Store(&conn)
|
||||
|
||||
// send the SETTINGs frame, using 0-RTT data, if possible
|
||||
go func() {
|
||||
if err := c.setupConn(conn); err != nil {
|
||||
c.logger.Debugf("Setting up connection failed: %s", err)
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
|
||||
}
|
||||
}()
|
||||
|
||||
if c.opts.StreamHijacker != nil {
|
||||
go c.handleBidirectionalStreams(conn)
|
||||
}
|
||||
go c.handleUnidirectionalStreams(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) setupConn(conn quic.EarlyConnection) error {
|
||||
// open the control stream
|
||||
str, err := conn.OpenUniStream()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b := make([]byte, 0, 64)
|
||||
b = quicvarint.Append(b, streamTypeControlStream)
|
||||
// send the SETTINGS frame
|
||||
b = (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Append(b)
|
||||
_, err = str.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) {
|
||||
for {
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
c.logger.Debugf("accepting bidirectional stream failed: %s", err)
|
||||
return
|
||||
}
|
||||
go func(str quic.Stream) {
|
||||
_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
|
||||
return c.opts.StreamHijacker(ft, conn, str, e)
|
||||
})
|
||||
if err == errHijacked {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
c.logger.Debugf("error handling stream: %s", err)
|
||||
}
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
|
||||
}(str)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) {
|
||||
for {
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
if err != nil {
|
||||
c.logger.Debugf("accepting unidirectional stream failed: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
go func(str quic.ReceiveStream) {
|
||||
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
|
||||
if err != nil {
|
||||
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) {
|
||||
return
|
||||
}
|
||||
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
|
||||
return
|
||||
}
|
||||
// We're only interested in the control stream here.
|
||||
switch streamType {
|
||||
case streamTypeControlStream:
|
||||
case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream:
|
||||
// Our QPACK implementation doesn't use the dynamic table yet.
|
||||
// TODO: check that only one stream of each type is opened.
|
||||
return
|
||||
case streamTypePushStream:
|
||||
// We never increased the Push ID, so we don't expect any push streams.
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
|
||||
return
|
||||
default:
|
||||
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
|
||||
return
|
||||
}
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
|
||||
return
|
||||
}
|
||||
f, err := parseNextFrame(str, nil)
|
||||
if err != nil {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
|
||||
return
|
||||
}
|
||||
sf, ok := f.(*settingsFrame)
|
||||
if !ok {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
|
||||
return
|
||||
}
|
||||
if !sf.Datagram {
|
||||
return
|
||||
}
|
||||
// If datagram support was enabled on our side as well as on the server side,
|
||||
// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
|
||||
// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
|
||||
if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
|
||||
}
|
||||
}(str)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Close() error {
|
||||
conn := c.conn.Load()
|
||||
if conn == nil {
|
||||
return nil
|
||||
}
|
||||
return (*conn).CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
|
||||
}
|
||||
|
||||
func (c *client) maxHeaderBytes() uint64 {
|
||||
if c.opts.MaxHeaderBytes <= 0 {
|
||||
return defaultMaxResponseHeaderBytes
|
||||
}
|
||||
return uint64(c.opts.MaxHeaderBytes)
|
||||
}
|
||||
|
||||
// RoundTripOpt executes a request and returns a response
|
||||
func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
|
||||
return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
|
||||
}
|
||||
|
||||
c.dialOnce.Do(func() {
|
||||
c.handshakeErr = c.dial(req.Context())
|
||||
})
|
||||
if c.handshakeErr != nil {
|
||||
return nil, c.handshakeErr
|
||||
}
|
||||
|
||||
// At this point, c.conn is guaranteed to be set.
|
||||
conn := *c.conn.Load()
|
||||
|
||||
// Immediately send out this request, if this is a 0-RTT request.
|
||||
if req.Method == MethodGet0RTT {
|
||||
req.Method = http.MethodGet
|
||||
} else {
|
||||
// wait for the handshake to complete
|
||||
select {
|
||||
case <-conn.HandshakeComplete():
|
||||
case <-req.Context().Done():
|
||||
return nil, req.Context().Err()
|
||||
}
|
||||
}
|
||||
|
||||
str, err := conn.OpenStreamSync(req.Context())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Request Cancellation:
|
||||
// This go routine keeps running even after RoundTripOpt() returns.
|
||||
// It is shut down when the application is done processing the body.
|
||||
reqDone := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
case <-reqDone:
|
||||
}
|
||||
}()
|
||||
|
||||
doneChan := reqDone
|
||||
if opt.DontCloseRequestStream {
|
||||
doneChan = nil
|
||||
}
|
||||
rsp, rerr := c.doRequest(req, conn, str, opt, doneChan)
|
||||
if rerr.err != nil { // if any error occurred
|
||||
close(reqDone)
|
||||
<-done
|
||||
if rerr.streamErr != 0 { // if it was a stream error
|
||||
str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
|
||||
}
|
||||
if rerr.connErr != 0 { // if it was a connection error
|
||||
var reason string
|
||||
if rerr.err != nil {
|
||||
reason = rerr.err.Error()
|
||||
}
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
|
||||
}
|
||||
return nil, maybeReplaceError(rerr.err)
|
||||
}
|
||||
if opt.DontCloseRequestStream {
|
||||
close(reqDone)
|
||||
<-done
|
||||
}
|
||||
return rsp, maybeReplaceError(rerr.err)
|
||||
}
|
||||
|
||||
// cancelingReader reads from the io.Reader.
|
||||
// It cancels writing on the stream if any error other than io.EOF occurs.
|
||||
type cancelingReader struct {
|
||||
r io.Reader
|
||||
str Stream
|
||||
}
|
||||
|
||||
func (r *cancelingReader) Read(b []byte) (int, error) {
|
||||
n, err := r.r.Read(b)
|
||||
if err != nil && err != io.EOF {
|
||||
r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
|
||||
defer body.Close()
|
||||
buf := make([]byte, bodyCopyBufferSize)
|
||||
sr := &cancelingReader{str: str, r: body}
|
||||
if contentLength == -1 {
|
||||
_, err := io.CopyBuffer(str, sr, buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// make sure we don't send more bytes than the content length
|
||||
n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var extra int64
|
||||
extra, err = io.CopyBuffer(io.Discard, sr, buf)
|
||||
n += extra
|
||||
if n > contentLength {
|
||||
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
|
||||
return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
|
||||
var requestGzip bool
|
||||
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
|
||||
requestGzip = true
|
||||
}
|
||||
if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil {
|
||||
return nil, newStreamError(ErrCodeInternalError, err)
|
||||
}
|
||||
|
||||
if req.Body == nil && !opt.DontCloseRequestStream {
|
||||
str.Close()
|
||||
}
|
||||
|
||||
hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") })
|
||||
if req.Body != nil {
|
||||
// send the request body asynchronously
|
||||
go func() {
|
||||
contentLength := int64(-1)
|
||||
// According to the documentation for http.Request.ContentLength,
|
||||
// a value of 0 with a non-nil Body is also treated as unknown content length.
|
||||
if req.ContentLength > 0 {
|
||||
contentLength = req.ContentLength
|
||||
}
|
||||
if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil {
|
||||
c.logger.Errorf("Error writing request: %s", err)
|
||||
}
|
||||
if !opt.DontCloseRequestStream {
|
||||
hstr.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
frame, err := parseNextFrame(str, nil)
|
||||
if err != nil {
|
||||
return nil, newStreamError(ErrCodeFrameError, err)
|
||||
}
|
||||
hf, ok := frame.(*headersFrame)
|
||||
if !ok {
|
||||
return nil, newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
|
||||
}
|
||||
if hf.Length > c.maxHeaderBytes() {
|
||||
return nil, newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()))
|
||||
}
|
||||
headerBlock := make([]byte, hf.Length)
|
||||
if _, err := io.ReadFull(str, headerBlock); err != nil {
|
||||
return nil, newStreamError(ErrCodeRequestIncomplete, err)
|
||||
}
|
||||
hfs, err := c.decoder.DecodeFull(headerBlock)
|
||||
if err != nil {
|
||||
// TODO: use the right error code
|
||||
return nil, newConnError(ErrCodeGeneralProtocolError, err)
|
||||
}
|
||||
|
||||
res, err := responseFromHeaders(hfs)
|
||||
if err != nil {
|
||||
return nil, newStreamError(ErrCodeMessageError, err)
|
||||
}
|
||||
connState := conn.ConnectionState().TLS
|
||||
res.TLS = &connState
|
||||
res.Request = req
|
||||
// Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
|
||||
// See section 4.1.2 of RFC 9114.
|
||||
var httpStr Stream
|
||||
if _, ok := res.Header["Content-Length"]; ok && res.ContentLength >= 0 {
|
||||
httpStr = newLengthLimitedStream(hstr, res.ContentLength)
|
||||
} else {
|
||||
httpStr = hstr
|
||||
}
|
||||
respBody := newResponseBody(httpStr, conn, reqDone)
|
||||
|
||||
// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
|
||||
_, hasTransferEncoding := res.Header["Transfer-Encoding"]
|
||||
isInformational := res.StatusCode >= 100 && res.StatusCode < 200
|
||||
isNoContent := res.StatusCode == http.StatusNoContent
|
||||
isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300
|
||||
if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect {
|
||||
res.ContentLength = -1
|
||||
if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 {
|
||||
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
|
||||
res.ContentLength = clen64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
|
||||
res.Header.Del("Content-Encoding")
|
||||
res.Header.Del("Content-Length")
|
||||
res.ContentLength = -1
|
||||
res.Body = newGzipReader(respBody)
|
||||
res.Uncompressed = true
|
||||
} else {
|
||||
res.Body = respBody
|
||||
}
|
||||
|
||||
return res, requestError{}
|
||||
}
|
||||
|
||||
func (c *client) HandshakeComplete() bool {
|
||||
conn := c.conn.Load()
|
||||
if conn == nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case <-(*conn).HandshakeComplete():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
58
vendor/github.com/quic-go/quic-go/http3/error.go
generated
vendored
Normal file
58
vendor/github.com/quic-go/quic-go/http3/error.go
generated
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
// Error is returned from the round tripper (for HTTP clients)
|
||||
// and inside the HTTP handler (for HTTP servers) if an HTTP/3 error occurs.
|
||||
// See section 8 of RFC 9114.
|
||||
type Error struct {
|
||||
Remote bool
|
||||
ErrorCode ErrCode
|
||||
ErrorMessage string
|
||||
}
|
||||
|
||||
var _ error = &Error{}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
s := e.ErrorCode.string()
|
||||
if s == "" {
|
||||
s = fmt.Sprintf("H3 error (%#x)", uint64(e.ErrorCode))
|
||||
}
|
||||
// Usually errors are remote. Only make it explicit for local errors.
|
||||
if !e.Remote {
|
||||
s += " (local)"
|
||||
}
|
||||
if e.ErrorMessage != "" {
|
||||
s += ": " + e.ErrorMessage
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func maybeReplaceError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
e Error
|
||||
strErr *quic.StreamError
|
||||
appErr *quic.ApplicationError
|
||||
)
|
||||
switch {
|
||||
default:
|
||||
return err
|
||||
case errors.As(err, &strErr):
|
||||
e.Remote = strErr.Remote
|
||||
e.ErrorCode = ErrCode(strErr.ErrorCode)
|
||||
case errors.As(err, &appErr):
|
||||
e.Remote = appErr.Remote
|
||||
e.ErrorCode = ErrCode(appErr.ErrorCode)
|
||||
e.ErrorMessage = appErr.ErrorMessage
|
||||
}
|
||||
return &e
|
||||
}
|
||||
81
vendor/github.com/quic-go/quic-go/http3/error_codes.go
generated
vendored
Normal file
81
vendor/github.com/quic-go/quic-go/http3/error_codes.go
generated
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
type ErrCode quic.ApplicationErrorCode
|
||||
|
||||
const (
|
||||
ErrCodeNoError ErrCode = 0x100
|
||||
ErrCodeGeneralProtocolError ErrCode = 0x101
|
||||
ErrCodeInternalError ErrCode = 0x102
|
||||
ErrCodeStreamCreationError ErrCode = 0x103
|
||||
ErrCodeClosedCriticalStream ErrCode = 0x104
|
||||
ErrCodeFrameUnexpected ErrCode = 0x105
|
||||
ErrCodeFrameError ErrCode = 0x106
|
||||
ErrCodeExcessiveLoad ErrCode = 0x107
|
||||
ErrCodeIDError ErrCode = 0x108
|
||||
ErrCodeSettingsError ErrCode = 0x109
|
||||
ErrCodeMissingSettings ErrCode = 0x10a
|
||||
ErrCodeRequestRejected ErrCode = 0x10b
|
||||
ErrCodeRequestCanceled ErrCode = 0x10c
|
||||
ErrCodeRequestIncomplete ErrCode = 0x10d
|
||||
ErrCodeMessageError ErrCode = 0x10e
|
||||
ErrCodeConnectError ErrCode = 0x10f
|
||||
ErrCodeVersionFallback ErrCode = 0x110
|
||||
ErrCodeDatagramError ErrCode = 0x33
|
||||
)
|
||||
|
||||
func (e ErrCode) String() string {
|
||||
s := e.string()
|
||||
if s != "" {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("unknown error code: %#x", uint16(e))
|
||||
}
|
||||
|
||||
func (e ErrCode) string() string {
|
||||
switch e {
|
||||
case ErrCodeNoError:
|
||||
return "H3_NO_ERROR"
|
||||
case ErrCodeGeneralProtocolError:
|
||||
return "H3_GENERAL_PROTOCOL_ERROR"
|
||||
case ErrCodeInternalError:
|
||||
return "H3_INTERNAL_ERROR"
|
||||
case ErrCodeStreamCreationError:
|
||||
return "H3_STREAM_CREATION_ERROR"
|
||||
case ErrCodeClosedCriticalStream:
|
||||
return "H3_CLOSED_CRITICAL_STREAM"
|
||||
case ErrCodeFrameUnexpected:
|
||||
return "H3_FRAME_UNEXPECTED"
|
||||
case ErrCodeFrameError:
|
||||
return "H3_FRAME_ERROR"
|
||||
case ErrCodeExcessiveLoad:
|
||||
return "H3_EXCESSIVE_LOAD"
|
||||
case ErrCodeIDError:
|
||||
return "H3_ID_ERROR"
|
||||
case ErrCodeSettingsError:
|
||||
return "H3_SETTINGS_ERROR"
|
||||
case ErrCodeMissingSettings:
|
||||
return "H3_MISSING_SETTINGS"
|
||||
case ErrCodeRequestRejected:
|
||||
return "H3_REQUEST_REJECTED"
|
||||
case ErrCodeRequestCanceled:
|
||||
return "H3_REQUEST_CANCELLED"
|
||||
case ErrCodeRequestIncomplete:
|
||||
return "H3_INCOMPLETE_REQUEST"
|
||||
case ErrCodeMessageError:
|
||||
return "H3_MESSAGE_ERROR"
|
||||
case ErrCodeConnectError:
|
||||
return "H3_CONNECT_ERROR"
|
||||
case ErrCodeVersionFallback:
|
||||
return "H3_VERSION_FALLBACK"
|
||||
case ErrCodeDatagramError:
|
||||
return "H3_DATAGRAM_ERROR"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
164
vendor/github.com/quic-go/quic-go/http3/frames.go
generated
vendored
Normal file
164
vendor/github.com/quic-go/quic-go/http3/frames.go
generated
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
// FrameType is the frame type of a HTTP/3 frame
|
||||
type FrameType uint64
|
||||
|
||||
type unknownFrameHandlerFunc func(FrameType, error) (processed bool, err error)
|
||||
|
||||
type frame interface{}
|
||||
|
||||
var errHijacked = errors.New("hijacked")
|
||||
|
||||
func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) {
|
||||
qr := quicvarint.NewReader(r)
|
||||
for {
|
||||
t, err := quicvarint.Read(qr)
|
||||
if err != nil {
|
||||
if unknownFrameHandler != nil {
|
||||
hijacked, err := unknownFrameHandler(0, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hijacked {
|
||||
return nil, errHijacked
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
// Call the unknownFrameHandler for frames not defined in the HTTP/3 spec
|
||||
if t > 0xd && unknownFrameHandler != nil {
|
||||
hijacked, err := unknownFrameHandler(FrameType(t), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hijacked {
|
||||
return nil, errHijacked
|
||||
}
|
||||
// If the unknownFrameHandler didn't process the frame, it is our responsibility to skip it.
|
||||
}
|
||||
l, err := quicvarint.Read(qr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case 0x0:
|
||||
return &dataFrame{Length: l}, nil
|
||||
case 0x1:
|
||||
return &headersFrame{Length: l}, nil
|
||||
case 0x4:
|
||||
return parseSettingsFrame(r, l)
|
||||
case 0x3: // CANCEL_PUSH
|
||||
case 0x5: // PUSH_PROMISE
|
||||
case 0x7: // GOAWAY
|
||||
case 0xd: // MAX_PUSH_ID
|
||||
}
|
||||
// skip over unknown frames
|
||||
if _, err := io.CopyN(io.Discard, qr, int64(l)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type dataFrame struct {
|
||||
Length uint64
|
||||
}
|
||||
|
||||
func (f *dataFrame) Append(b []byte) []byte {
|
||||
b = quicvarint.Append(b, 0x0)
|
||||
return quicvarint.Append(b, f.Length)
|
||||
}
|
||||
|
||||
type headersFrame struct {
|
||||
Length uint64
|
||||
}
|
||||
|
||||
func (f *headersFrame) Append(b []byte) []byte {
|
||||
b = quicvarint.Append(b, 0x1)
|
||||
return quicvarint.Append(b, f.Length)
|
||||
}
|
||||
|
||||
const settingDatagram = 0x33
|
||||
|
||||
type settingsFrame struct {
|
||||
Datagram bool
|
||||
Other map[uint64]uint64 // all settings that we don't explicitly recognize
|
||||
}
|
||||
|
||||
func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) {
|
||||
if l > 8*(1<<10) {
|
||||
return nil, fmt.Errorf("unexpected size for SETTINGS frame: %d", l)
|
||||
}
|
||||
buf := make([]byte, l)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
frame := &settingsFrame{}
|
||||
b := bytes.NewReader(buf)
|
||||
var readDatagram bool
|
||||
for b.Len() > 0 {
|
||||
id, err := quicvarint.Read(b)
|
||||
if err != nil { // should not happen. We allocated the whole frame already.
|
||||
return nil, err
|
||||
}
|
||||
val, err := quicvarint.Read(b)
|
||||
if err != nil { // should not happen. We allocated the whole frame already.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch id {
|
||||
case settingDatagram:
|
||||
if readDatagram {
|
||||
return nil, fmt.Errorf("duplicate setting: %d", id)
|
||||
}
|
||||
readDatagram = true
|
||||
if val != 0 && val != 1 {
|
||||
return nil, fmt.Errorf("invalid value for H3_DATAGRAM: %d", val)
|
||||
}
|
||||
frame.Datagram = val == 1
|
||||
default:
|
||||
if _, ok := frame.Other[id]; ok {
|
||||
return nil, fmt.Errorf("duplicate setting: %d", id)
|
||||
}
|
||||
if frame.Other == nil {
|
||||
frame.Other = make(map[uint64]uint64)
|
||||
}
|
||||
frame.Other[id] = val
|
||||
}
|
||||
}
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
func (f *settingsFrame) Append(b []byte) []byte {
|
||||
b = quicvarint.Append(b, 0x4)
|
||||
var l protocol.ByteCount
|
||||
for id, val := range f.Other {
|
||||
l += quicvarint.Len(id) + quicvarint.Len(val)
|
||||
}
|
||||
if f.Datagram {
|
||||
l += quicvarint.Len(settingDatagram) + quicvarint.Len(1)
|
||||
}
|
||||
b = quicvarint.Append(b, uint64(l))
|
||||
if f.Datagram {
|
||||
b = quicvarint.Append(b, settingDatagram)
|
||||
b = quicvarint.Append(b, 1)
|
||||
}
|
||||
for id, val := range f.Other {
|
||||
b = quicvarint.Append(b, id)
|
||||
b = quicvarint.Append(b, val)
|
||||
}
|
||||
return b
|
||||
}
|
||||
39
vendor/github.com/quic-go/quic-go/http3/gzip_reader.go
generated
vendored
Normal file
39
vendor/github.com/quic-go/quic-go/http3/gzip_reader.go
generated
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
package http3
|
||||
|
||||
// copied from net/transport.go
|
||||
|
||||
// gzipReader wraps a response body so it can lazily
|
||||
// call gzip.NewReader on the first call to Read
|
||||
import (
|
||||
"compress/gzip"
|
||||
"io"
|
||||
)
|
||||
|
||||
// call gzip.NewReader on the first call to Read
|
||||
type gzipReader struct {
|
||||
body io.ReadCloser // underlying Response.Body
|
||||
zr *gzip.Reader // lazily-initialized gzip reader
|
||||
zerr error // sticky error
|
||||
}
|
||||
|
||||
func newGzipReader(body io.ReadCloser) io.ReadCloser {
|
||||
return &gzipReader{body: body}
|
||||
}
|
||||
|
||||
func (gz *gzipReader) Read(p []byte) (n int, err error) {
|
||||
if gz.zerr != nil {
|
||||
return 0, gz.zerr
|
||||
}
|
||||
if gz.zr == nil {
|
||||
gz.zr, err = gzip.NewReader(gz.body)
|
||||
if err != nil {
|
||||
gz.zerr = err
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return gz.zr.Read(p)
|
||||
}
|
||||
|
||||
func (gz *gzipReader) Close() error {
|
||||
return gz.body.Close()
|
||||
}
|
||||
198
vendor/github.com/quic-go/quic-go/http3/headers.go
generated
vendored
Normal file
198
vendor/github.com/quic-go/quic-go/http3/headers.go
generated
vendored
Normal file
@@ -0,0 +1,198 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
)
|
||||
|
||||
type header struct {
|
||||
// Pseudo header fields defined in RFC 9114
|
||||
Path string
|
||||
Method string
|
||||
Authority string
|
||||
Scheme string
|
||||
Status string
|
||||
// for Extended connect
|
||||
Protocol string
|
||||
// parsed and deduplicated
|
||||
ContentLength int64
|
||||
// all non-pseudo headers
|
||||
Headers http.Header
|
||||
}
|
||||
|
||||
func parseHeaders(headers []qpack.HeaderField, isRequest bool) (header, error) {
|
||||
hdr := header{Headers: make(http.Header, len(headers))}
|
||||
var readFirstRegularHeader, readContentLength bool
|
||||
var contentLengthStr string
|
||||
for _, h := range headers {
|
||||
// field names need to be lowercase, see section 4.2 of RFC 9114
|
||||
if strings.ToLower(h.Name) != h.Name {
|
||||
return header{}, fmt.Errorf("header field is not lower-case: %s", h.Name)
|
||||
}
|
||||
if !httpguts.ValidHeaderFieldValue(h.Value) {
|
||||
return header{}, fmt.Errorf("invalid header field value for %s: %q", h.Name, h.Value)
|
||||
}
|
||||
if h.IsPseudo() {
|
||||
if readFirstRegularHeader {
|
||||
// all pseudo headers must appear before regular header fields, see section 4.3 of RFC 9114
|
||||
return header{}, fmt.Errorf("received pseudo header %s after a regular header field", h.Name)
|
||||
}
|
||||
var isResponsePseudoHeader bool // pseudo headers are either valid for requests or for responses
|
||||
switch h.Name {
|
||||
case ":path":
|
||||
hdr.Path = h.Value
|
||||
case ":method":
|
||||
hdr.Method = h.Value
|
||||
case ":authority":
|
||||
hdr.Authority = h.Value
|
||||
case ":protocol":
|
||||
hdr.Protocol = h.Value
|
||||
case ":scheme":
|
||||
hdr.Scheme = h.Value
|
||||
case ":status":
|
||||
hdr.Status = h.Value
|
||||
isResponsePseudoHeader = true
|
||||
default:
|
||||
return header{}, fmt.Errorf("unknown pseudo header: %s", h.Name)
|
||||
}
|
||||
if isRequest && isResponsePseudoHeader {
|
||||
return header{}, fmt.Errorf("invalid request pseudo header: %s", h.Name)
|
||||
}
|
||||
if !isRequest && !isResponsePseudoHeader {
|
||||
return header{}, fmt.Errorf("invalid response pseudo header: %s", h.Name)
|
||||
}
|
||||
} else {
|
||||
if !httpguts.ValidHeaderFieldName(h.Name) {
|
||||
return header{}, fmt.Errorf("invalid header field name: %q", h.Name)
|
||||
}
|
||||
readFirstRegularHeader = true
|
||||
switch h.Name {
|
||||
case "content-length":
|
||||
// Ignore duplicate Content-Length headers.
|
||||
// Fail if the duplicates differ.
|
||||
if !readContentLength {
|
||||
readContentLength = true
|
||||
contentLengthStr = h.Value
|
||||
} else if contentLengthStr != h.Value {
|
||||
return header{}, fmt.Errorf("contradicting content lengths (%s and %s)", contentLengthStr, h.Value)
|
||||
}
|
||||
default:
|
||||
hdr.Headers.Add(h.Name, h.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(contentLengthStr) > 0 {
|
||||
// use ParseUint instead of ParseInt, so that parsing fails on negative values
|
||||
cl, err := strconv.ParseUint(contentLengthStr, 10, 63)
|
||||
if err != nil {
|
||||
return header{}, fmt.Errorf("invalid content length: %w", err)
|
||||
}
|
||||
hdr.Headers.Set("Content-Length", contentLengthStr)
|
||||
hdr.ContentLength = int64(cl)
|
||||
}
|
||||
return hdr, nil
|
||||
}
|
||||
|
||||
func requestFromHeaders(headerFields []qpack.HeaderField) (*http.Request, error) {
|
||||
hdr, err := parseHeaders(headerFields, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
|
||||
if len(hdr.Headers["Cookie"]) > 0 {
|
||||
hdr.Headers.Set("Cookie", strings.Join(hdr.Headers["Cookie"], "; "))
|
||||
}
|
||||
|
||||
isConnect := hdr.Method == http.MethodConnect
|
||||
// Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4
|
||||
isExtendedConnected := isConnect && hdr.Protocol != ""
|
||||
if isExtendedConnected {
|
||||
if hdr.Scheme == "" || hdr.Path == "" || hdr.Authority == "" {
|
||||
return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty")
|
||||
}
|
||||
} else if isConnect {
|
||||
if hdr.Path != "" || hdr.Authority == "" { // normal CONNECT
|
||||
return nil, errors.New(":path must be empty and :authority must not be empty")
|
||||
}
|
||||
} else if len(hdr.Path) == 0 || len(hdr.Authority) == 0 || len(hdr.Method) == 0 {
|
||||
return nil, errors.New(":path, :authority and :method must not be empty")
|
||||
}
|
||||
|
||||
var u *url.URL
|
||||
var requestURI string
|
||||
var protocol string
|
||||
|
||||
if isConnect {
|
||||
u = &url.URL{}
|
||||
if isExtendedConnected {
|
||||
u, err = url.ParseRequestURI(hdr.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
u.Path = hdr.Path
|
||||
}
|
||||
u.Scheme = hdr.Scheme
|
||||
u.Host = hdr.Authority
|
||||
requestURI = hdr.Authority
|
||||
protocol = hdr.Protocol
|
||||
} else {
|
||||
protocol = "HTTP/3.0"
|
||||
u, err = url.ParseRequestURI(hdr.Path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid content length: %w", err)
|
||||
}
|
||||
requestURI = hdr.Path
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
Method: hdr.Method,
|
||||
URL: u,
|
||||
Proto: protocol,
|
||||
ProtoMajor: 3,
|
||||
ProtoMinor: 0,
|
||||
Header: hdr.Headers,
|
||||
Body: nil,
|
||||
ContentLength: hdr.ContentLength,
|
||||
Host: hdr.Authority,
|
||||
RequestURI: requestURI,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func hostnameFromRequest(req *http.Request) string {
|
||||
if req.URL != nil {
|
||||
return req.URL.Host
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func responseFromHeaders(headerFields []qpack.HeaderField) (*http.Response, error) {
|
||||
hdr, err := parseHeaders(headerFields, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hdr.Status == "" {
|
||||
return nil, errors.New("missing status field")
|
||||
}
|
||||
rsp := &http.Response{
|
||||
Proto: "HTTP/3.0",
|
||||
ProtoMajor: 3,
|
||||
Header: hdr.Headers,
|
||||
ContentLength: hdr.ContentLength,
|
||||
}
|
||||
status, err := strconv.Atoi(hdr.Status)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid status code: %w", err)
|
||||
}
|
||||
rsp.StatusCode = status
|
||||
rsp.Status = hdr.Status + " " + http.StatusText(status)
|
||||
return rsp, nil
|
||||
}
|
||||
124
vendor/github.com/quic-go/quic-go/http3/http_stream.go
generated
vendored
Normal file
124
vendor/github.com/quic-go/quic-go/http3/http_stream.go
generated
vendored
Normal file
@@ -0,0 +1,124 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// A Stream is a HTTP/3 stream.
|
||||
// When writing to and reading from the stream, data is framed in HTTP/3 DATA frames.
|
||||
type Stream quic.Stream
|
||||
|
||||
// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly
|
||||
// from the QUIC stream, it writes to and reads from the HTTP stream.
|
||||
type stream struct {
|
||||
quic.Stream
|
||||
|
||||
buf []byte
|
||||
|
||||
onFrameError func()
|
||||
bytesRemainingInFrame uint64
|
||||
}
|
||||
|
||||
var _ Stream = &stream{}
|
||||
|
||||
func newStream(str quic.Stream, onFrameError func()) *stream {
|
||||
return &stream{
|
||||
Stream: str,
|
||||
onFrameError: onFrameError,
|
||||
buf: make([]byte, 0, 16),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stream) Read(b []byte) (int, error) {
|
||||
if s.bytesRemainingInFrame == 0 {
|
||||
parseLoop:
|
||||
for {
|
||||
frame, err := parseNextFrame(s.Stream, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch f := frame.(type) {
|
||||
case *headersFrame:
|
||||
// skip HEADERS frames
|
||||
continue
|
||||
case *dataFrame:
|
||||
s.bytesRemainingInFrame = f.Length
|
||||
break parseLoop
|
||||
default:
|
||||
s.onFrameError()
|
||||
// parseNextFrame skips over unknown frame types
|
||||
// Therefore, this condition is only entered when we parsed another known frame type.
|
||||
return 0, fmt.Errorf("peer sent an unexpected frame: %T", f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var n int
|
||||
var err error
|
||||
if s.bytesRemainingInFrame < uint64(len(b)) {
|
||||
n, err = s.Stream.Read(b[:s.bytesRemainingInFrame])
|
||||
} else {
|
||||
n, err = s.Stream.Read(b)
|
||||
}
|
||||
s.bytesRemainingInFrame -= uint64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (s *stream) hasMoreData() bool {
|
||||
return s.bytesRemainingInFrame > 0
|
||||
}
|
||||
|
||||
func (s *stream) Write(b []byte) (int, error) {
|
||||
s.buf = s.buf[:0]
|
||||
s.buf = (&dataFrame{Length: uint64(len(b))}).Append(s.buf)
|
||||
if _, err := s.Stream.Write(s.buf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return s.Stream.Write(b)
|
||||
}
|
||||
|
||||
var errTooMuchData = errors.New("peer sent too much data")
|
||||
|
||||
type lengthLimitedStream struct {
|
||||
*stream
|
||||
contentLength int64
|
||||
read int64
|
||||
resetStream bool
|
||||
}
|
||||
|
||||
var _ Stream = &lengthLimitedStream{}
|
||||
|
||||
func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream {
|
||||
return &lengthLimitedStream{
|
||||
stream: str,
|
||||
contentLength: contentLength,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *lengthLimitedStream) checkContentLengthViolation() error {
|
||||
if s.read > s.contentLength || s.read == s.contentLength && s.hasMoreData() {
|
||||
if !s.resetStream {
|
||||
s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
|
||||
s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
|
||||
s.resetStream = true
|
||||
}
|
||||
return errTooMuchData
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *lengthLimitedStream) Read(b []byte) (int, error) {
|
||||
if err := s.checkContentLengthViolation(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err := s.stream.Read(b[:utils.Min(int64(len(b)), s.contentLength-s.read)])
|
||||
s.read += int64(n)
|
||||
if err := s.checkContentLengthViolation(); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
8
vendor/github.com/quic-go/quic-go/http3/mockgen.go
generated
vendored
Normal file
8
vendor/github.com/quic-go/quic-go/http3/mockgen.go
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build gomock || generate
|
||||
|
||||
package http3
|
||||
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 RoundTripCloser"
|
||||
type RoundTripCloser = roundTripCloser
|
||||
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/quic-go/quic-go/http3 QUICEarlyListener"
|
||||
287
vendor/github.com/quic-go/quic-go/http3/request_writer.go
generated
vendored
Normal file
287
vendor/github.com/quic-go/quic-go/http3/request_writer.go
generated
vendored
Normal file
@@ -0,0 +1,287 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
"golang.org/x/net/idna"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
const bodyCopyBufferSize = 8 * 1024
|
||||
|
||||
type requestWriter struct {
|
||||
mutex sync.Mutex
|
||||
encoder *qpack.Encoder
|
||||
headerBuf *bytes.Buffer
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func newRequestWriter(logger utils.Logger) *requestWriter {
|
||||
headerBuf := &bytes.Buffer{}
|
||||
encoder := qpack.NewEncoder(headerBuf)
|
||||
return &requestWriter{
|
||||
encoder: encoder,
|
||||
headerBuf: headerBuf,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *requestWriter) WriteRequestHeader(str quic.Stream, req *http.Request, gzip bool) error {
|
||||
// TODO: figure out how to add support for trailers
|
||||
buf := &bytes.Buffer{}
|
||||
if err := w.writeHeaders(buf, req, gzip); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := str.Write(buf.Bytes())
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
defer w.encoder.Close()
|
||||
defer w.headerBuf.Reset()
|
||||
|
||||
if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b := make([]byte, 0, 128)
|
||||
b = (&headersFrame{Length: uint64(w.headerBuf.Len())}).Append(b)
|
||||
if _, err := wr.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := wr.Write(w.headerBuf.Bytes())
|
||||
return err
|
||||
}
|
||||
|
||||
// copied from net/transport.go
|
||||
// Modified to support Extended CONNECT:
|
||||
// Contrary to what the godoc for the http.Request says,
|
||||
// we do respect the Proto field if the method is CONNECT.
|
||||
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) error {
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
host, err := httpguts.PunycodeHostPort(host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !httpguts.ValidHostHeader(host) {
|
||||
return errors.New("http3: invalid Host header")
|
||||
}
|
||||
|
||||
// http.NewRequest sets this field to HTTP/1.1
|
||||
isExtendedConnect := req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1"
|
||||
|
||||
var path string
|
||||
if req.Method != http.MethodConnect || isExtendedConnect {
|
||||
path = req.URL.RequestURI()
|
||||
if !validPseudoPath(path) {
|
||||
orig := path
|
||||
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
|
||||
if !validPseudoPath(path) {
|
||||
if req.URL.Opaque != "" {
|
||||
return fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
|
||||
} else {
|
||||
return fmt.Errorf("invalid request :path %q", orig)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for any invalid headers and return an error before we
|
||||
// potentially pollute our hpack state. (We want to be able to
|
||||
// continue to reuse the hpack encoder for future requests)
|
||||
for k, vv := range req.Header {
|
||||
if !httpguts.ValidHeaderFieldName(k) {
|
||||
return fmt.Errorf("invalid HTTP header name %q", k)
|
||||
}
|
||||
for _, v := range vv {
|
||||
if !httpguts.ValidHeaderFieldValue(v) {
|
||||
return fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enumerateHeaders := func(f func(name, value string)) {
|
||||
// 8.1.2.3 Request Pseudo-Header Fields
|
||||
// The :path pseudo-header field includes the path and query parts of the
|
||||
// target URI (the path-absolute production and optionally a '?' character
|
||||
// followed by the query production (see Sections 3.3 and 3.4 of
|
||||
// [RFC3986]).
|
||||
f(":authority", host)
|
||||
f(":method", req.Method)
|
||||
if req.Method != http.MethodConnect || isExtendedConnect {
|
||||
f(":path", path)
|
||||
f(":scheme", req.URL.Scheme)
|
||||
}
|
||||
if isExtendedConnect {
|
||||
f(":protocol", req.Proto)
|
||||
}
|
||||
if trailers != "" {
|
||||
f("trailer", trailers)
|
||||
}
|
||||
|
||||
var didUA bool
|
||||
for k, vv := range req.Header {
|
||||
if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") {
|
||||
// Host is :authority, already sent.
|
||||
// Content-Length is automatic, set below.
|
||||
continue
|
||||
} else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") ||
|
||||
strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") ||
|
||||
strings.EqualFold(k, "keep-alive") {
|
||||
// Per 8.1.2.2 Connection-Specific Header
|
||||
// Fields, don't send connection-specific
|
||||
// fields. We have already checked if any
|
||||
// are error-worthy so just ignore the rest.
|
||||
continue
|
||||
} else if strings.EqualFold(k, "user-agent") {
|
||||
// Match Go's http1 behavior: at most one
|
||||
// User-Agent. If set to nil or empty string,
|
||||
// then omit it. Otherwise if not mentioned,
|
||||
// include the default (below).
|
||||
didUA = true
|
||||
if len(vv) < 1 {
|
||||
continue
|
||||
}
|
||||
vv = vv[:1]
|
||||
if vv[0] == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for _, v := range vv {
|
||||
f(k, v)
|
||||
}
|
||||
}
|
||||
if shouldSendReqContentLength(req.Method, contentLength) {
|
||||
f("content-length", strconv.FormatInt(contentLength, 10))
|
||||
}
|
||||
if addGzipHeader {
|
||||
f("accept-encoding", "gzip")
|
||||
}
|
||||
if !didUA {
|
||||
f("user-agent", defaultUserAgent)
|
||||
}
|
||||
}
|
||||
|
||||
// Do a first pass over the headers counting bytes to ensure
|
||||
// we don't exceed cc.peerMaxHeaderListSize. This is done as a
|
||||
// separate pass before encoding the headers to prevent
|
||||
// modifying the hpack state.
|
||||
hlSize := uint64(0)
|
||||
enumerateHeaders(func(name, value string) {
|
||||
hf := hpack.HeaderField{Name: name, Value: value}
|
||||
hlSize += uint64(hf.Size())
|
||||
})
|
||||
|
||||
// TODO: check maximum header list size
|
||||
// if hlSize > cc.peerMaxHeaderListSize {
|
||||
// return errRequestHeaderListSize
|
||||
// }
|
||||
|
||||
// trace := httptrace.ContextClientTrace(req.Context())
|
||||
// traceHeaders := traceHasWroteHeaderField(trace)
|
||||
|
||||
// Header list size is ok. Write the headers.
|
||||
enumerateHeaders(func(name, value string) {
|
||||
name = strings.ToLower(name)
|
||||
w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value})
|
||||
// if traceHeaders {
|
||||
// traceWroteHeaderField(trace, name, value)
|
||||
// }
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
|
||||
// and returns a host:port. The port 443 is added if needed.
|
||||
func authorityAddr(scheme string, authority string) (addr string) {
|
||||
host, port, err := net.SplitHostPort(authority)
|
||||
if err != nil { // authority didn't have a port
|
||||
port = "443"
|
||||
if scheme == "http" {
|
||||
port = "80"
|
||||
}
|
||||
host = authority
|
||||
}
|
||||
if a, err := idna.ToASCII(host); err == nil {
|
||||
host = a
|
||||
}
|
||||
// IPv6 address literal, without a port:
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
return host + ":" + port
|
||||
}
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
// validPseudoPath reports whether v is a valid :path pseudo-header
|
||||
// value. It must be either:
|
||||
//
|
||||
// *) a non-empty string starting with '/'
|
||||
// *) the string '*', for OPTIONS requests.
|
||||
//
|
||||
// For now this is only used a quick check for deciding when to clean
|
||||
// up Opaque URLs before sending requests from the Transport.
|
||||
// See golang.org/issue/16847
|
||||
//
|
||||
// We used to enforce that the path also didn't start with "//", but
|
||||
// Google's GFE accepts such paths and Chrome sends them, so ignore
|
||||
// that part of the spec. See golang.org/issue/19103.
|
||||
func validPseudoPath(v string) bool {
|
||||
return (len(v) > 0 && v[0] == '/') || v == "*"
|
||||
}
|
||||
|
||||
// actualContentLength returns a sanitized version of
|
||||
// req.ContentLength, where 0 actually means zero (not unknown) and -1
|
||||
// means unknown.
|
||||
func actualContentLength(req *http.Request) int64 {
|
||||
if req.Body == nil {
|
||||
return 0
|
||||
}
|
||||
if req.ContentLength != 0 {
|
||||
return req.ContentLength
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// shouldSendReqContentLength reports whether the http2.Transport should send
|
||||
// a "content-length" request header. This logic is basically a copy of the net/http
|
||||
// transferWriter.shouldSendContentLength.
|
||||
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
|
||||
// -1 means unknown.
|
||||
func shouldSendReqContentLength(method string, contentLength int64) bool {
|
||||
if contentLength > 0 {
|
||||
return true
|
||||
}
|
||||
if contentLength < 0 {
|
||||
return false
|
||||
}
|
||||
// For zero bodies, whether we send a content-length depends on the method.
|
||||
// It also kinda doesn't matter for http2 either way, with END_STREAM.
|
||||
switch method {
|
||||
case "POST", "PUT", "PATCH":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
219
vendor/github.com/quic-go/quic-go/http3/response_writer.go
generated
vendored
Normal file
219
vendor/github.com/quic-go/quic-go/http3/response_writer.go
generated
vendored
Normal file
@@ -0,0 +1,219 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
)
|
||||
|
||||
// The maximum length of an encoded HTTP/3 frame header is 16:
|
||||
// The frame has a type and length field, both QUIC varints (maximum 8 bytes in length)
|
||||
const frameHeaderLen = 16
|
||||
|
||||
// headerWriter wraps the stream, so that the first Write call flushes the header to the stream
|
||||
type headerWriter struct {
|
||||
str quic.Stream
|
||||
header http.Header
|
||||
status int // status code passed to WriteHeader
|
||||
written bool
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// writeHeader encodes and flush header to the stream
|
||||
func (hw *headerWriter) writeHeader() error {
|
||||
var headers bytes.Buffer
|
||||
enc := qpack.NewEncoder(&headers)
|
||||
enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)})
|
||||
|
||||
for k, v := range hw.header {
|
||||
for index := range v {
|
||||
enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
|
||||
}
|
||||
}
|
||||
|
||||
buf := make([]byte, 0, frameHeaderLen+headers.Len())
|
||||
buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf)
|
||||
hw.logger.Infof("Responding with %d", hw.status)
|
||||
buf = append(buf, headers.Bytes()...)
|
||||
|
||||
_, err := hw.str.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// first Write will trigger flushing header
|
||||
func (hw *headerWriter) Write(p []byte) (int, error) {
|
||||
if !hw.written {
|
||||
if err := hw.writeHeader(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
hw.written = true
|
||||
}
|
||||
return hw.str.Write(p)
|
||||
}
|
||||
|
||||
type responseWriter struct {
|
||||
*headerWriter
|
||||
conn quic.Connection
|
||||
bufferedStr *bufio.Writer
|
||||
buf []byte
|
||||
|
||||
headerWritten bool
|
||||
contentLen int64 // if handler set valid Content-Length header
|
||||
numWritten int64 // bytes written
|
||||
}
|
||||
|
||||
var (
|
||||
_ http.ResponseWriter = &responseWriter{}
|
||||
_ http.Flusher = &responseWriter{}
|
||||
_ Hijacker = &responseWriter{}
|
||||
)
|
||||
|
||||
func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter {
|
||||
hw := &headerWriter{
|
||||
str: str,
|
||||
header: http.Header{},
|
||||
logger: logger,
|
||||
}
|
||||
return &responseWriter{
|
||||
headerWriter: hw,
|
||||
buf: make([]byte, frameHeaderLen),
|
||||
conn: conn,
|
||||
bufferedStr: bufio.NewWriter(hw),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Header() http.Header {
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *responseWriter) WriteHeader(status int) {
|
||||
if w.headerWritten {
|
||||
return
|
||||
}
|
||||
|
||||
// http status must be 3 digits
|
||||
if status < 100 || status > 999 {
|
||||
panic(fmt.Sprintf("invalid WriteHeader code %v", status))
|
||||
}
|
||||
|
||||
if status >= 200 {
|
||||
w.headerWritten = true
|
||||
// Add Date header.
|
||||
// This is what the standard library does.
|
||||
// Can be disabled by setting the Date header to nil.
|
||||
if _, ok := w.header["Date"]; !ok {
|
||||
w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
|
||||
}
|
||||
// Content-Length checking
|
||||
// use ParseUint instead of ParseInt, as negative values are invalid
|
||||
if clen := w.header.Get("Content-Length"); clen != "" {
|
||||
if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
|
||||
w.contentLen = int64(cl)
|
||||
} else {
|
||||
// emit a warning for malformed Content-Length and remove it
|
||||
w.logger.Errorf("Malformed Content-Length %s", clen)
|
||||
w.header.Del("Content-Length")
|
||||
}
|
||||
}
|
||||
}
|
||||
w.status = status
|
||||
|
||||
if !w.headerWritten {
|
||||
w.writeHeader()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) Write(p []byte) (int, error) {
|
||||
bodyAllowed := bodyAllowedForStatus(w.status)
|
||||
if !w.headerWritten {
|
||||
// If body is not allowed, we don't need to (and we can't) sniff the content type.
|
||||
if bodyAllowed {
|
||||
// If no content type, apply sniffing algorithm to body.
|
||||
// We can't use `w.header.Get` here since if the Content-Type was set to nil, we shoundn't do sniffing.
|
||||
_, haveType := w.header["Content-Type"]
|
||||
|
||||
// If the Transfer-Encoding or Content-Encoding was set and is non-blank,
|
||||
// we shouldn't sniff the body.
|
||||
hasTE := w.header.Get("Transfer-Encoding") != ""
|
||||
hasCE := w.header.Get("Content-Encoding") != ""
|
||||
if !hasCE && !haveType && !hasTE && len(p) > 0 {
|
||||
w.header.Set("Content-Type", http.DetectContentType(p))
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
bodyAllowed = true
|
||||
}
|
||||
if !bodyAllowed {
|
||||
return 0, http.ErrBodyNotAllowed
|
||||
}
|
||||
|
||||
w.numWritten += int64(len(p))
|
||||
if w.contentLen != 0 && w.numWritten > w.contentLen {
|
||||
return 0, http.ErrContentLength
|
||||
}
|
||||
|
||||
df := &dataFrame{Length: uint64(len(p))}
|
||||
w.buf = w.buf[:0]
|
||||
w.buf = df.Append(w.buf)
|
||||
if _, err := w.bufferedStr.Write(w.buf); err != nil {
|
||||
return 0, maybeReplaceError(err)
|
||||
}
|
||||
n, err := w.bufferedStr.Write(p)
|
||||
return n, maybeReplaceError(err)
|
||||
}
|
||||
|
||||
func (w *responseWriter) FlushError() error {
|
||||
if !w.headerWritten {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
if !w.written {
|
||||
if err := w.writeHeader(); err != nil {
|
||||
return maybeReplaceError(err)
|
||||
}
|
||||
w.written = true
|
||||
}
|
||||
return w.bufferedStr.Flush()
|
||||
}
|
||||
|
||||
func (w *responseWriter) Flush() {
|
||||
if err := w.FlushError(); err != nil {
|
||||
w.logger.Errorf("could not flush to stream: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (w *responseWriter) StreamCreator() StreamCreator {
|
||||
return w.conn
|
||||
}
|
||||
|
||||
func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
|
||||
return w.str.SetReadDeadline(deadline)
|
||||
}
|
||||
|
||||
func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
|
||||
return w.str.SetWriteDeadline(deadline)
|
||||
}
|
||||
|
||||
// copied from http2/http2.go
|
||||
// bodyAllowedForStatus reports whether a given response status code
|
||||
// permits a body. See RFC 2616, section 4.4.
|
||||
func bodyAllowedForStatus(status int) bool {
|
||||
switch {
|
||||
case status >= 100 && status <= 199:
|
||||
return false
|
||||
case status == http.StatusNoContent:
|
||||
return false
|
||||
case status == http.StatusNotModified:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
301
vendor/github.com/quic-go/quic-go/http3/roundtrip.go
generated
vendored
Normal file
301
vendor/github.com/quic-go/quic-go/http3/roundtrip.go
generated
vendored
Normal file
@@ -0,0 +1,301 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
type roundTripCloser interface {
|
||||
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
|
||||
HandshakeComplete() bool
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type roundTripCloserWithCount struct {
|
||||
roundTripCloser
|
||||
useCount atomic.Int64
|
||||
}
|
||||
|
||||
// RoundTripper implements the http.RoundTripper interface
|
||||
type RoundTripper struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
// DisableCompression, if true, prevents the Transport from
|
||||
// requesting compression with an "Accept-Encoding: gzip"
|
||||
// request header when the Request contains no existing
|
||||
// Accept-Encoding value. If the Transport requests gzip on
|
||||
// its own and gets a gzipped response, it's transparently
|
||||
// decoded in the Response.Body. However, if the user
|
||||
// explicitly requested gzip it is not automatically
|
||||
// uncompressed.
|
||||
DisableCompression bool
|
||||
|
||||
// TLSClientConfig specifies the TLS configuration to use with
|
||||
// tls.Client. If nil, the default configuration is used.
|
||||
TLSClientConfig *tls.Config
|
||||
|
||||
// QuicConfig is the quic.Config used for dialing new connections.
|
||||
// If nil, reasonable default values will be used.
|
||||
QuicConfig *quic.Config
|
||||
|
||||
// Enable support for HTTP/3 datagrams.
|
||||
// If set to true, QuicConfig.EnableDatagram will be set.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc9297.
|
||||
EnableDatagrams bool
|
||||
|
||||
// Additional HTTP/3 settings.
|
||||
// It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft.
|
||||
AdditionalSettings map[uint64]uint64
|
||||
|
||||
// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
|
||||
// It is called right after parsing the frame type.
|
||||
// If parsing the frame type fails, the error is passed to the callback.
|
||||
// In that case, the frame type will not be set.
|
||||
// Callers can either ignore the frame and return control of the stream back to HTTP/3
|
||||
// (by returning hijacked false).
|
||||
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
|
||||
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
|
||||
|
||||
// When set, this callback is called for unknown unidirectional stream of unknown stream type.
|
||||
// If parsing the stream type fails, the error is passed to the callback.
|
||||
// In that case, the stream type will not be set.
|
||||
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
|
||||
|
||||
// Dial specifies an optional dial function for creating QUIC
|
||||
// connections for requests.
|
||||
// If Dial is nil, a UDPConn will be created at the first request
|
||||
// and will be reused for subsequent connections to other servers.
|
||||
Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
|
||||
|
||||
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
|
||||
// allowed in the server's response header.
|
||||
// Zero means to use a default limit.
|
||||
MaxResponseHeaderBytes int64
|
||||
|
||||
newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
|
||||
clients map[string]*roundTripCloserWithCount
|
||||
transport *quic.Transport
|
||||
}
|
||||
|
||||
// RoundTripOpt are options for the Transport.RoundTripOpt method.
|
||||
type RoundTripOpt struct {
|
||||
// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
|
||||
// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
|
||||
OnlyCachedConn bool
|
||||
// DontCloseRequestStream controls whether the request stream is closed after sending the request.
|
||||
// If set, context cancellations have no effect after the response headers are received.
|
||||
DontCloseRequestStream bool
|
||||
}
|
||||
|
||||
var (
|
||||
_ http.RoundTripper = &RoundTripper{}
|
||||
_ io.Closer = &RoundTripper{}
|
||||
)
|
||||
|
||||
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
|
||||
var ErrNoCachedConn = errors.New("http3: no cached connection was available")
|
||||
|
||||
// RoundTripOpt is like RoundTrip, but takes options.
|
||||
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
|
||||
if req.URL == nil {
|
||||
closeRequestBody(req)
|
||||
return nil, errors.New("http3: nil Request.URL")
|
||||
}
|
||||
if req.URL.Scheme != "https" {
|
||||
closeRequestBody(req)
|
||||
return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
|
||||
}
|
||||
if req.URL.Host == "" {
|
||||
closeRequestBody(req)
|
||||
return nil, errors.New("http3: no Host in request URL")
|
||||
}
|
||||
if req.Header == nil {
|
||||
closeRequestBody(req)
|
||||
return nil, errors.New("http3: nil Request.Header")
|
||||
}
|
||||
for k, vv := range req.Header {
|
||||
if !httpguts.ValidHeaderFieldName(k) {
|
||||
return nil, fmt.Errorf("http3: invalid http header field name %q", k)
|
||||
}
|
||||
for _, v := range vv {
|
||||
if !httpguts.ValidHeaderFieldValue(v) {
|
||||
return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if req.Method != "" && !validMethod(req.Method) {
|
||||
closeRequestBody(req)
|
||||
return nil, fmt.Errorf("http3: invalid method %q", req.Method)
|
||||
}
|
||||
|
||||
hostname := authorityAddr("https", hostnameFromRequest(req))
|
||||
cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cl.useCount.Add(-1)
|
||||
rsp, err := cl.RoundTripOpt(req, opt)
|
||||
if err != nil {
|
||||
r.removeClient(hostname)
|
||||
if isReused {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
return r.RoundTripOpt(req, opt)
|
||||
}
|
||||
}
|
||||
}
|
||||
return rsp, err
|
||||
}
|
||||
|
||||
// RoundTrip does a round trip.
|
||||
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return r.RoundTripOpt(req, RoundTripOpt{})
|
||||
}
|
||||
|
||||
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if r.clients == nil {
|
||||
r.clients = make(map[string]*roundTripCloserWithCount)
|
||||
}
|
||||
|
||||
client, ok := r.clients[hostname]
|
||||
if !ok {
|
||||
if onlyCached {
|
||||
return nil, false, ErrNoCachedConn
|
||||
}
|
||||
var err error
|
||||
newCl := newClient
|
||||
if r.newClient != nil {
|
||||
newCl = r.newClient
|
||||
}
|
||||
dial := r.Dial
|
||||
if dial == nil {
|
||||
if r.transport == nil {
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
r.transport = &quic.Transport{Conn: udpConn}
|
||||
}
|
||||
dial = r.makeDialer()
|
||||
}
|
||||
c, err := newCl(
|
||||
hostname,
|
||||
r.TLSClientConfig,
|
||||
&roundTripperOpts{
|
||||
EnableDatagram: r.EnableDatagrams,
|
||||
DisableCompression: r.DisableCompression,
|
||||
MaxHeaderBytes: r.MaxResponseHeaderBytes,
|
||||
StreamHijacker: r.StreamHijacker,
|
||||
UniStreamHijacker: r.UniStreamHijacker,
|
||||
},
|
||||
r.QuicConfig,
|
||||
dial,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
client = &roundTripCloserWithCount{roundTripCloser: c}
|
||||
r.clients[hostname] = client
|
||||
} else if client.HandshakeComplete() {
|
||||
isReused = true
|
||||
}
|
||||
client.useCount.Add(1)
|
||||
return client, isReused, nil
|
||||
}
|
||||
|
||||
func (r *RoundTripper) removeClient(hostname string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
if r.clients == nil {
|
||||
return
|
||||
}
|
||||
delete(r.clients, hostname)
|
||||
}
|
||||
|
||||
// Close closes the QUIC connections that this RoundTripper has used.
|
||||
// It also closes the underlying UDPConn if it is not nil.
|
||||
func (r *RoundTripper) Close() error {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
for _, client := range r.clients {
|
||||
if err := client.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.clients = nil
|
||||
if r.transport != nil {
|
||||
if err := r.transport.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.transport.Conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
r.transport = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeRequestBody(req *http.Request) {
|
||||
if req.Body != nil {
|
||||
req.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func validMethod(method string) bool {
|
||||
/*
|
||||
Method = "OPTIONS" ; Section 9.2
|
||||
| "GET" ; Section 9.3
|
||||
| "HEAD" ; Section 9.4
|
||||
| "POST" ; Section 9.5
|
||||
| "PUT" ; Section 9.6
|
||||
| "DELETE" ; Section 9.7
|
||||
| "TRACE" ; Section 9.8
|
||||
| "CONNECT" ; Section 9.9
|
||||
| extension-method
|
||||
extension-method = token
|
||||
token = 1*<any CHAR except CTLs or separators>
|
||||
*/
|
||||
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
|
||||
}
|
||||
|
||||
// copied from net/http/http.go
|
||||
func isNotToken(r rune) bool {
|
||||
return !httpguts.IsTokenRune(r)
|
||||
}
|
||||
|
||||
// makeDialer makes a QUIC dialer using r.udpConn.
|
||||
func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RoundTripper) CloseIdleConnections() {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
for hostname, client := range r.clients {
|
||||
if client.useCount.Load() == 0 {
|
||||
client.Close()
|
||||
delete(r.clients, hostname)
|
||||
}
|
||||
}
|
||||
}
|
||||
767
vendor/github.com/quic-go/quic-go/http3/server.go
generated
vendored
Normal file
767
vendor/github.com/quic-go/quic-go/http3/server.go
generated
vendored
Normal file
@@ -0,0 +1,767 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
|
||||
"github.com/quic-go/qpack"
|
||||
)
|
||||
|
||||
// allows mocking of quic.Listen and quic.ListenAddr
|
||||
var (
|
||||
quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
|
||||
return quic.ListenEarly(conn, tlsConf, config)
|
||||
}
|
||||
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
|
||||
return quic.ListenAddrEarly(addr, tlsConf, config)
|
||||
}
|
||||
)
|
||||
|
||||
// NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2.
|
||||
const NextProtoH3 = "h3"
|
||||
|
||||
// StreamType is the stream type of a unidirectional stream.
|
||||
type StreamType uint64
|
||||
|
||||
const (
|
||||
streamTypeControlStream = 0
|
||||
streamTypePushStream = 1
|
||||
streamTypeQPACKEncoderStream = 2
|
||||
streamTypeQPACKDecoderStream = 3
|
||||
)
|
||||
|
||||
// A QUICEarlyListener listens for incoming QUIC connections.
|
||||
type QUICEarlyListener interface {
|
||||
Accept(context.Context) (quic.EarlyConnection, error)
|
||||
Addr() net.Addr
|
||||
io.Closer
|
||||
}
|
||||
|
||||
var _ QUICEarlyListener = &quic.EarlyListener{}
|
||||
|
||||
func versionToALPN(v protocol.VersionNumber) string {
|
||||
//nolint:exhaustive // These are all the versions we care about.
|
||||
switch v {
|
||||
case protocol.Version1, protocol.Version2:
|
||||
return NextProtoH3
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureTLSConfig creates a new tls.Config which can be used
|
||||
// to create a quic.Listener meant for serving http3. The created
|
||||
// tls.Config adds the functionality of detecting the used QUIC version
|
||||
// in order to set the correct ALPN value for the http3 connection.
|
||||
func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config {
|
||||
// The tls.Config used to setup the quic.Listener needs to have the GetConfigForClient callback set.
|
||||
// That way, we can get the QUIC version and set the correct ALPN value.
|
||||
return &tls.Config{
|
||||
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
// determine the ALPN from the QUIC version used
|
||||
proto := NextProtoH3
|
||||
val := ch.Context().Value(quic.QUICVersionContextKey)
|
||||
if v, ok := val.(quic.VersionNumber); ok {
|
||||
proto = versionToALPN(v)
|
||||
}
|
||||
config := tlsConf
|
||||
if tlsConf.GetConfigForClient != nil {
|
||||
getConfigForClient := tlsConf.GetConfigForClient
|
||||
var err error
|
||||
conf, err := getConfigForClient(ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conf != nil {
|
||||
config = conf
|
||||
}
|
||||
}
|
||||
if config == nil {
|
||||
return nil, nil
|
||||
}
|
||||
config = config.Clone()
|
||||
config.NextProtos = []string{proto}
|
||||
return config, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// contextKey is a value for use with context.WithValue. It's used as
|
||||
// a pointer so it fits in an interface{} without allocation.
|
||||
type contextKey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (k *contextKey) String() string { return "quic-go/http3 context value " + k.name }
|
||||
|
||||
// ServerContextKey is a context key. It can be used in HTTP
|
||||
// handlers with Context.Value to access the server that
|
||||
// started the handler. The associated value will be of
|
||||
// type *http3.Server.
|
||||
var ServerContextKey = &contextKey{"http3-server"}
|
||||
|
||||
type requestError struct {
|
||||
err error
|
||||
streamErr ErrCode
|
||||
connErr ErrCode
|
||||
}
|
||||
|
||||
func newStreamError(code ErrCode, err error) requestError {
|
||||
return requestError{err: err, streamErr: code}
|
||||
}
|
||||
|
||||
func newConnError(code ErrCode, err error) requestError {
|
||||
return requestError{err: err, connErr: code}
|
||||
}
|
||||
|
||||
// listenerInfo contains info about specific listener added with addListener
|
||||
type listenerInfo struct {
|
||||
port int // 0 means that no info about port is available
|
||||
}
|
||||
|
||||
// Server is a HTTP/3 server.
|
||||
type Server struct {
|
||||
// Addr optionally specifies the UDP address for the server to listen on,
|
||||
// in the form "host:port".
|
||||
//
|
||||
// When used by ListenAndServe and ListenAndServeTLS methods, if empty,
|
||||
// ":https" (port 443) is used. See net.Dial for details of the address
|
||||
// format.
|
||||
//
|
||||
// Otherwise, if Port is not set and underlying QUIC listeners do not
|
||||
// have valid port numbers, the port part is used in Alt-Svc headers set
|
||||
// with SetQuicHeaders.
|
||||
Addr string
|
||||
|
||||
// Port is used in Alt-Svc response headers set with SetQuicHeaders. If
|
||||
// needed Port can be manually set when the Server is created.
|
||||
//
|
||||
// This is useful when a Layer 4 firewall is redirecting UDP traffic and
|
||||
// clients must use a port different from the port the Server is
|
||||
// listening on.
|
||||
Port int
|
||||
|
||||
// TLSConfig provides a TLS configuration for use by server. It must be
|
||||
// set for ListenAndServe and Serve methods.
|
||||
TLSConfig *tls.Config
|
||||
|
||||
// QuicConfig provides the parameters for QUIC connection created with
|
||||
// Serve. If nil, it uses reasonable default values.
|
||||
//
|
||||
// Configured versions are also used in Alt-Svc response header set with
|
||||
// SetQuicHeaders.
|
||||
QuicConfig *quic.Config
|
||||
|
||||
// Handler is the HTTP request handler to use. If not set, defaults to
|
||||
// http.NotFound.
|
||||
Handler http.Handler
|
||||
|
||||
// EnableDatagrams enables support for HTTP/3 datagrams.
|
||||
// If set to true, QuicConfig.EnableDatagram will be set.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc9297.
|
||||
EnableDatagrams bool
|
||||
|
||||
// MaxHeaderBytes controls the maximum number of bytes the server will
|
||||
// read parsing the request HEADERS frame. It does not limit the size of
|
||||
// the request body. If zero or negative, http.DefaultMaxHeaderBytes is
|
||||
// used.
|
||||
MaxHeaderBytes int
|
||||
|
||||
// AdditionalSettings specifies additional HTTP/3 settings.
|
||||
// It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft.
|
||||
AdditionalSettings map[uint64]uint64
|
||||
|
||||
// StreamHijacker, when set, is called for the first unknown frame parsed on a bidirectional stream.
|
||||
// It is called right after parsing the frame type.
|
||||
// If parsing the frame type fails, the error is passed to the callback.
|
||||
// In that case, the frame type will not be set.
|
||||
// Callers can either ignore the frame and return control of the stream back to HTTP/3
|
||||
// (by returning hijacked false).
|
||||
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
|
||||
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
|
||||
|
||||
// UniStreamHijacker, when set, is called for unknown unidirectional stream of unknown stream type.
|
||||
// If parsing the stream type fails, the error is passed to the callback.
|
||||
// In that case, the stream type will not be set.
|
||||
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
|
||||
|
||||
mutex sync.RWMutex
|
||||
listeners map[*QUICEarlyListener]listenerInfo
|
||||
|
||||
closed bool
|
||||
|
||||
altSvcHeader string
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
|
||||
//
|
||||
// If s.Addr is blank, ":https" is used.
|
||||
func (s *Server) ListenAndServe() error {
|
||||
return s.serveConn(s.TLSConfig, nil)
|
||||
}
|
||||
|
||||
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
|
||||
//
|
||||
// If s.Addr is blank, ":https" is used.
|
||||
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
||||
var err error
|
||||
certs := make([]tls.Certificate, 1)
|
||||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// We currently only use the cert-related stuff from tls.Config,
|
||||
// so we don't need to make a full copy.
|
||||
config := &tls.Config{
|
||||
Certificates: certs,
|
||||
}
|
||||
return s.serveConn(config, nil)
|
||||
}
|
||||
|
||||
// Serve an existing UDP connection.
|
||||
// It is possible to reuse the same connection for outgoing connections.
|
||||
// Closing the server does not close the connection.
|
||||
func (s *Server) Serve(conn net.PacketConn) error {
|
||||
return s.serveConn(s.TLSConfig, conn)
|
||||
}
|
||||
|
||||
// ServeQUICConn serves a single QUIC connection.
|
||||
func (s *Server) ServeQUICConn(conn quic.Connection) error {
|
||||
s.mutex.Lock()
|
||||
if s.logger == nil {
|
||||
s.logger = utils.DefaultLogger.WithPrefix("server")
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
return s.handleConn(conn)
|
||||
}
|
||||
|
||||
// ServeListener serves an existing QUIC listener.
|
||||
// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config
|
||||
// and use it to construct a http3-friendly QUIC listener.
|
||||
// Closing the server does close the listener.
|
||||
// ServeListener always returns a non-nil error. After Shutdown or Close, the returned error is http.ErrServerClosed.
|
||||
func (s *Server) ServeListener(ln QUICEarlyListener) error {
|
||||
if err := s.addListener(&ln); err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.removeListener(&ln)
|
||||
for {
|
||||
conn, err := ln.Accept(context.Background())
|
||||
if err == quic.ErrServerClosed {
|
||||
return http.ErrServerClosed
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
if err := s.handleConn(conn); err != nil {
|
||||
s.logger.Debugf(err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig")
|
||||
|
||||
func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
|
||||
if tlsConf == nil {
|
||||
return errServerWithoutTLSConfig
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
closed := s.closed
|
||||
s.mutex.Unlock()
|
||||
if closed {
|
||||
return http.ErrServerClosed
|
||||
}
|
||||
|
||||
baseConf := ConfigureTLSConfig(tlsConf)
|
||||
quicConf := s.QuicConfig
|
||||
if quicConf == nil {
|
||||
quicConf = &quic.Config{Allow0RTT: true}
|
||||
} else {
|
||||
quicConf = s.QuicConfig.Clone()
|
||||
}
|
||||
if s.EnableDatagrams {
|
||||
quicConf.EnableDatagrams = true
|
||||
}
|
||||
|
||||
var ln QUICEarlyListener
|
||||
var err error
|
||||
if conn == nil {
|
||||
addr := s.Addr
|
||||
if addr == "" {
|
||||
addr = ":https"
|
||||
}
|
||||
ln, err = quicListenAddr(addr, baseConf, quicConf)
|
||||
} else {
|
||||
ln, err = quicListen(conn, baseConf, quicConf)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.ServeListener(ln)
|
||||
}
|
||||
|
||||
func extractPort(addr string) (int, error) {
|
||||
_, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
portInt, err := net.LookupPort("tcp", portStr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return portInt, nil
|
||||
}
|
||||
|
||||
func (s *Server) generateAltSvcHeader() {
|
||||
if len(s.listeners) == 0 {
|
||||
// Don't announce any ports since no one is listening for connections
|
||||
s.altSvcHeader = ""
|
||||
return
|
||||
}
|
||||
|
||||
// This code assumes that we will use protocol.SupportedVersions if no quic.Config is passed.
|
||||
supportedVersions := protocol.SupportedVersions
|
||||
if s.QuicConfig != nil && len(s.QuicConfig.Versions) > 0 {
|
||||
supportedVersions = s.QuicConfig.Versions
|
||||
}
|
||||
|
||||
// keep track of which have been seen so we don't yield duplicate values
|
||||
seen := make(map[string]struct{}, len(supportedVersions))
|
||||
var versionStrings []string
|
||||
for _, version := range supportedVersions {
|
||||
if v := versionToALPN(version); len(v) > 0 {
|
||||
if _, ok := seen[v]; !ok {
|
||||
versionStrings = append(versionStrings, v)
|
||||
seen[v] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var altSvc []string
|
||||
addPort := func(port int) {
|
||||
for _, v := range versionStrings {
|
||||
altSvc = append(altSvc, fmt.Sprintf(`%s=":%d"; ma=2592000`, v, port))
|
||||
}
|
||||
}
|
||||
|
||||
if s.Port != 0 {
|
||||
// if Port is specified, we must use it instead of the
|
||||
// listener addresses since there's a reason it's specified.
|
||||
addPort(s.Port)
|
||||
} else {
|
||||
// if we have some listeners assigned, try to find ports
|
||||
// which we can announce, otherwise nothing should be announced
|
||||
validPortsFound := false
|
||||
for _, info := range s.listeners {
|
||||
if info.port != 0 {
|
||||
addPort(info.port)
|
||||
validPortsFound = true
|
||||
}
|
||||
}
|
||||
if !validPortsFound {
|
||||
if port, err := extractPort(s.Addr); err == nil {
|
||||
addPort(port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.altSvcHeader = strings.Join(altSvc, ",")
|
||||
}
|
||||
|
||||
// We store a pointer to interface in the map set. This is safe because we only
|
||||
// call trackListener via Serve and can track+defer untrack the same pointer to
|
||||
// local variable there. We never need to compare a Listener from another caller.
|
||||
func (s *Server) addListener(l *QUICEarlyListener) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.closed {
|
||||
return http.ErrServerClosed
|
||||
}
|
||||
if s.logger == nil {
|
||||
s.logger = utils.DefaultLogger.WithPrefix("server")
|
||||
}
|
||||
if s.listeners == nil {
|
||||
s.listeners = make(map[*QUICEarlyListener]listenerInfo)
|
||||
}
|
||||
|
||||
if port, err := extractPort((*l).Addr().String()); err == nil {
|
||||
s.listeners[l] = listenerInfo{port}
|
||||
} else {
|
||||
s.logger.Errorf("Unable to extract port from listener %+v, will not be announced using SetQuicHeaders: %s", err)
|
||||
s.listeners[l] = listenerInfo{}
|
||||
}
|
||||
s.generateAltSvcHeader()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) removeListener(l *QUICEarlyListener) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
delete(s.listeners, l)
|
||||
s.generateAltSvcHeader()
|
||||
}
|
||||
|
||||
func (s *Server) handleConn(conn quic.Connection) error {
|
||||
decoder := qpack.NewDecoder(nil)
|
||||
|
||||
// send a SETTINGS frame
|
||||
str, err := conn.OpenUniStream()
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening the control stream failed: %w", err)
|
||||
}
|
||||
b := make([]byte, 0, 64)
|
||||
b = quicvarint.Append(b, streamTypeControlStream) // stream type
|
||||
b = (&settingsFrame{Datagram: s.EnableDatagrams, Other: s.AdditionalSettings}).Append(b)
|
||||
str.Write(b)
|
||||
|
||||
go s.handleUnidirectionalStreams(conn)
|
||||
|
||||
// Process all requests immediately.
|
||||
// It's the client's responsibility to decide which requests are eligible for 0-RTT.
|
||||
for {
|
||||
str, err := conn.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
var appErr *quic.ApplicationError
|
||||
if errors.As(err, &appErr) && appErr.ErrorCode == quic.ApplicationErrorCode(ErrCodeNoError) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("accepting stream failed: %w", err)
|
||||
}
|
||||
go func() {
|
||||
rerr := s.handleRequest(conn, str, decoder, func() {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "")
|
||||
})
|
||||
if rerr.err == errHijacked {
|
||||
return
|
||||
}
|
||||
if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
|
||||
s.logger.Debugf("Handling request failed: %s", err)
|
||||
if rerr.streamErr != 0 {
|
||||
str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
|
||||
}
|
||||
if rerr.connErr != 0 {
|
||||
var reason string
|
||||
if rerr.err != nil {
|
||||
reason = rerr.err.Error()
|
||||
}
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
|
||||
}
|
||||
return
|
||||
}
|
||||
str.Close()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleUnidirectionalStreams(conn quic.Connection) {
|
||||
for {
|
||||
str, err := conn.AcceptUniStream(context.Background())
|
||||
if err != nil {
|
||||
s.logger.Debugf("accepting unidirectional stream failed: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
go func(str quic.ReceiveStream) {
|
||||
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
|
||||
if err != nil {
|
||||
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, err) {
|
||||
return
|
||||
}
|
||||
s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
|
||||
return
|
||||
}
|
||||
// We're only interested in the control stream here.
|
||||
switch streamType {
|
||||
case streamTypeControlStream:
|
||||
case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream:
|
||||
// Our QPACK implementation doesn't use the dynamic table yet.
|
||||
// TODO: check that only one stream of each type is opened.
|
||||
return
|
||||
case streamTypePushStream: // only the server can push
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "")
|
||||
return
|
||||
default:
|
||||
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
|
||||
return
|
||||
}
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
|
||||
return
|
||||
}
|
||||
f, err := parseNextFrame(str, nil)
|
||||
if err != nil {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
|
||||
return
|
||||
}
|
||||
sf, ok := f.(*settingsFrame)
|
||||
if !ok {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
|
||||
return
|
||||
}
|
||||
if !sf.Datagram {
|
||||
return
|
||||
}
|
||||
// If datagram support was enabled on our side as well as on the client side,
|
||||
// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
|
||||
// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
|
||||
if s.EnableDatagrams && !conn.ConnectionState().SupportsDatagrams {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
|
||||
}
|
||||
}(str)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) maxHeaderBytes() uint64 {
|
||||
if s.MaxHeaderBytes <= 0 {
|
||||
return http.DefaultMaxHeaderBytes
|
||||
}
|
||||
return uint64(s.MaxHeaderBytes)
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
|
||||
var ufh unknownFrameHandlerFunc
|
||||
if s.StreamHijacker != nil {
|
||||
ufh = func(ft FrameType, e error) (processed bool, err error) { return s.StreamHijacker(ft, conn, str, e) }
|
||||
}
|
||||
frame, err := parseNextFrame(str, ufh)
|
||||
if err != nil {
|
||||
if err == errHijacked {
|
||||
return requestError{err: errHijacked}
|
||||
}
|
||||
return newStreamError(ErrCodeRequestIncomplete, err)
|
||||
}
|
||||
hf, ok := frame.(*headersFrame)
|
||||
if !ok {
|
||||
return newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
|
||||
}
|
||||
if hf.Length > s.maxHeaderBytes() {
|
||||
return newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes()))
|
||||
}
|
||||
headerBlock := make([]byte, hf.Length)
|
||||
if _, err := io.ReadFull(str, headerBlock); err != nil {
|
||||
return newStreamError(ErrCodeRequestIncomplete, err)
|
||||
}
|
||||
hfs, err := decoder.DecodeFull(headerBlock)
|
||||
if err != nil {
|
||||
// TODO: use the right error code
|
||||
return newConnError(ErrCodeGeneralProtocolError, err)
|
||||
}
|
||||
req, err := requestFromHeaders(hfs)
|
||||
if err != nil {
|
||||
return newStreamError(ErrCodeMessageError, err)
|
||||
}
|
||||
|
||||
connState := conn.ConnectionState().TLS
|
||||
req.TLS = &connState
|
||||
req.RemoteAddr = conn.RemoteAddr().String()
|
||||
|
||||
// Check that the client doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
|
||||
// See section 4.1.2 of RFC 9114.
|
||||
var httpStr Stream
|
||||
if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 {
|
||||
httpStr = newLengthLimitedStream(newStream(str, onFrameError), req.ContentLength)
|
||||
} else {
|
||||
httpStr = newStream(str, onFrameError)
|
||||
}
|
||||
body := newRequestBody(httpStr)
|
||||
req.Body = body
|
||||
|
||||
if s.logger.Debug() {
|
||||
s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID())
|
||||
} else {
|
||||
s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
|
||||
}
|
||||
|
||||
ctx := str.Context()
|
||||
ctx = context.WithValue(ctx, ServerContextKey, s)
|
||||
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
|
||||
req = req.WithContext(ctx)
|
||||
r := newResponseWriter(str, conn, s.logger)
|
||||
handler := s.Handler
|
||||
if handler == nil {
|
||||
handler = http.DefaultServeMux
|
||||
}
|
||||
|
||||
var panicked bool
|
||||
func() {
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
panicked = true
|
||||
if p == http.ErrAbortHandler {
|
||||
return
|
||||
}
|
||||
// Copied from net/http/server.go
|
||||
const size = 64 << 10
|
||||
buf := make([]byte, size)
|
||||
buf = buf[:runtime.Stack(buf, false)]
|
||||
s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
|
||||
}
|
||||
}()
|
||||
handler.ServeHTTP(r, req)
|
||||
}()
|
||||
|
||||
if body.wasStreamHijacked() {
|
||||
return requestError{err: errHijacked}
|
||||
}
|
||||
|
||||
// only write response when there is no panic
|
||||
if !panicked {
|
||||
// response not written to the client yet, set Content-Length
|
||||
if !r.written {
|
||||
if _, haveCL := r.header["Content-Length"]; !haveCL {
|
||||
r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10))
|
||||
}
|
||||
}
|
||||
r.Flush()
|
||||
}
|
||||
// If the EOF was read by the handler, CancelRead() is a no-op.
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeNoError))
|
||||
return requestError{}
|
||||
}
|
||||
|
||||
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
|
||||
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
|
||||
func (s *Server) Close() error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.closed = true
|
||||
|
||||
var err error
|
||||
for ln := range s.listeners {
|
||||
if cerr := (*ln).Close(); cerr != nil && err == nil {
|
||||
err = cerr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
|
||||
// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
|
||||
func (s *Server) CloseGracefully(timeout time.Duration) error {
|
||||
// TODO: implement
|
||||
return nil
|
||||
}
|
||||
|
||||
// ErrNoAltSvcPort is the error returned by SetQuicHeaders when no port was found
|
||||
// for Alt-Svc to announce. This can happen if listening on a PacketConn without a port
|
||||
// (UNIX socket, for example) and no port is specified in Server.Port or Server.Addr.
|
||||
var ErrNoAltSvcPort = errors.New("no port can be announced, specify it explicitly using Server.Port or Server.Addr")
|
||||
|
||||
// SetQuicHeaders can be used to set the proper headers that announce that this server supports HTTP/3.
|
||||
// The values set by default advertise all of the ports the server is listening on, but can be
|
||||
// changed to a specific port by setting Server.Port before launching the serverr.
|
||||
// If no listener's Addr().String() returns an address with a valid port, Server.Addr will be used
|
||||
// to extract the port, if specified.
|
||||
// For example, a server launched using ListenAndServe on an address with port 443 would set:
|
||||
//
|
||||
// Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
|
||||
func (s *Server) SetQuicHeaders(hdr http.Header) error {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
|
||||
if s.altSvcHeader == "" {
|
||||
return ErrNoAltSvcPort
|
||||
}
|
||||
// use the map directly to avoid constant canonicalization
|
||||
// since the key is already canonicalized
|
||||
hdr["Alt-Svc"] = append(hdr["Alt-Svc"], s.altSvcHeader)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListenAndServeQUIC listens on the UDP network address addr and calls the
|
||||
// handler for HTTP/3 requests on incoming connections. http.DefaultServeMux is
|
||||
// used when handler is nil.
|
||||
func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
|
||||
server := &Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
}
|
||||
return server.ListenAndServeTLS(certFile, keyFile)
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the given network address for both, TLS and QUIC
|
||||
// connections in parallel. It returns if one of the two returns an error.
|
||||
// http.DefaultServeMux is used when handler is nil.
|
||||
// The correct Alt-Svc headers for QUIC are set.
|
||||
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
|
||||
// Load certs
|
||||
var err error
|
||||
certs := make([]tls.Certificate, 1)
|
||||
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// We currently only use the cert-related stuff from tls.Config,
|
||||
// so we don't need to make a full copy.
|
||||
config := &tls.Config{
|
||||
Certificates: certs,
|
||||
}
|
||||
|
||||
if addr == "" {
|
||||
addr = ":https"
|
||||
}
|
||||
|
||||
// Open the listeners
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer udpConn.Close()
|
||||
|
||||
if handler == nil {
|
||||
handler = http.DefaultServeMux
|
||||
}
|
||||
// Start the servers
|
||||
quicServer := &Server{
|
||||
TLSConfig: config,
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
hErr := make(chan error)
|
||||
qErr := make(chan error)
|
||||
go func() {
|
||||
hErr <- http.ListenAndServeTLS(addr, certFile, keyFile, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
quicServer.SetQuicHeaders(w.Header())
|
||||
handler.ServeHTTP(w, r)
|
||||
}))
|
||||
}()
|
||||
go func() {
|
||||
qErr <- quicServer.Serve(udpConn)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-hErr:
|
||||
quicServer.Close()
|
||||
return err
|
||||
case err := <-qErr:
|
||||
// Cannot close the HTTP server or wait for requests to complete properly :/
|
||||
return err
|
||||
}
|
||||
}
|
||||
349
vendor/github.com/quic-go/quic-go/interface.go
generated
vendored
Normal file
349
vendor/github.com/quic-go/quic-go/interface.go
generated
vendored
Normal file
@@ -0,0 +1,349 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/handshake"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
// The StreamID is the ID of a QUIC stream.
|
||||
type StreamID = protocol.StreamID
|
||||
|
||||
// A VersionNumber is a QUIC version number.
|
||||
type VersionNumber = protocol.VersionNumber
|
||||
|
||||
const (
|
||||
// Version1 is RFC 9000
|
||||
Version1 = protocol.Version1
|
||||
// Version2 is RFC 9369
|
||||
Version2 = protocol.Version2
|
||||
)
|
||||
|
||||
// A ClientToken is a token received by the client.
|
||||
// It can be used to skip address validation on future connection attempts.
|
||||
type ClientToken struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
type TokenStore interface {
|
||||
// Pop searches for a ClientToken associated with the given key.
|
||||
// Since tokens are not supposed to be reused, it must remove the token from the cache.
|
||||
// It returns nil when no token is found.
|
||||
Pop(key string) (token *ClientToken)
|
||||
|
||||
// Put adds a token to the cache with the given key. It might get called
|
||||
// multiple times in a connection.
|
||||
Put(key string, token *ClientToken)
|
||||
}
|
||||
|
||||
// Err0RTTRejected is the returned from:
|
||||
// * Open{Uni}Stream{Sync}
|
||||
// * Accept{Uni}Stream
|
||||
// * Stream.Read and Stream.Write
|
||||
// when the server rejects a 0-RTT connection attempt.
|
||||
var Err0RTTRejected = errors.New("0-RTT rejected")
|
||||
|
||||
// ConnectionTracingKey can be used to associate a ConnectionTracer with a Connection.
|
||||
// It is set on the Connection.Context() context,
|
||||
// as well as on the context passed to logging.Tracer.NewConnectionTracer.
|
||||
var ConnectionTracingKey = connTracingCtxKey{}
|
||||
|
||||
type connTracingCtxKey struct{}
|
||||
|
||||
// QUICVersionContextKey can be used to find out the QUIC version of a TLS handshake from the
|
||||
// context returned by tls.Config.ClientHelloInfo.Context.
|
||||
var QUICVersionContextKey = handshake.QUICVersionContextKey
|
||||
|
||||
// Stream is the interface implemented by QUIC streams
|
||||
// In addition to the errors listed on the Connection,
|
||||
// calls to stream functions can return a StreamError if the stream is canceled.
|
||||
type Stream interface {
|
||||
ReceiveStream
|
||||
SendStream
|
||||
// SetDeadline sets the read and write deadlines associated
|
||||
// with the connection. It is equivalent to calling both
|
||||
// SetReadDeadline and SetWriteDeadline.
|
||||
SetDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A ReceiveStream is a unidirectional Receive Stream.
|
||||
type ReceiveStream interface {
|
||||
// StreamID returns the stream ID.
|
||||
StreamID() StreamID
|
||||
// Read reads data from the stream.
|
||||
// Read can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
||||
// If the stream was canceled by the peer, the error implements the StreamError
|
||||
// interface, and Canceled() == true.
|
||||
// If the connection was closed due to a timeout, the error satisfies
|
||||
// the net.Error interface, and Timeout() will be true.
|
||||
io.Reader
|
||||
// CancelRead aborts receiving on this stream.
|
||||
// It will ask the peer to stop transmitting stream data.
|
||||
// Read will unblock immediately, and future Read calls will fail.
|
||||
// When called multiple times or after reading the io.EOF it is a no-op.
|
||||
CancelRead(StreamErrorCode)
|
||||
// SetReadDeadline sets the deadline for future Read calls and
|
||||
// any currently-blocked Read call.
|
||||
// A zero value for t means Read will not time out.
|
||||
|
||||
SetReadDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A SendStream is a unidirectional Send Stream.
|
||||
type SendStream interface {
|
||||
// StreamID returns the stream ID.
|
||||
StreamID() StreamID
|
||||
// Write writes data to the stream.
|
||||
// Write can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
|
||||
// If the stream was canceled by the peer, the error implements the StreamError
|
||||
// interface, and Canceled() == true.
|
||||
// If the connection was closed due to a timeout, the error satisfies
|
||||
// the net.Error interface, and Timeout() will be true.
|
||||
io.Writer
|
||||
// Close closes the write-direction of the stream.
|
||||
// Future calls to Write are not permitted after calling Close.
|
||||
// It must not be called concurrently with Write.
|
||||
// It must not be called after calling CancelWrite.
|
||||
io.Closer
|
||||
// CancelWrite aborts sending on this stream.
|
||||
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
|
||||
// Write will unblock immediately, and future calls to Write will fail.
|
||||
// When called multiple times or after closing the stream it is a no-op.
|
||||
CancelWrite(StreamErrorCode)
|
||||
// The Context is canceled as soon as the write-side of the stream is closed.
|
||||
// This happens when Close() or CancelWrite() is called, or when the peer
|
||||
// cancels the read-side of their stream.
|
||||
// The cancellation cause is set to the error that caused the stream to
|
||||
// close, or `context.Canceled` in case the stream is closed without error.
|
||||
Context() context.Context
|
||||
// SetWriteDeadline sets the deadline for future Write calls
|
||||
// and any currently-blocked Write call.
|
||||
// Even if write times out, it may return n > 0, indicating that
|
||||
// some data was successfully written.
|
||||
// A zero value for t means Write will not time out.
|
||||
SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
// A Connection is a QUIC connection between two peers.
|
||||
// Calls to the connection (and to streams) can return the following types of errors:
|
||||
// * ApplicationError: for errors triggered by the application running on top of QUIC
|
||||
// * TransportError: for errors triggered by the QUIC transport (in many cases a misbehaving peer)
|
||||
// * IdleTimeoutError: when the peer goes away unexpectedly (this is a net.Error timeout error)
|
||||
// * HandshakeTimeoutError: when the cryptographic handshake takes too long (this is a net.Error timeout error)
|
||||
// * StatelessResetError: when we receive a stateless reset (this is a net.Error temporary error)
|
||||
// * VersionNegotiationError: returned by the client, when there's no version overlap between the peers
|
||||
type Connection interface {
|
||||
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
|
||||
// If the connection was closed due to a timeout, the error satisfies
|
||||
// the net.Error interface, and Timeout() will be true.
|
||||
AcceptStream(context.Context) (Stream, error)
|
||||
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
|
||||
// If the connection was closed due to a timeout, the error satisfies
|
||||
// the net.Error interface, and Timeout() will be true.
|
||||
AcceptUniStream(context.Context) (ReceiveStream, error)
|
||||
// OpenStream opens a new bidirectional QUIC stream.
|
||||
// There is no signaling to the peer about new streams:
|
||||
// The peer can only accept the stream after data has been sent on the stream.
|
||||
// If the error is non-nil, it satisfies the net.Error interface.
|
||||
// When reaching the peer's stream limit, err.Temporary() will be true.
|
||||
// If the connection was closed due to a timeout, Timeout() will be true.
|
||||
OpenStream() (Stream, error)
|
||||
// OpenStreamSync opens a new bidirectional QUIC stream.
|
||||
// It blocks until a new stream can be opened.
|
||||
// If the error is non-nil, it satisfies the net.Error interface.
|
||||
// If the connection was closed due to a timeout, Timeout() will be true.
|
||||
OpenStreamSync(context.Context) (Stream, error)
|
||||
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
|
||||
// If the error is non-nil, it satisfies the net.Error interface.
|
||||
// When reaching the peer's stream limit, Temporary() will be true.
|
||||
// If the connection was closed due to a timeout, Timeout() will be true.
|
||||
OpenUniStream() (SendStream, error)
|
||||
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
|
||||
// It blocks until a new stream can be opened.
|
||||
// If the error is non-nil, it satisfies the net.Error interface.
|
||||
// If the connection was closed due to a timeout, Timeout() will be true.
|
||||
OpenUniStreamSync(context.Context) (SendStream, error)
|
||||
// LocalAddr returns the local address.
|
||||
LocalAddr() net.Addr
|
||||
// RemoteAddr returns the address of the peer.
|
||||
RemoteAddr() net.Addr
|
||||
// CloseWithError closes the connection with an error.
|
||||
// The error string will be sent to the peer.
|
||||
CloseWithError(ApplicationErrorCode, string) error
|
||||
// Context returns a context that is cancelled when the connection is closed.
|
||||
// The cancellation cause is set to the error that caused the connection to
|
||||
// close, or `context.Canceled` in case the listener is closed first.
|
||||
Context() context.Context
|
||||
// ConnectionState returns basic details about the QUIC connection.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
ConnectionState() ConnectionState
|
||||
|
||||
// SendMessage sends a message as a datagram, as specified in RFC 9221.
|
||||
SendMessage([]byte) error
|
||||
// ReceiveMessage gets a message received in a datagram, as specified in RFC 9221.
|
||||
ReceiveMessage(context.Context) ([]byte, error)
|
||||
}
|
||||
|
||||
// An EarlyConnection is a connection that is handshaking.
|
||||
// Data sent during the handshake is encrypted using the forward secure keys.
|
||||
// When using client certificates, the client's identity is only verified
|
||||
// after completion of the handshake.
|
||||
type EarlyConnection interface {
|
||||
Connection
|
||||
|
||||
// HandshakeComplete blocks until the handshake completes (or fails).
|
||||
// For the client, data sent before completion of the handshake is encrypted with 0-RTT keys.
|
||||
// For the server, data sent before completion of the handshake is encrypted with 1-RTT keys,
|
||||
// however the client's identity is only verified once the handshake completes.
|
||||
HandshakeComplete() <-chan struct{}
|
||||
|
||||
NextConnection() Connection
|
||||
}
|
||||
|
||||
// StatelessResetKey is a key used to derive stateless reset tokens.
|
||||
type StatelessResetKey [32]byte
|
||||
|
||||
// TokenGeneratorKey is a key used to encrypt session resumption tokens.
|
||||
type TokenGeneratorKey = handshake.TokenProtectorKey
|
||||
|
||||
// A ConnectionID is a QUIC Connection ID, as defined in RFC 9000.
|
||||
// It is not able to handle QUIC Connection IDs longer than 20 bytes,
|
||||
// as they are allowed by RFC 8999.
|
||||
type ConnectionID = protocol.ConnectionID
|
||||
|
||||
// ConnectionIDFromBytes interprets b as a Connection ID. It panics if b is
|
||||
// longer than 20 bytes.
|
||||
func ConnectionIDFromBytes(b []byte) ConnectionID {
|
||||
return protocol.ParseConnectionID(b)
|
||||
}
|
||||
|
||||
// A ConnectionIDGenerator is an interface that allows clients to implement their own format
|
||||
// for the Connection IDs that servers/clients use as SrcConnectionID in QUIC packets.
|
||||
//
|
||||
// Connection IDs generated by an implementation should always produce IDs of constant size.
|
||||
type ConnectionIDGenerator interface {
|
||||
// GenerateConnectionID generates a new ConnectionID.
|
||||
// Generated ConnectionIDs should be unique and observers should not be able to correlate two ConnectionIDs.
|
||||
GenerateConnectionID() (ConnectionID, error)
|
||||
|
||||
// ConnectionIDLen tells what is the length of the ConnectionIDs generated by the implementation of
|
||||
// this interface.
|
||||
// Effectively, this means that implementations of ConnectionIDGenerator must always return constant-size
|
||||
// connection IDs. Valid lengths are between 0 and 20 and calls to GenerateConnectionID.
|
||||
// 0-length ConnectionsIDs can be used when an endpoint (server or client) does not require multiplexing connections
|
||||
// in the presence of a connection migration environment.
|
||||
ConnectionIDLen() int
|
||||
}
|
||||
|
||||
// Config contains all configuration data needed for a QUIC server or client.
|
||||
type Config struct {
|
||||
// GetConfigForClient is called for incoming connections.
|
||||
// If the error is not nil, the connection attempt is refused.
|
||||
GetConfigForClient func(info *ClientHelloInfo) (*Config, error)
|
||||
// The QUIC versions that can be negotiated.
|
||||
// If not set, it uses all versions available.
|
||||
Versions []VersionNumber
|
||||
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
|
||||
// If we don't receive any packet from the peer within this time, the connection attempt is aborted.
|
||||
// Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted.
|
||||
// If this value is zero, the timeout is set to 5 seconds.
|
||||
HandshakeIdleTimeout time.Duration
|
||||
// MaxIdleTimeout is the maximum duration that may pass without any incoming network activity.
|
||||
// The actual value for the idle timeout is the minimum of this value and the peer's.
|
||||
// This value only applies after the handshake has completed.
|
||||
// If the timeout is exceeded, the connection is closed.
|
||||
// If this value is zero, the timeout is set to 30 seconds.
|
||||
MaxIdleTimeout time.Duration
|
||||
// RequireAddressValidation determines if a QUIC Retry packet is sent.
|
||||
// This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details.
|
||||
// If not set, every client is forced to prove its remote address.
|
||||
RequireAddressValidation func(net.Addr) bool
|
||||
// The TokenStore stores tokens received from the server.
|
||||
// Tokens are used to skip address validation on future connection attempts.
|
||||
// The key used to store tokens is the ServerName from the tls.Config, if set
|
||||
// otherwise the token is associated with the server's IP address.
|
||||
TokenStore TokenStore
|
||||
// InitialStreamReceiveWindow is the initial size of the stream-level flow control window for receiving data.
|
||||
// If the application is consuming data quickly enough, the flow control auto-tuning algorithm
|
||||
// will increase the window up to MaxStreamReceiveWindow.
|
||||
// If this value is zero, it will default to 512 KB.
|
||||
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
|
||||
InitialStreamReceiveWindow uint64
|
||||
// MaxStreamReceiveWindow is the maximum stream-level flow control window for receiving data.
|
||||
// If this value is zero, it will default to 6 MB.
|
||||
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
|
||||
MaxStreamReceiveWindow uint64
|
||||
// InitialConnectionReceiveWindow is the initial size of the stream-level flow control window for receiving data.
|
||||
// If the application is consuming data quickly enough, the flow control auto-tuning algorithm
|
||||
// will increase the window up to MaxConnectionReceiveWindow.
|
||||
// If this value is zero, it will default to 512 KB.
|
||||
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
|
||||
InitialConnectionReceiveWindow uint64
|
||||
// MaxConnectionReceiveWindow is the connection-level flow control window for receiving data.
|
||||
// If this value is zero, it will default to 15 MB.
|
||||
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
|
||||
MaxConnectionReceiveWindow uint64
|
||||
// AllowConnectionWindowIncrease is called every time the connection flow controller attempts
|
||||
// to increase the connection flow control window.
|
||||
// If set, the caller can prevent an increase of the window. Typically, it would do so to
|
||||
// limit the memory usage.
|
||||
// To avoid deadlocks, it is not valid to call other functions on the connection or on streams
|
||||
// in this callback.
|
||||
AllowConnectionWindowIncrease func(conn Connection, delta uint64) bool
|
||||
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any bidirectional streams.
|
||||
// Values larger than 2^60 will be clipped to that value.
|
||||
MaxIncomingStreams int64
|
||||
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
|
||||
// If not set, it will default to 100.
|
||||
// If set to a negative value, it doesn't allow any unidirectional streams.
|
||||
// Values larger than 2^60 will be clipped to that value.
|
||||
MaxIncomingUniStreams int64
|
||||
// KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive.
|
||||
// If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most
|
||||
// every half of MaxIdleTimeout, whichever is smaller).
|
||||
KeepAlivePeriod time.Duration
|
||||
// DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899).
|
||||
// This allows the sending of QUIC packets that fully utilize the available MTU of the path.
|
||||
// Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit.
|
||||
// If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
|
||||
DisablePathMTUDiscovery bool
|
||||
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
|
||||
// Only valid for the server.
|
||||
Allow0RTT bool
|
||||
// Enable QUIC datagram support (RFC 9221).
|
||||
EnableDatagrams bool
|
||||
Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer
|
||||
}
|
||||
|
||||
type ClientHelloInfo struct {
|
||||
RemoteAddr net.Addr
|
||||
}
|
||||
|
||||
// ConnectionState records basic details about a QUIC connection
|
||||
type ConnectionState struct {
|
||||
// TLS contains information about the TLS connection state, incl. the tls.ConnectionState.
|
||||
TLS tls.ConnectionState
|
||||
// SupportsDatagrams says if support for QUIC datagrams (RFC 9221) was negotiated.
|
||||
// This requires both nodes to support and enable the datagram extensions (via Config.EnableDatagrams).
|
||||
// If datagram support was negotiated, datagrams can be sent and received using the
|
||||
// SendMessage and ReceiveMessage methods on the Connection.
|
||||
SupportsDatagrams bool
|
||||
// Used0RTT says if 0-RTT resumption was used.
|
||||
Used0RTT bool
|
||||
// Version is the QUIC version of the QUIC connection.
|
||||
Version VersionNumber
|
||||
// GSO says if generic segmentation offload is used
|
||||
GSO bool
|
||||
}
|
||||
20
vendor/github.com/quic-go/quic-go/internal/ackhandler/ack_eliciting.go
generated
vendored
Normal file
20
vendor/github.com/quic-go/quic-go/internal/ackhandler/ack_eliciting.go
generated
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
package ackhandler
|
||||
|
||||
import "github.com/quic-go/quic-go/internal/wire"
|
||||
|
||||
// IsFrameAckEliciting returns true if the frame is ack-eliciting.
|
||||
func IsFrameAckEliciting(f wire.Frame) bool {
|
||||
_, isAck := f.(*wire.AckFrame)
|
||||
_, isConnectionClose := f.(*wire.ConnectionCloseFrame)
|
||||
return !isAck && !isConnectionClose
|
||||
}
|
||||
|
||||
// HasAckElicitingFrames returns true if at least one frame is ack-eliciting.
|
||||
func HasAckElicitingFrames(fs []Frame) bool {
|
||||
for _, f := range fs {
|
||||
if IsFrameAckEliciting(f.Frame) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
24
vendor/github.com/quic-go/quic-go/internal/ackhandler/ackhandler.go
generated
vendored
Normal file
24
vendor/github.com/quic-go/quic-go/internal/ackhandler/ackhandler.go
generated
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler.
|
||||
// clientAddressValidated indicates whether the address was validated beforehand by an address validation token.
|
||||
// clientAddressValidated has no effect for a client.
|
||||
func NewAckHandler(
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
initialMaxDatagramSize protocol.ByteCount,
|
||||
rttStats *utils.RTTStats,
|
||||
clientAddressValidated bool,
|
||||
enableECN bool,
|
||||
pers protocol.Perspective,
|
||||
tracer *logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
) (SentPacketHandler, ReceivedPacketHandler) {
|
||||
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger)
|
||||
return sph, newReceivedPacketHandler(sph, rttStats, logger)
|
||||
}
|
||||
296
vendor/github.com/quic-go/quic-go/internal/ackhandler/ecn.go
generated
vendored
Normal file
296
vendor/github.com/quic-go/quic-go/internal/ackhandler/ecn.go
generated
vendored
Normal file
@@ -0,0 +1,296 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
type ecnState uint8
|
||||
|
||||
const (
|
||||
ecnStateInitial ecnState = iota
|
||||
ecnStateTesting
|
||||
ecnStateUnknown
|
||||
ecnStateCapable
|
||||
ecnStateFailed
|
||||
)
|
||||
|
||||
// must fit into an uint8, otherwise numSentTesting and numLostTesting must have a larger type
|
||||
const numECNTestingPackets = 10
|
||||
|
||||
type ecnHandler interface {
|
||||
SentPacket(protocol.PacketNumber, protocol.ECN)
|
||||
Mode() protocol.ECN
|
||||
HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool)
|
||||
LostPacket(protocol.PacketNumber)
|
||||
}
|
||||
|
||||
// The ecnTracker performs ECN validation of a path.
|
||||
// Once failed, it doesn't do any re-validation of the path.
|
||||
// It is designed only work for 1-RTT packets, it doesn't handle multiple packet number spaces.
|
||||
// In order to avoid revealing any internal state to on-path observers,
|
||||
// callers should make sure to start using ECN (i.e. calling Mode) for the very first 1-RTT packet sent.
|
||||
// The validation logic implemented here strictly follows the algorithm described in RFC 9000 section 13.4.2 and A.4.
|
||||
type ecnTracker struct {
|
||||
state ecnState
|
||||
numSentTesting, numLostTesting uint8
|
||||
|
||||
firstTestingPacket protocol.PacketNumber
|
||||
lastTestingPacket protocol.PacketNumber
|
||||
firstCapablePacket protocol.PacketNumber
|
||||
|
||||
numSentECT0, numSentECT1 int64
|
||||
numAckedECT0, numAckedECT1, numAckedECNCE int64
|
||||
|
||||
tracer *logging.ConnectionTracer
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var _ ecnHandler = &ecnTracker{}
|
||||
|
||||
func newECNTracker(logger utils.Logger, tracer *logging.ConnectionTracer) *ecnTracker {
|
||||
return &ecnTracker{
|
||||
firstTestingPacket: protocol.InvalidPacketNumber,
|
||||
lastTestingPacket: protocol.InvalidPacketNumber,
|
||||
firstCapablePacket: protocol.InvalidPacketNumber,
|
||||
state: ecnStateInitial,
|
||||
logger: logger,
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ecnTracker) SentPacket(pn protocol.PacketNumber, ecn protocol.ECN) {
|
||||
//nolint:exhaustive // These are the only ones we need to take care of.
|
||||
switch ecn {
|
||||
case protocol.ECNNon:
|
||||
return
|
||||
case protocol.ECT0:
|
||||
e.numSentECT0++
|
||||
case protocol.ECT1:
|
||||
e.numSentECT1++
|
||||
case protocol.ECNUnsupported:
|
||||
if e.state != ecnStateFailed {
|
||||
panic("didn't expect ECN to be unsupported")
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("sent packet with unexpected ECN marking: %s", ecn))
|
||||
}
|
||||
|
||||
if e.state == ecnStateCapable && e.firstCapablePacket == protocol.InvalidPacketNumber {
|
||||
e.firstCapablePacket = pn
|
||||
}
|
||||
|
||||
if e.state != ecnStateTesting {
|
||||
return
|
||||
}
|
||||
|
||||
e.numSentTesting++
|
||||
if e.firstTestingPacket == protocol.InvalidPacketNumber {
|
||||
e.firstTestingPacket = pn
|
||||
}
|
||||
if e.numSentECT0+e.numSentECT1 >= numECNTestingPackets {
|
||||
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
|
||||
e.tracer.ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger)
|
||||
}
|
||||
e.state = ecnStateUnknown
|
||||
e.lastTestingPacket = pn
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ecnTracker) Mode() protocol.ECN {
|
||||
switch e.state {
|
||||
case ecnStateInitial:
|
||||
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
|
||||
e.tracer.ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger)
|
||||
}
|
||||
e.state = ecnStateTesting
|
||||
return e.Mode()
|
||||
case ecnStateTesting, ecnStateCapable:
|
||||
return protocol.ECT0
|
||||
case ecnStateUnknown, ecnStateFailed:
|
||||
return protocol.ECNNon
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown ECN state: %d", e.state))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) {
|
||||
if e.state != ecnStateTesting && e.state != ecnStateUnknown {
|
||||
return
|
||||
}
|
||||
if !e.isTestingPacket(pn) {
|
||||
return
|
||||
}
|
||||
e.numLostTesting++
|
||||
// Only proceed if we have sent all 10 testing packets.
|
||||
if e.state != ecnStateUnknown {
|
||||
return
|
||||
}
|
||||
if e.numLostTesting >= e.numSentTesting {
|
||||
e.logger.Debugf("Disabling ECN. All testing packets were lost.")
|
||||
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
|
||||
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets)
|
||||
}
|
||||
e.state = ecnStateFailed
|
||||
return
|
||||
}
|
||||
// Path validation also fails if some testing packets are lost, and all other testing packets where CE-marked
|
||||
e.failIfMangled()
|
||||
}
|
||||
|
||||
// HandleNewlyAcked handles the ECN counts on an ACK frame.
|
||||
// It must only be called for ACK frames that increase the largest acknowledged packet number,
|
||||
// see section 13.4.2.1 of RFC 9000.
|
||||
func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool) {
|
||||
if e.state == ecnStateFailed {
|
||||
return false
|
||||
}
|
||||
|
||||
// ECN validation can fail if the received total count for either ECT(0) or ECT(1) exceeds
|
||||
// the total number of packets sent with each corresponding ECT codepoint.
|
||||
if ect0 > e.numSentECT0 || ect1 > e.numSentECT1 {
|
||||
e.logger.Debugf("Disabling ECN. Received more ECT(0) / ECT(1) acknowledgements than packets sent.")
|
||||
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
|
||||
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent)
|
||||
}
|
||||
e.state = ecnStateFailed
|
||||
return false
|
||||
}
|
||||
|
||||
// Count ECT0 and ECT1 marks that we used when sending the packets that are now being acknowledged.
|
||||
var ackedECT0, ackedECT1 int64
|
||||
for _, p := range packets {
|
||||
//nolint:exhaustive // We only ever send ECT(0) and ECT(1).
|
||||
switch e.ecnMarking(p.PacketNumber) {
|
||||
case protocol.ECT0:
|
||||
ackedECT0++
|
||||
case protocol.ECT1:
|
||||
ackedECT1++
|
||||
}
|
||||
}
|
||||
|
||||
// If an ACK frame newly acknowledges a packet that the endpoint sent with either the ECT(0) or ECT(1)
|
||||
// codepoint set, ECN validation fails if the corresponding ECN counts are not present in the ACK frame.
|
||||
// This check detects:
|
||||
// * paths that bleach all ECN marks, and
|
||||
// * peers that don't report any ECN counts
|
||||
if (ackedECT0 > 0 || ackedECT1 > 0) && ect0 == 0 && ect1 == 0 && ecnce == 0 {
|
||||
e.logger.Debugf("Disabling ECN. ECN-marked packet acknowledged, but no ECN counts on ACK frame.")
|
||||
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
|
||||
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts)
|
||||
}
|
||||
e.state = ecnStateFailed
|
||||
return false
|
||||
}
|
||||
|
||||
// Determine the increase in ECT0, ECT1 and ECNCE marks
|
||||
newECT0 := ect0 - e.numAckedECT0
|
||||
newECT1 := ect1 - e.numAckedECT1
|
||||
newECNCE := ecnce - e.numAckedECNCE
|
||||
|
||||
// We're only processing ACKs that increase the Largest Acked.
|
||||
// Therefore, the ECN counters should only ever increase.
|
||||
// Any decrease means that the peer's counting logic is broken.
|
||||
if newECT0 < 0 || newECT1 < 0 || newECNCE < 0 {
|
||||
e.logger.Debugf("Disabling ECN. ECN counts decreased unexpectedly.")
|
||||
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
|
||||
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedDecreasedECNCounts)
|
||||
}
|
||||
e.state = ecnStateFailed
|
||||
return false
|
||||
}
|
||||
|
||||
// ECN validation also fails if the sum of the increase in ECT(0) and ECN-CE counts is less than the number
|
||||
// of newly acknowledged packets that were originally sent with an ECT(0) marking.
|
||||
// This could be the result of (partial) bleaching.
|
||||
if newECT0+newECNCE < ackedECT0 {
|
||||
e.logger.Debugf("Disabling ECN. Received less ECT(0) + ECN-CE than packets sent with ECT(0).")
|
||||
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
|
||||
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts)
|
||||
}
|
||||
e.state = ecnStateFailed
|
||||
return false
|
||||
}
|
||||
// Similarly, ECN validation fails if the sum of the increases to ECT(1) and ECN-CE counts is less than
|
||||
// the number of newly acknowledged packets sent with an ECT(1) marking.
|
||||
if newECT1+newECNCE < ackedECT1 {
|
||||
e.logger.Debugf("Disabling ECN. Received less ECT(1) + ECN-CE than packets sent with ECT(1).")
|
||||
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
|
||||
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts)
|
||||
}
|
||||
e.state = ecnStateFailed
|
||||
return false
|
||||
}
|
||||
|
||||
// update our counters
|
||||
e.numAckedECT0 = ect0
|
||||
e.numAckedECT1 = ect1
|
||||
e.numAckedECNCE = ecnce
|
||||
|
||||
// Detect mangling (a path remarking all ECN-marked testing packets as CE),
|
||||
// once all 10 testing packets have been sent out.
|
||||
if e.state == ecnStateUnknown {
|
||||
e.failIfMangled()
|
||||
if e.state == ecnStateFailed {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if e.state == ecnStateTesting || e.state == ecnStateUnknown {
|
||||
var ackedTestingPacket bool
|
||||
for _, p := range packets {
|
||||
if e.isTestingPacket(p.PacketNumber) {
|
||||
ackedTestingPacket = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// This check won't succeed if the path is mangling ECN-marks (i.e. rewrites all ECN-marked packets to CE).
|
||||
if ackedTestingPacket && (newECT0 > 0 || newECT1 > 0) {
|
||||
e.logger.Debugf("ECN capability confirmed.")
|
||||
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
|
||||
e.tracer.ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger)
|
||||
}
|
||||
e.state = ecnStateCapable
|
||||
}
|
||||
}
|
||||
|
||||
// Don't trust CE marks before having confirmed ECN capability of the path.
|
||||
// Otherwise, mangling would be misinterpreted as actual congestion.
|
||||
return e.state == ecnStateCapable && newECNCE > 0
|
||||
}
|
||||
|
||||
// failIfMangled fails ECN validation if all testing packets are lost or CE-marked.
|
||||
func (e *ecnTracker) failIfMangled() {
|
||||
numAckedECNCE := e.numAckedECNCE + int64(e.numLostTesting)
|
||||
if e.numSentECT0+e.numSentECT1 > numAckedECNCE {
|
||||
return
|
||||
}
|
||||
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
|
||||
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected)
|
||||
}
|
||||
e.state = ecnStateFailed
|
||||
}
|
||||
|
||||
func (e *ecnTracker) ecnMarking(pn protocol.PacketNumber) protocol.ECN {
|
||||
if pn < e.firstTestingPacket || e.firstTestingPacket == protocol.InvalidPacketNumber {
|
||||
return protocol.ECNNon
|
||||
}
|
||||
if pn < e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber {
|
||||
return protocol.ECT0
|
||||
}
|
||||
if pn < e.firstCapablePacket || e.firstCapablePacket == protocol.InvalidPacketNumber {
|
||||
return protocol.ECNNon
|
||||
}
|
||||
// We don't need to deal with the case when ECN validation fails,
|
||||
// since we're ignoring any ECN counts reported in ACK frames in that case.
|
||||
return protocol.ECT0
|
||||
}
|
||||
|
||||
func (e *ecnTracker) isTestingPacket(pn protocol.PacketNumber) bool {
|
||||
if e.firstTestingPacket == protocol.InvalidPacketNumber {
|
||||
return false
|
||||
}
|
||||
return pn >= e.firstTestingPacket && (pn <= e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber)
|
||||
}
|
||||
21
vendor/github.com/quic-go/quic-go/internal/ackhandler/frame.go
generated
vendored
Normal file
21
vendor/github.com/quic-go/quic-go/internal/ackhandler/frame.go
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// FrameHandler handles the acknowledgement and the loss of a frame.
|
||||
type FrameHandler interface {
|
||||
OnAcked(wire.Frame)
|
||||
OnLost(wire.Frame)
|
||||
}
|
||||
|
||||
type Frame struct {
|
||||
Frame wire.Frame // nil if the frame has already been acknowledged in another packet
|
||||
Handler FrameHandler
|
||||
}
|
||||
|
||||
type StreamFrame struct {
|
||||
Frame *wire.StreamFrame
|
||||
Handler FrameHandler
|
||||
}
|
||||
53
vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
53
vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go
generated
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// SentPacketHandler handles ACKs received for outgoing packets
|
||||
type SentPacketHandler interface {
|
||||
// SentPacket may modify the packet
|
||||
SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool)
|
||||
// ReceivedAck processes an ACK frame.
|
||||
// It does not store a copy of the frame.
|
||||
ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error)
|
||||
ReceivedBytes(protocol.ByteCount)
|
||||
DropPackets(protocol.EncryptionLevel)
|
||||
ResetForRetry(rcvTime time.Time) error
|
||||
SetHandshakeConfirmed()
|
||||
|
||||
// The SendMode determines if and what kind of packets can be sent.
|
||||
SendMode(now time.Time) SendMode
|
||||
// TimeUntilSend is the time when the next packet should be sent.
|
||||
// It is used for pacing packets.
|
||||
TimeUntilSend() time.Time
|
||||
SetMaxDatagramSize(count protocol.ByteCount)
|
||||
|
||||
// only to be called once the handshake is complete
|
||||
QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */
|
||||
|
||||
ECNMode(isShortHeaderPacket bool) protocol.ECN // isShortHeaderPacket should only be true for non-coalesced 1-RTT packets
|
||||
PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
|
||||
PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber
|
||||
|
||||
GetLossDetectionTimeout() time.Time
|
||||
OnLossDetectionTimeout() error
|
||||
}
|
||||
|
||||
type sentPacketTracker interface {
|
||||
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
|
||||
ReceivedPacket(protocol.EncryptionLevel)
|
||||
}
|
||||
|
||||
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
|
||||
type ReceivedPacketHandler interface {
|
||||
IsPotentiallyDuplicate(protocol.PacketNumber, protocol.EncryptionLevel) bool
|
||||
ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, ackEliciting bool) error
|
||||
DropPackets(protocol.EncryptionLevel)
|
||||
|
||||
GetAlarmTimeout() time.Time
|
||||
GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame
|
||||
}
|
||||
9
vendor/github.com/quic-go/quic-go/internal/ackhandler/mockgen.go
generated
vendored
Normal file
9
vendor/github.com/quic-go/quic-go/internal/ackhandler/mockgen.go
generated
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build gomock || generate
|
||||
|
||||
package ackhandler
|
||||
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/quic-go/quic-go/internal/ackhandler SentPacketTracker"
|
||||
type SentPacketTracker = sentPacketTracker
|
||||
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_ecn_handler_test.go github.com/quic-go/quic-go/internal/ackhandler ECNHandler"
|
||||
type ECNHandler = ecnHandler
|
||||
55
vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go
generated
vendored
Normal file
55
vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go
generated
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// A Packet is a packet
|
||||
type packet struct {
|
||||
SendTime time.Time
|
||||
PacketNumber protocol.PacketNumber
|
||||
StreamFrames []StreamFrame
|
||||
Frames []Frame
|
||||
LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK
|
||||
Length protocol.ByteCount
|
||||
EncryptionLevel protocol.EncryptionLevel
|
||||
|
||||
IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller.
|
||||
|
||||
includedInBytesInFlight bool
|
||||
declaredLost bool
|
||||
skippedPacket bool
|
||||
}
|
||||
|
||||
func (p *packet) outstanding() bool {
|
||||
return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket
|
||||
}
|
||||
|
||||
var packetPool = sync.Pool{New: func() any { return &packet{} }}
|
||||
|
||||
func getPacket() *packet {
|
||||
p := packetPool.Get().(*packet)
|
||||
p.PacketNumber = 0
|
||||
p.StreamFrames = nil
|
||||
p.Frames = nil
|
||||
p.LargestAcked = 0
|
||||
p.Length = 0
|
||||
p.EncryptionLevel = protocol.EncryptionLevel(0)
|
||||
p.SendTime = time.Time{}
|
||||
p.IsPathMTUProbePacket = false
|
||||
p.includedInBytesInFlight = false
|
||||
p.declaredLost = false
|
||||
p.skippedPacket = false
|
||||
return p
|
||||
}
|
||||
|
||||
// We currently only return Packets back into the pool when they're acknowledged (not when they're lost).
|
||||
// This simplifies the code, and gives the vast majority of the performance benefit we can gain from using the pool.
|
||||
func putPacket(p *packet) {
|
||||
p.Frames = nil
|
||||
p.StreamFrames = nil
|
||||
packetPool.Put(p)
|
||||
}
|
||||
84
vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go
generated
vendored
Normal file
84
vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go
generated
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type packetNumberGenerator interface {
|
||||
Peek() protocol.PacketNumber
|
||||
// Pop pops the packet number.
|
||||
// It reports if the packet number (before the one just popped) was skipped.
|
||||
// It never skips more than one packet number in a row.
|
||||
Pop() (skipped bool, _ protocol.PacketNumber)
|
||||
}
|
||||
|
||||
type sequentialPacketNumberGenerator struct {
|
||||
next protocol.PacketNumber
|
||||
}
|
||||
|
||||
var _ packetNumberGenerator = &sequentialPacketNumberGenerator{}
|
||||
|
||||
func newSequentialPacketNumberGenerator(initial protocol.PacketNumber) packetNumberGenerator {
|
||||
return &sequentialPacketNumberGenerator{next: initial}
|
||||
}
|
||||
|
||||
func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber {
|
||||
return p.next
|
||||
}
|
||||
|
||||
func (p *sequentialPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) {
|
||||
next := p.next
|
||||
p.next++
|
||||
return false, next
|
||||
}
|
||||
|
||||
// The skippingPacketNumberGenerator generates the packet number for the next packet
|
||||
// it randomly skips a packet number every averagePeriod packets (on average).
|
||||
// It is guaranteed to never skip two consecutive packet numbers.
|
||||
type skippingPacketNumberGenerator struct {
|
||||
period protocol.PacketNumber
|
||||
maxPeriod protocol.PacketNumber
|
||||
|
||||
next protocol.PacketNumber
|
||||
nextToSkip protocol.PacketNumber
|
||||
|
||||
rng utils.Rand
|
||||
}
|
||||
|
||||
var _ packetNumberGenerator = &skippingPacketNumberGenerator{}
|
||||
|
||||
func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol.PacketNumber) packetNumberGenerator {
|
||||
g := &skippingPacketNumberGenerator{
|
||||
next: initial,
|
||||
period: initialPeriod,
|
||||
maxPeriod: maxPeriod,
|
||||
}
|
||||
g.generateNewSkip()
|
||||
return g
|
||||
}
|
||||
|
||||
func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber {
|
||||
if p.next == p.nextToSkip {
|
||||
return p.next + 1
|
||||
}
|
||||
return p.next
|
||||
}
|
||||
|
||||
func (p *skippingPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) {
|
||||
next := p.next
|
||||
if p.next == p.nextToSkip {
|
||||
next++
|
||||
p.next += 2
|
||||
p.generateNewSkip()
|
||||
return true, next
|
||||
}
|
||||
p.next++ // generate a new packet number for the next packet
|
||||
return false, next
|
||||
}
|
||||
|
||||
func (p *skippingPacketNumberGenerator) generateNewSkip() {
|
||||
// make sure that there are never two consecutive packet numbers that are skipped
|
||||
p.nextToSkip = p.next + 3 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period)))
|
||||
p.period = utils.Min(2*p.period, p.maxPeriod)
|
||||
}
|
||||
142
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
Normal file
142
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go
generated
vendored
Normal file
@@ -0,0 +1,142 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type receivedPacketHandler struct {
|
||||
sentPackets sentPacketTracker
|
||||
|
||||
initialPackets *receivedPacketTracker
|
||||
handshakePackets *receivedPacketTracker
|
||||
appDataPackets *receivedPacketTracker
|
||||
|
||||
lowest1RTTPacket protocol.PacketNumber
|
||||
}
|
||||
|
||||
var _ ReceivedPacketHandler = &receivedPacketHandler{}
|
||||
|
||||
func newReceivedPacketHandler(
|
||||
sentPackets sentPacketTracker,
|
||||
rttStats *utils.RTTStats,
|
||||
logger utils.Logger,
|
||||
) ReceivedPacketHandler {
|
||||
return &receivedPacketHandler{
|
||||
sentPackets: sentPackets,
|
||||
initialPackets: newReceivedPacketTracker(rttStats, logger),
|
||||
handshakePackets: newReceivedPacketTracker(rttStats, logger),
|
||||
appDataPackets: newReceivedPacketTracker(rttStats, logger),
|
||||
lowest1RTTPacket: protocol.InvalidPacketNumber,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) ReceivedPacket(
|
||||
pn protocol.PacketNumber,
|
||||
ecn protocol.ECN,
|
||||
encLevel protocol.EncryptionLevel,
|
||||
rcvTime time.Time,
|
||||
ackEliciting bool,
|
||||
) error {
|
||||
h.sentPackets.ReceivedPacket(encLevel)
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
|
||||
case protocol.EncryptionHandshake:
|
||||
// The Handshake packet number space might already have been dropped as a result
|
||||
// of processing the CRYPTO frame that was contained in this packet.
|
||||
if h.handshakePackets == nil {
|
||||
return nil
|
||||
}
|
||||
return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
|
||||
case protocol.Encryption0RTT:
|
||||
if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket {
|
||||
return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket)
|
||||
}
|
||||
return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
|
||||
case protocol.Encryption1RTT:
|
||||
if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket {
|
||||
h.lowest1RTTPacket = pn
|
||||
}
|
||||
if err := h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil {
|
||||
return err
|
||||
}
|
||||
h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked())
|
||||
return nil
|
||||
default:
|
||||
panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
|
||||
//nolint:exhaustive // 1-RTT packet number space is never dropped.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
h.initialPackets = nil
|
||||
case protocol.EncryptionHandshake:
|
||||
h.handshakePackets = nil
|
||||
case protocol.Encryption0RTT:
|
||||
// Nothing to do here.
|
||||
// If we are rejecting 0-RTT, no 0-RTT packets will have been decrypted.
|
||||
default:
|
||||
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
|
||||
var initialAlarm, handshakeAlarm time.Time
|
||||
if h.initialPackets != nil {
|
||||
initialAlarm = h.initialPackets.GetAlarmTimeout()
|
||||
}
|
||||
if h.handshakePackets != nil {
|
||||
handshakeAlarm = h.handshakePackets.GetAlarmTimeout()
|
||||
}
|
||||
oneRTTAlarm := h.appDataPackets.GetAlarmTimeout()
|
||||
return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm)
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame {
|
||||
var ack *wire.AckFrame
|
||||
//nolint:exhaustive // 0-RTT packets can't contain ACK frames.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
if h.initialPackets != nil {
|
||||
ack = h.initialPackets.GetAckFrame(onlyIfQueued)
|
||||
}
|
||||
case protocol.EncryptionHandshake:
|
||||
if h.handshakePackets != nil {
|
||||
ack = h.handshakePackets.GetAckFrame(onlyIfQueued)
|
||||
}
|
||||
case protocol.Encryption1RTT:
|
||||
// 0-RTT packets can't contain ACK frames
|
||||
return h.appDataPackets.GetAckFrame(onlyIfQueued)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
// For Initial and Handshake ACKs, the delay time is ignored by the receiver.
|
||||
// Set it to 0 in order to save bytes.
|
||||
if ack != nil {
|
||||
ack.DelayTime = 0
|
||||
}
|
||||
return ack
|
||||
}
|
||||
|
||||
func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool {
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
if h.initialPackets != nil {
|
||||
return h.initialPackets.IsPotentiallyDuplicate(pn)
|
||||
}
|
||||
case protocol.EncryptionHandshake:
|
||||
if h.handshakePackets != nil {
|
||||
return h.handshakePackets.IsPotentiallyDuplicate(pn)
|
||||
}
|
||||
case protocol.Encryption0RTT, protocol.Encryption1RTT:
|
||||
return h.appDataPackets.IsPotentiallyDuplicate(pn)
|
||||
}
|
||||
panic("unexpected encryption level")
|
||||
}
|
||||
151
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_history.go
generated
vendored
Normal file
151
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_history.go
generated
vendored
Normal file
@@ -0,0 +1,151 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// interval is an interval from one PacketNumber to the other
|
||||
type interval struct {
|
||||
Start protocol.PacketNumber
|
||||
End protocol.PacketNumber
|
||||
}
|
||||
|
||||
var intervalElementPool sync.Pool
|
||||
|
||||
func init() {
|
||||
intervalElementPool = *list.NewPool[interval]()
|
||||
}
|
||||
|
||||
// The receivedPacketHistory stores if a packet number has already been received.
|
||||
// It generates ACK ranges which can be used to assemble an ACK frame.
|
||||
// It does not store packet contents.
|
||||
type receivedPacketHistory struct {
|
||||
ranges *list.List[interval]
|
||||
|
||||
deletedBelow protocol.PacketNumber
|
||||
}
|
||||
|
||||
func newReceivedPacketHistory() *receivedPacketHistory {
|
||||
return &receivedPacketHistory{
|
||||
ranges: list.NewWithPool[interval](&intervalElementPool),
|
||||
}
|
||||
}
|
||||
|
||||
// ReceivedPacket registers a packet with PacketNumber p and updates the ranges
|
||||
func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
|
||||
// ignore delayed packets, if we already deleted the range
|
||||
if p < h.deletedBelow {
|
||||
return false
|
||||
}
|
||||
isNew := h.addToRanges(p)
|
||||
h.maybeDeleteOldRanges()
|
||||
return isNew
|
||||
}
|
||||
|
||||
func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
|
||||
if h.ranges.Len() == 0 {
|
||||
h.ranges.PushBack(interval{Start: p, End: p})
|
||||
return true
|
||||
}
|
||||
|
||||
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
||||
// p already included in an existing range. Nothing to do here
|
||||
if p >= el.Value.Start && p <= el.Value.End {
|
||||
return false
|
||||
}
|
||||
|
||||
if el.Value.End == p-1 { // extend a range at the end
|
||||
el.Value.End = p
|
||||
return true
|
||||
}
|
||||
if el.Value.Start == p+1 { // extend a range at the beginning
|
||||
el.Value.Start = p
|
||||
|
||||
prev := el.Prev()
|
||||
if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges
|
||||
prev.Value.End = el.Value.End
|
||||
h.ranges.Remove(el)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// create a new range at the end
|
||||
if p > el.Value.End {
|
||||
h.ranges.InsertAfter(interval{Start: p, End: p}, el)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// create a new range at the beginning
|
||||
h.ranges.InsertBefore(interval{Start: p, End: p}, h.ranges.Front())
|
||||
return true
|
||||
}
|
||||
|
||||
// Delete old ranges, if we're tracking more than 500 of them.
|
||||
// This is a DoS defense against a peer that sends us too many gaps.
|
||||
func (h *receivedPacketHistory) maybeDeleteOldRanges() {
|
||||
for h.ranges.Len() > protocol.MaxNumAckRanges {
|
||||
h.ranges.Remove(h.ranges.Front())
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteBelow deletes all entries below (but not including) p
|
||||
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
|
||||
if p < h.deletedBelow {
|
||||
return
|
||||
}
|
||||
h.deletedBelow = p
|
||||
|
||||
nextEl := h.ranges.Front()
|
||||
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
|
||||
nextEl = el.Next()
|
||||
|
||||
if el.Value.End < p { // delete a whole range
|
||||
h.ranges.Remove(el)
|
||||
} else if p > el.Value.Start && p <= el.Value.End {
|
||||
el.Value.Start = p
|
||||
return
|
||||
} else { // no ranges affected. Nothing to do
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AppendAckRanges appends to a slice of all AckRanges that can be used in an AckFrame
|
||||
func (h *receivedPacketHistory) AppendAckRanges(ackRanges []wire.AckRange) []wire.AckRange {
|
||||
if h.ranges.Len() > 0 {
|
||||
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
||||
ackRanges = append(ackRanges, wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End})
|
||||
}
|
||||
}
|
||||
return ackRanges
|
||||
}
|
||||
|
||||
func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
|
||||
ackRange := wire.AckRange{}
|
||||
if h.ranges.Len() > 0 {
|
||||
r := h.ranges.Back().Value
|
||||
ackRange.Smallest = r.Start
|
||||
ackRange.Largest = r.End
|
||||
}
|
||||
return ackRange
|
||||
}
|
||||
|
||||
func (h *receivedPacketHistory) IsPotentiallyDuplicate(p protocol.PacketNumber) bool {
|
||||
if p < h.deletedBelow {
|
||||
return true
|
||||
}
|
||||
for el := h.ranges.Back(); el != nil; el = el.Prev() {
|
||||
if p > el.Value.End {
|
||||
return false
|
||||
}
|
||||
if p <= el.Value.End && p >= el.Value.Start {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
196
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go
generated
vendored
Normal file
196
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go
generated
vendored
Normal file
@@ -0,0 +1,196 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
// number of ack-eliciting packets received before sending an ack.
|
||||
const packetsBeforeAck = 2
|
||||
|
||||
type receivedPacketTracker struct {
|
||||
largestObserved protocol.PacketNumber
|
||||
ignoreBelow protocol.PacketNumber
|
||||
largestObservedRcvdTime time.Time
|
||||
ect0, ect1, ecnce uint64
|
||||
|
||||
packetHistory *receivedPacketHistory
|
||||
|
||||
maxAckDelay time.Duration
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
hasNewAck bool // true as soon as we received an ack-eliciting new packet
|
||||
ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets
|
||||
|
||||
ackElicitingPacketsReceivedSinceLastAck int
|
||||
ackAlarm time.Time
|
||||
lastAck *wire.AckFrame
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
func newReceivedPacketTracker(
|
||||
rttStats *utils.RTTStats,
|
||||
logger utils.Logger,
|
||||
) *receivedPacketTracker {
|
||||
return &receivedPacketTracker{
|
||||
packetHistory: newReceivedPacketHistory(),
|
||||
maxAckDelay: protocol.MaxAckDelay,
|
||||
rttStats: rttStats,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
|
||||
if isNew := h.packetHistory.ReceivedPacket(pn); !isNew {
|
||||
return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn)
|
||||
}
|
||||
|
||||
isMissing := h.isMissing(pn)
|
||||
if pn >= h.largestObserved {
|
||||
h.largestObserved = pn
|
||||
h.largestObservedRcvdTime = rcvTime
|
||||
}
|
||||
|
||||
if ackEliciting {
|
||||
h.hasNewAck = true
|
||||
}
|
||||
if ackEliciting {
|
||||
h.maybeQueueACK(pn, rcvTime, isMissing)
|
||||
}
|
||||
//nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECNCE.
|
||||
switch ecn {
|
||||
case protocol.ECT0:
|
||||
h.ect0++
|
||||
case protocol.ECT1:
|
||||
h.ect1++
|
||||
case protocol.ECNCE:
|
||||
h.ecnce++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IgnoreBelow sets a lower limit for acknowledging packets.
|
||||
// Packets with packet numbers smaller than p will not be acked.
|
||||
func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
|
||||
if pn <= h.ignoreBelow {
|
||||
return
|
||||
}
|
||||
h.ignoreBelow = pn
|
||||
h.packetHistory.DeleteBelow(pn)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tIgnoring all packets below %d.", pn)
|
||||
}
|
||||
}
|
||||
|
||||
// isMissing says if a packet was reported missing in the last ACK.
|
||||
func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
|
||||
if h.lastAck == nil || p < h.ignoreBelow {
|
||||
return false
|
||||
}
|
||||
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) hasNewMissingPackets() bool {
|
||||
if h.lastAck == nil {
|
||||
return false
|
||||
}
|
||||
highestRange := h.packetHistory.GetHighestAckRange()
|
||||
return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1
|
||||
}
|
||||
|
||||
// maybeQueueACK queues an ACK, if necessary.
|
||||
func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) {
|
||||
// always acknowledge the first packet
|
||||
if h.lastAck == nil {
|
||||
if !h.ackQueued {
|
||||
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
|
||||
}
|
||||
h.ackQueued = true
|
||||
return
|
||||
}
|
||||
|
||||
if h.ackQueued {
|
||||
return
|
||||
}
|
||||
|
||||
h.ackElicitingPacketsReceivedSinceLastAck++
|
||||
|
||||
// Send an ACK if this packet was reported missing in an ACK sent before.
|
||||
// Ack decimation with reordering relies on the timer to send an ACK, but if
|
||||
// missing packets we reported in the previous ack, send an ACK immediately.
|
||||
if wasMissing {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn)
|
||||
}
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
// send an ACK every 2 ack-eliciting packets
|
||||
if h.ackElicitingPacketsReceivedSinceLastAck >= packetsBeforeAck {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck)
|
||||
}
|
||||
h.ackQueued = true
|
||||
} else if h.ackAlarm.IsZero() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay)
|
||||
}
|
||||
h.ackAlarm = rcvTime.Add(h.maxAckDelay)
|
||||
}
|
||||
|
||||
// Queue an ACK if there are new missing packets to report.
|
||||
if h.hasNewMissingPackets() {
|
||||
h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.")
|
||||
h.ackQueued = true
|
||||
}
|
||||
|
||||
if h.ackQueued {
|
||||
// cancel the ack alarm
|
||||
h.ackAlarm = time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
|
||||
if !h.hasNewAck {
|
||||
return nil
|
||||
}
|
||||
now := time.Now()
|
||||
if onlyIfQueued {
|
||||
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
|
||||
return nil
|
||||
}
|
||||
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
|
||||
h.logger.Debugf("Sending ACK because the ACK timer expired.")
|
||||
}
|
||||
}
|
||||
|
||||
// This function always returns the same ACK frame struct, filled with the most recent values.
|
||||
ack := h.lastAck
|
||||
if ack == nil {
|
||||
ack = &wire.AckFrame{}
|
||||
}
|
||||
ack.Reset()
|
||||
ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedRcvdTime))
|
||||
ack.ECT0 = h.ect0
|
||||
ack.ECT1 = h.ect1
|
||||
ack.ECNCE = h.ecnce
|
||||
ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
|
||||
|
||||
h.lastAck = ack
|
||||
h.ackAlarm = time.Time{}
|
||||
h.ackQueued = false
|
||||
h.hasNewAck = false
|
||||
h.ackElicitingPacketsReceivedSinceLastAck = 0
|
||||
return ack
|
||||
}
|
||||
|
||||
func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }
|
||||
|
||||
func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
|
||||
return h.packetHistory.IsPotentiallyDuplicate(pn)
|
||||
}
|
||||
46
vendor/github.com/quic-go/quic-go/internal/ackhandler/send_mode.go
generated
vendored
Normal file
46
vendor/github.com/quic-go/quic-go/internal/ackhandler/send_mode.go
generated
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
package ackhandler
|
||||
|
||||
import "fmt"
|
||||
|
||||
// The SendMode says what kind of packets can be sent.
|
||||
type SendMode uint8
|
||||
|
||||
const (
|
||||
// SendNone means that no packets should be sent
|
||||
SendNone SendMode = iota
|
||||
// SendAck means an ACK-only packet should be sent
|
||||
SendAck
|
||||
// SendPTOInitial means that an Initial probe packet should be sent
|
||||
SendPTOInitial
|
||||
// SendPTOHandshake means that a Handshake probe packet should be sent
|
||||
SendPTOHandshake
|
||||
// SendPTOAppData means that an Application data probe packet should be sent
|
||||
SendPTOAppData
|
||||
// SendPacingLimited means that the pacer doesn't allow sending of a packet right now,
|
||||
// but will do in a little while.
|
||||
// The timestamp when sending is allowed again can be obtained via the SentPacketHandler.TimeUntilSend.
|
||||
SendPacingLimited
|
||||
// SendAny means that any packet should be sent
|
||||
SendAny
|
||||
)
|
||||
|
||||
func (s SendMode) String() string {
|
||||
switch s {
|
||||
case SendNone:
|
||||
return "none"
|
||||
case SendAck:
|
||||
return "ack"
|
||||
case SendPTOInitial:
|
||||
return "pto (Initial)"
|
||||
case SendPTOHandshake:
|
||||
return "pto (Handshake)"
|
||||
case SendPTOAppData:
|
||||
return "pto (Application Data)"
|
||||
case SendAny:
|
||||
return "any"
|
||||
case SendPacingLimited:
|
||||
return "pacing limited"
|
||||
default:
|
||||
return fmt.Sprintf("invalid send mode: %d", s)
|
||||
}
|
||||
}
|
||||
928
vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
Normal file
928
vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go
generated
vendored
Normal file
@@ -0,0 +1,928 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/congestion"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
// Maximum reordering in time space before time based loss detection considers a packet lost.
|
||||
// Specified as an RTT multiplier.
|
||||
timeThreshold = 9.0 / 8
|
||||
// Maximum reordering in packets before packet threshold loss detection considers a packet lost.
|
||||
packetThreshold = 3
|
||||
// Before validating the client's address, the server won't send more than 3x bytes than it received.
|
||||
amplificationFactor = 3
|
||||
// We use Retry packets to derive an RTT estimate. Make sure we don't set the RTT to a super low value yet.
|
||||
minRTTAfterRetry = 5 * time.Millisecond
|
||||
// The PTO duration uses exponential backoff, but is truncated to a maximum value, as allowed by RFC 8961, section 4.4.
|
||||
maxPTODuration = 60 * time.Second
|
||||
)
|
||||
|
||||
type packetNumberSpace struct {
|
||||
history *sentPacketHistory
|
||||
pns packetNumberGenerator
|
||||
|
||||
lossTime time.Time
|
||||
lastAckElicitingPacketTime time.Time
|
||||
|
||||
largestAcked protocol.PacketNumber
|
||||
largestSent protocol.PacketNumber
|
||||
}
|
||||
|
||||
func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool) *packetNumberSpace {
|
||||
var pns packetNumberGenerator
|
||||
if skipPNs {
|
||||
pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod)
|
||||
} else {
|
||||
pns = newSequentialPacketNumberGenerator(initialPN)
|
||||
}
|
||||
return &packetNumberSpace{
|
||||
history: newSentPacketHistory(),
|
||||
pns: pns,
|
||||
largestSent: protocol.InvalidPacketNumber,
|
||||
largestAcked: protocol.InvalidPacketNumber,
|
||||
}
|
||||
}
|
||||
|
||||
type sentPacketHandler struct {
|
||||
initialPackets *packetNumberSpace
|
||||
handshakePackets *packetNumberSpace
|
||||
appDataPackets *packetNumberSpace
|
||||
|
||||
// Do we know that the peer completed address validation yet?
|
||||
// Always true for the server.
|
||||
peerCompletedAddressValidation bool
|
||||
bytesReceived protocol.ByteCount
|
||||
bytesSent protocol.ByteCount
|
||||
// Have we validated the peer's address yet?
|
||||
// Always true for the client.
|
||||
peerAddressValidated bool
|
||||
|
||||
handshakeConfirmed bool
|
||||
|
||||
// lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
|
||||
// example: we send an ACK for packets 90-100 with packet number 20
|
||||
// once we receive an ACK from the peer for packet 20, the lowestNotConfirmedAcked is 101
|
||||
// Only applies to the application-data packet number space.
|
||||
lowestNotConfirmedAcked protocol.PacketNumber
|
||||
|
||||
ackedPackets []*packet // to avoid allocations in detectAndRemoveAckedPackets
|
||||
|
||||
bytesInFlight protocol.ByteCount
|
||||
|
||||
congestion congestion.SendAlgorithmWithDebugInfos
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
// The number of times a PTO has been sent without receiving an ack.
|
||||
ptoCount uint32
|
||||
ptoMode SendMode
|
||||
// The number of PTO probe packets that should be sent.
|
||||
// Only applies to the application-data packet number space.
|
||||
numProbesToSend int
|
||||
|
||||
// The alarm timeout
|
||||
alarm time.Time
|
||||
|
||||
enableECN bool
|
||||
ecnTracker ecnHandler
|
||||
|
||||
perspective protocol.Perspective
|
||||
|
||||
tracer *logging.ConnectionTracer
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
var (
|
||||
_ SentPacketHandler = &sentPacketHandler{}
|
||||
_ sentPacketTracker = &sentPacketHandler{}
|
||||
)
|
||||
|
||||
// clientAddressValidated indicates whether the address was validated beforehand by an address validation token.
|
||||
// If the address was validated, the amplification limit doesn't apply. It has no effect for a client.
|
||||
func newSentPacketHandler(
|
||||
initialPN protocol.PacketNumber,
|
||||
initialMaxDatagramSize protocol.ByteCount,
|
||||
rttStats *utils.RTTStats,
|
||||
clientAddressValidated bool,
|
||||
enableECN bool,
|
||||
pers protocol.Perspective,
|
||||
tracer *logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
) *sentPacketHandler {
|
||||
congestion := congestion.NewCubicSender(
|
||||
congestion.DefaultClock{},
|
||||
rttStats,
|
||||
initialMaxDatagramSize,
|
||||
true, // use Reno
|
||||
tracer,
|
||||
)
|
||||
|
||||
h := &sentPacketHandler{
|
||||
peerCompletedAddressValidation: pers == protocol.PerspectiveServer,
|
||||
peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated,
|
||||
initialPackets: newPacketNumberSpace(initialPN, false),
|
||||
handshakePackets: newPacketNumberSpace(0, false),
|
||||
appDataPackets: newPacketNumberSpace(0, true),
|
||||
rttStats: rttStats,
|
||||
congestion: congestion,
|
||||
perspective: pers,
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
}
|
||||
if enableECN {
|
||||
h.enableECN = true
|
||||
h.ecnTracker = newECNTracker(logger, tracer)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
|
||||
if p.includedInBytesInFlight {
|
||||
if p.Length > h.bytesInFlight {
|
||||
panic("negative bytes_in_flight")
|
||||
}
|
||||
h.bytesInFlight -= p.Length
|
||||
p.includedInBytesInFlight = false
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
|
||||
// The server won't await address validation after the handshake is confirmed.
|
||||
// This applies even if we didn't receive an ACK for a Handshake packet.
|
||||
if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake {
|
||||
h.peerCompletedAddressValidation = true
|
||||
}
|
||||
// remove outstanding packets from bytes_in_flight
|
||||
if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake {
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
// We might already have dropped this packet number space.
|
||||
if pnSpace == nil {
|
||||
return
|
||||
}
|
||||
pnSpace.history.Iterate(func(p *packet) (bool, error) {
|
||||
h.removeFromBytesInFlight(p)
|
||||
return true, nil
|
||||
})
|
||||
}
|
||||
// drop the packet history
|
||||
//nolint:exhaustive // Not every packet number space can be dropped.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
h.initialPackets = nil
|
||||
case protocol.EncryptionHandshake:
|
||||
h.handshakePackets = nil
|
||||
case protocol.Encryption0RTT:
|
||||
// This function is only called when 0-RTT is rejected,
|
||||
// and not when the client drops 0-RTT keys when the handshake completes.
|
||||
// When 0-RTT is rejected, all application data sent so far becomes invalid.
|
||||
// Delete the packets from the history and remove them from bytes_in_flight.
|
||||
h.appDataPackets.history.Iterate(func(p *packet) (bool, error) {
|
||||
if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket {
|
||||
return false, nil
|
||||
}
|
||||
h.removeFromBytesInFlight(p)
|
||||
h.appDataPackets.history.Remove(p.PacketNumber)
|
||||
return true, nil
|
||||
})
|
||||
default:
|
||||
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
|
||||
}
|
||||
if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 {
|
||||
h.tracer.UpdatedPTOCount(0)
|
||||
}
|
||||
h.ptoCount = 0
|
||||
h.numProbesToSend = 0
|
||||
h.ptoMode = SendNone
|
||||
h.setLossDetectionTimer()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) {
|
||||
wasAmplificationLimit := h.isAmplificationLimited()
|
||||
h.bytesReceived += n
|
||||
if wasAmplificationLimit && !h.isAmplificationLimited() {
|
||||
h.setLossDetectionTimer()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel) {
|
||||
if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated {
|
||||
h.peerAddressValidated = true
|
||||
h.setLossDetectionTimer()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) packetsInFlight() int {
|
||||
packetsInFlight := h.appDataPackets.history.Len()
|
||||
if h.handshakePackets != nil {
|
||||
packetsInFlight += h.handshakePackets.history.Len()
|
||||
}
|
||||
if h.initialPackets != nil {
|
||||
packetsInFlight += h.initialPackets.history.Len()
|
||||
}
|
||||
return packetsInFlight
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SentPacket(
|
||||
t time.Time,
|
||||
pn, largestAcked protocol.PacketNumber,
|
||||
streamFrames []StreamFrame,
|
||||
frames []Frame,
|
||||
encLevel protocol.EncryptionLevel,
|
||||
ecn protocol.ECN,
|
||||
size protocol.ByteCount,
|
||||
isPathMTUProbePacket bool,
|
||||
) {
|
||||
h.bytesSent += size
|
||||
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() {
|
||||
for p := utils.Max(0, pnSpace.largestSent+1); p < pn; p++ {
|
||||
h.logger.Debugf("Skipping packet number %d", p)
|
||||
}
|
||||
}
|
||||
|
||||
pnSpace.largestSent = pn
|
||||
isAckEliciting := len(streamFrames) > 0 || len(frames) > 0
|
||||
|
||||
if isAckEliciting {
|
||||
pnSpace.lastAckElicitingPacketTime = t
|
||||
h.bytesInFlight += size
|
||||
if h.numProbesToSend > 0 {
|
||||
h.numProbesToSend--
|
||||
}
|
||||
}
|
||||
h.congestion.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting)
|
||||
|
||||
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil {
|
||||
h.ecnTracker.SentPacket(pn, ecn)
|
||||
}
|
||||
|
||||
if !isAckEliciting {
|
||||
pnSpace.history.SentNonAckElicitingPacket(pn)
|
||||
if !h.peerCompletedAddressValidation {
|
||||
h.setLossDetectionTimer()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
p := getPacket()
|
||||
p.SendTime = t
|
||||
p.PacketNumber = pn
|
||||
p.EncryptionLevel = encLevel
|
||||
p.Length = size
|
||||
p.LargestAcked = largestAcked
|
||||
p.StreamFrames = streamFrames
|
||||
p.Frames = frames
|
||||
p.IsPathMTUProbePacket = isPathMTUProbePacket
|
||||
p.includedInBytesInFlight = true
|
||||
|
||||
pnSpace.history.SentAckElicitingPacket(p)
|
||||
if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
|
||||
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
|
||||
}
|
||||
h.setLossDetectionTimer()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace {
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
return h.initialPackets
|
||||
case protocol.EncryptionHandshake:
|
||||
return h.handshakePackets
|
||||
case protocol.Encryption0RTT, protocol.Encryption1RTT:
|
||||
return h.appDataPackets
|
||||
default:
|
||||
panic("invalid packet number space")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* contained 1-RTT packet */, error) {
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
|
||||
largestAcked := ack.LargestAcked()
|
||||
if largestAcked > pnSpace.largestSent {
|
||||
return false, &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: "received ACK for an unsent packet",
|
||||
}
|
||||
}
|
||||
|
||||
// Servers complete address validation when a protected packet is received.
|
||||
if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation &&
|
||||
(encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) {
|
||||
h.peerCompletedAddressValidation = true
|
||||
h.logger.Debugf("Peer doesn't await address validation any longer.")
|
||||
// Make sure that the timer is reset, even if this ACK doesn't acknowledge any (ack-eliciting) packets.
|
||||
h.setLossDetectionTimer()
|
||||
}
|
||||
|
||||
priorInFlight := h.bytesInFlight
|
||||
ackedPackets, err := h.detectAndRemoveAckedPackets(ack, encLevel)
|
||||
if err != nil || len(ackedPackets) == 0 {
|
||||
return false, err
|
||||
}
|
||||
// update the RTT, if the largest acked is newly acknowledged
|
||||
if len(ackedPackets) > 0 {
|
||||
if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() {
|
||||
// don't use the ack delay for Initial and Handshake packets
|
||||
var ackDelay time.Duration
|
||||
if encLevel == protocol.Encryption1RTT {
|
||||
ackDelay = utils.Min(ack.DelayTime, h.rttStats.MaxAckDelay())
|
||||
}
|
||||
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
|
||||
}
|
||||
h.congestion.MaybeExitSlowStart()
|
||||
}
|
||||
}
|
||||
|
||||
// Only inform the ECN tracker about new 1-RTT ACKs if the ACK increases the largest acked.
|
||||
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil && largestAcked > pnSpace.largestAcked {
|
||||
congested := h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE))
|
||||
if congested {
|
||||
h.congestion.OnCongestionEvent(largestAcked, 0, priorInFlight)
|
||||
}
|
||||
}
|
||||
|
||||
pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked)
|
||||
|
||||
if err := h.detectLostPackets(rcvTime, encLevel); err != nil {
|
||||
return false, err
|
||||
}
|
||||
var acked1RTTPacket bool
|
||||
for _, p := range ackedPackets {
|
||||
if p.includedInBytesInFlight && !p.declaredLost {
|
||||
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
|
||||
}
|
||||
if p.EncryptionLevel == protocol.Encryption1RTT {
|
||||
acked1RTTPacket = true
|
||||
}
|
||||
h.removeFromBytesInFlight(p)
|
||||
putPacket(p)
|
||||
}
|
||||
// After this point, we must not use ackedPackets any longer!
|
||||
// We've already returned the buffers.
|
||||
ackedPackets = nil //nolint:ineffassign // This is just to be on the safe side.
|
||||
|
||||
// Reset the pto_count unless the client is unsure if the server has validated the client's address.
|
||||
if h.peerCompletedAddressValidation {
|
||||
if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 {
|
||||
h.tracer.UpdatedPTOCount(0)
|
||||
}
|
||||
h.ptoCount = 0
|
||||
}
|
||||
h.numProbesToSend = 0
|
||||
|
||||
if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
|
||||
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
|
||||
}
|
||||
|
||||
h.setLossDetectionTimer()
|
||||
return acked1RTTPacket, nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
|
||||
return h.lowestNotConfirmedAcked
|
||||
}
|
||||
|
||||
// Packets are returned in ascending packet number order.
|
||||
func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*packet, error) {
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
h.ackedPackets = h.ackedPackets[:0]
|
||||
ackRangeIndex := 0
|
||||
lowestAcked := ack.LowestAcked()
|
||||
largestAcked := ack.LargestAcked()
|
||||
err := pnSpace.history.Iterate(func(p *packet) (bool, error) {
|
||||
// Ignore packets below the lowest acked
|
||||
if p.PacketNumber < lowestAcked {
|
||||
return true, nil
|
||||
}
|
||||
// Break after largest acked is reached
|
||||
if p.PacketNumber > largestAcked {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if ack.HasMissingRanges() {
|
||||
ackRange := ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
|
||||
|
||||
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ack.AckRanges)-1 {
|
||||
ackRangeIndex++
|
||||
ackRange = ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
|
||||
}
|
||||
|
||||
if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range
|
||||
return true, nil
|
||||
}
|
||||
if p.PacketNumber > ackRange.Largest {
|
||||
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
|
||||
}
|
||||
}
|
||||
if p.skippedPacket {
|
||||
return false, &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel),
|
||||
}
|
||||
}
|
||||
h.ackedPackets = append(h.ackedPackets, p)
|
||||
return true, nil
|
||||
})
|
||||
if h.logger.Debug() && len(h.ackedPackets) > 0 {
|
||||
pns := make([]protocol.PacketNumber, len(h.ackedPackets))
|
||||
for i, p := range h.ackedPackets {
|
||||
pns[i] = p.PacketNumber
|
||||
}
|
||||
h.logger.Debugf("\tnewly acked packets (%d): %d", len(pns), pns)
|
||||
}
|
||||
|
||||
for _, p := range h.ackedPackets {
|
||||
if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT {
|
||||
h.lowestNotConfirmedAcked = utils.Max(h.lowestNotConfirmedAcked, p.LargestAcked+1)
|
||||
}
|
||||
|
||||
for _, f := range p.Frames {
|
||||
if f.Handler != nil {
|
||||
f.Handler.OnAcked(f.Frame)
|
||||
}
|
||||
}
|
||||
for _, f := range p.StreamFrames {
|
||||
if f.Handler != nil {
|
||||
f.Handler.OnAcked(f.Frame)
|
||||
}
|
||||
}
|
||||
if err := pnSpace.history.Remove(p.PacketNumber); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if h.tracer != nil && h.tracer.AcknowledgedPacket != nil {
|
||||
h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber)
|
||||
}
|
||||
}
|
||||
|
||||
return h.ackedPackets, err
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.EncryptionLevel) {
|
||||
var encLevel protocol.EncryptionLevel
|
||||
var lossTime time.Time
|
||||
|
||||
if h.initialPackets != nil {
|
||||
lossTime = h.initialPackets.lossTime
|
||||
encLevel = protocol.EncryptionInitial
|
||||
}
|
||||
if h.handshakePackets != nil && (lossTime.IsZero() || (!h.handshakePackets.lossTime.IsZero() && h.handshakePackets.lossTime.Before(lossTime))) {
|
||||
lossTime = h.handshakePackets.lossTime
|
||||
encLevel = protocol.EncryptionHandshake
|
||||
}
|
||||
if lossTime.IsZero() || (!h.appDataPackets.lossTime.IsZero() && h.appDataPackets.lossTime.Before(lossTime)) {
|
||||
lossTime = h.appDataPackets.lossTime
|
||||
encLevel = protocol.Encryption1RTT
|
||||
}
|
||||
return lossTime, encLevel
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) getScaledPTO(includeMaxAckDelay bool) time.Duration {
|
||||
pto := h.rttStats.PTO(includeMaxAckDelay) << h.ptoCount
|
||||
if pto > maxPTODuration || pto <= 0 {
|
||||
return maxPTODuration
|
||||
}
|
||||
return pto
|
||||
}
|
||||
|
||||
// same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime
|
||||
func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) {
|
||||
// We only send application data probe packets once the handshake is confirmed,
|
||||
// because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets.
|
||||
if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() {
|
||||
if h.peerCompletedAddressValidation {
|
||||
return
|
||||
}
|
||||
t := time.Now().Add(h.getScaledPTO(false))
|
||||
if h.initialPackets != nil {
|
||||
return t, protocol.EncryptionInitial, true
|
||||
}
|
||||
return t, protocol.EncryptionHandshake, true
|
||||
}
|
||||
|
||||
if h.initialPackets != nil {
|
||||
encLevel = protocol.EncryptionInitial
|
||||
if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() {
|
||||
pto = t.Add(h.getScaledPTO(false))
|
||||
}
|
||||
}
|
||||
if h.handshakePackets != nil && !h.handshakePackets.lastAckElicitingPacketTime.IsZero() {
|
||||
t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(false))
|
||||
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
|
||||
pto = t
|
||||
encLevel = protocol.EncryptionHandshake
|
||||
}
|
||||
}
|
||||
if h.handshakeConfirmed && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() {
|
||||
t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(true))
|
||||
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
|
||||
pto = t
|
||||
encLevel = protocol.Encryption1RTT
|
||||
}
|
||||
}
|
||||
return pto, encLevel, true
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool {
|
||||
if h.initialPackets != nil && h.initialPackets.history.HasOutstandingPackets() {
|
||||
return true
|
||||
}
|
||||
if h.handshakePackets != nil && h.handshakePackets.history.HasOutstandingPackets() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) hasOutstandingPackets() bool {
|
||||
return h.appDataPackets.history.HasOutstandingPackets() || h.hasOutstandingCryptoPackets()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) setLossDetectionTimer() {
|
||||
oldAlarm := h.alarm // only needed in case tracing is enabled
|
||||
lossTime, encLevel := h.getLossTimeAndSpace()
|
||||
if !lossTime.IsZero() {
|
||||
// Early retransmit timer or time loss detection.
|
||||
h.alarm = lossTime
|
||||
if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm {
|
||||
h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Cancel the alarm if amplification limited.
|
||||
if h.isAmplificationLimited() {
|
||||
h.alarm = time.Time{}
|
||||
if !oldAlarm.IsZero() {
|
||||
h.logger.Debugf("Canceling loss detection timer. Amplification limited.")
|
||||
if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
|
||||
h.tracer.LossTimerCanceled()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Cancel the alarm if no packets are outstanding
|
||||
if !h.hasOutstandingPackets() && h.peerCompletedAddressValidation {
|
||||
h.alarm = time.Time{}
|
||||
if !oldAlarm.IsZero() {
|
||||
h.logger.Debugf("Canceling loss detection timer. No packets in flight.")
|
||||
if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
|
||||
h.tracer.LossTimerCanceled()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PTO alarm
|
||||
ptoTime, encLevel, ok := h.getPTOTimeAndSpace()
|
||||
if !ok {
|
||||
if !oldAlarm.IsZero() {
|
||||
h.alarm = time.Time{}
|
||||
h.logger.Debugf("Canceling loss detection timer. No PTO needed..")
|
||||
if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
|
||||
h.tracer.LossTimerCanceled()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
h.alarm = ptoTime
|
||||
if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm {
|
||||
h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) error {
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
pnSpace.lossTime = time.Time{}
|
||||
|
||||
maxRTT := float64(utils.Max(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
|
||||
lossDelay := time.Duration(timeThreshold * maxRTT)
|
||||
|
||||
// Minimum time of granularity before packets are deemed lost.
|
||||
lossDelay = utils.Max(lossDelay, protocol.TimerGranularity)
|
||||
|
||||
// Packets sent before this time are deemed lost.
|
||||
lostSendTime := now.Add(-lossDelay)
|
||||
|
||||
priorInFlight := h.bytesInFlight
|
||||
return pnSpace.history.Iterate(func(p *packet) (bool, error) {
|
||||
if p.PacketNumber > pnSpace.largestAcked {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var packetLost bool
|
||||
if p.SendTime.Before(lostSendTime) {
|
||||
packetLost = true
|
||||
if !p.skippedPacket {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber)
|
||||
}
|
||||
if h.tracer != nil && h.tracer.LostPacket != nil {
|
||||
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold)
|
||||
}
|
||||
}
|
||||
} else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold {
|
||||
packetLost = true
|
||||
if !p.skippedPacket {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber)
|
||||
}
|
||||
if h.tracer != nil && h.tracer.LostPacket != nil {
|
||||
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold)
|
||||
}
|
||||
}
|
||||
} else if pnSpace.lossTime.IsZero() {
|
||||
// Note: This conditional is only entered once per call
|
||||
lossTime := p.SendTime.Add(lossDelay)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tsetting loss timer for packet %d (%s) to %s (in %s)", p.PacketNumber, encLevel, lossDelay, lossTime)
|
||||
}
|
||||
pnSpace.lossTime = lossTime
|
||||
}
|
||||
if packetLost {
|
||||
pnSpace.history.DeclareLost(p.PacketNumber)
|
||||
if !p.skippedPacket {
|
||||
// the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted
|
||||
h.removeFromBytesInFlight(p)
|
||||
h.queueFramesForRetransmission(p)
|
||||
if !p.IsPathMTUProbePacket {
|
||||
h.congestion.OnCongestionEvent(p.PacketNumber, p.Length, priorInFlight)
|
||||
}
|
||||
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil {
|
||||
h.ecnTracker.LostPacket(p.PacketNumber)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) OnLossDetectionTimeout() error {
|
||||
defer h.setLossDetectionTimer()
|
||||
earliestLossTime, encLevel := h.getLossTimeAndSpace()
|
||||
if !earliestLossTime.IsZero() {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime)
|
||||
}
|
||||
if h.tracer != nil && h.tracer.LossTimerExpired != nil {
|
||||
h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel)
|
||||
}
|
||||
// Early retransmit or time loss detection
|
||||
return h.detectLostPackets(time.Now(), encLevel)
|
||||
}
|
||||
|
||||
// PTO
|
||||
// When all outstanding are acknowledged, the alarm is canceled in
|
||||
// setLossDetectionTimer. This doesn't reset the timer in the session though.
|
||||
// When OnAlarm is called, we therefore need to make sure that there are
|
||||
// actually packets outstanding.
|
||||
if h.bytesInFlight == 0 && !h.peerCompletedAddressValidation {
|
||||
h.ptoCount++
|
||||
h.numProbesToSend++
|
||||
if h.initialPackets != nil {
|
||||
h.ptoMode = SendPTOInitial
|
||||
} else if h.handshakePackets != nil {
|
||||
h.ptoMode = SendPTOHandshake
|
||||
} else {
|
||||
return errors.New("sentPacketHandler BUG: PTO fired, but bytes_in_flight is 0 and Initial and Handshake already dropped")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
_, encLevel, ok := h.getPTOTimeAndSpace()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if ps := h.getPacketNumberSpace(encLevel); !ps.history.HasOutstandingPackets() && !h.peerCompletedAddressValidation {
|
||||
return nil
|
||||
}
|
||||
h.ptoCount++
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount)
|
||||
}
|
||||
if h.tracer != nil {
|
||||
if h.tracer.LossTimerExpired != nil {
|
||||
h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel)
|
||||
}
|
||||
if h.tracer.UpdatedPTOCount != nil {
|
||||
h.tracer.UpdatedPTOCount(h.ptoCount)
|
||||
}
|
||||
}
|
||||
h.numProbesToSend += 2
|
||||
//nolint:exhaustive // We never arm a PTO timer for 0-RTT packets.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
h.ptoMode = SendPTOInitial
|
||||
case protocol.EncryptionHandshake:
|
||||
h.ptoMode = SendPTOHandshake
|
||||
case protocol.Encryption1RTT:
|
||||
// skip a packet number in order to elicit an immediate ACK
|
||||
pn := h.PopPacketNumber(protocol.Encryption1RTT)
|
||||
h.getPacketNumberSpace(protocol.Encryption1RTT).history.SkippedPacket(pn)
|
||||
h.ptoMode = SendPTOAppData
|
||||
default:
|
||||
return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
|
||||
return h.alarm
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN {
|
||||
if !h.enableECN {
|
||||
return protocol.ECNUnsupported
|
||||
}
|
||||
if !isShortHeaderPacket {
|
||||
return protocol.ECNNon
|
||||
}
|
||||
return h.ecnTracker.Mode()
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) {
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
pn := pnSpace.pns.Peek()
|
||||
// See section 17.1 of RFC 9000.
|
||||
return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber {
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
skipped, pn := pnSpace.pns.Pop()
|
||||
if skipped {
|
||||
skippedPN := pn - 1
|
||||
pnSpace.history.SkippedPacket(skippedPN)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Skipping packet number %d", skippedPN)
|
||||
}
|
||||
}
|
||||
return pn
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SendMode(now time.Time) SendMode {
|
||||
numTrackedPackets := h.appDataPackets.history.Len()
|
||||
if h.initialPackets != nil {
|
||||
numTrackedPackets += h.initialPackets.history.Len()
|
||||
}
|
||||
if h.handshakePackets != nil {
|
||||
numTrackedPackets += h.handshakePackets.history.Len()
|
||||
}
|
||||
|
||||
if h.isAmplificationLimited() {
|
||||
h.logger.Debugf("Amplification window limited. Received %d bytes, already sent out %d bytes", h.bytesReceived, h.bytesSent)
|
||||
return SendNone
|
||||
}
|
||||
// Don't send any packets if we're keeping track of the maximum number of packets.
|
||||
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
|
||||
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
|
||||
// but still allow sending of retransmissions and ACKs.
|
||||
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
|
||||
}
|
||||
return SendNone
|
||||
}
|
||||
if h.numProbesToSend > 0 {
|
||||
return h.ptoMode
|
||||
}
|
||||
// Only send ACKs if we're congestion limited.
|
||||
if !h.congestion.CanSend(h.bytesInFlight) {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, h.congestion.GetCongestionWindow())
|
||||
}
|
||||
return SendAck
|
||||
}
|
||||
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
|
||||
}
|
||||
return SendAck
|
||||
}
|
||||
if !h.congestion.HasPacingBudget(now) {
|
||||
return SendPacingLimited
|
||||
}
|
||||
return SendAny
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) TimeUntilSend() time.Time {
|
||||
return h.congestion.TimeUntilSend(h.bytesInFlight)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) {
|
||||
h.congestion.SetMaxDatagramSize(s)
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) isAmplificationLimited() bool {
|
||||
if h.peerAddressValidated {
|
||||
return false
|
||||
}
|
||||
return h.bytesSent >= amplificationFactor*h.bytesReceived
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool {
|
||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||
p := pnSpace.history.FirstOutstanding()
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
h.queueFramesForRetransmission(p)
|
||||
// TODO: don't declare the packet lost here.
|
||||
// Keep track of acknowledged frames instead.
|
||||
h.removeFromBytesInFlight(p)
|
||||
pnSpace.history.DeclareLost(p.PacketNumber)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) {
|
||||
if len(p.Frames) == 0 && len(p.StreamFrames) == 0 {
|
||||
panic("no frames")
|
||||
}
|
||||
for _, f := range p.Frames {
|
||||
if f.Handler != nil {
|
||||
f.Handler.OnLost(f.Frame)
|
||||
}
|
||||
}
|
||||
for _, f := range p.StreamFrames {
|
||||
if f.Handler != nil {
|
||||
f.Handler.OnLost(f.Frame)
|
||||
}
|
||||
}
|
||||
p.StreamFrames = nil
|
||||
p.Frames = nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) ResetForRetry(now time.Time) error {
|
||||
h.bytesInFlight = 0
|
||||
var firstPacketSendTime time.Time
|
||||
h.initialPackets.history.Iterate(func(p *packet) (bool, error) {
|
||||
if firstPacketSendTime.IsZero() {
|
||||
firstPacketSendTime = p.SendTime
|
||||
}
|
||||
if p.declaredLost || p.skippedPacket {
|
||||
return true, nil
|
||||
}
|
||||
h.queueFramesForRetransmission(p)
|
||||
return true, nil
|
||||
})
|
||||
// All application data packets sent at this point are 0-RTT packets.
|
||||
// In the case of a Retry, we can assume that the server dropped all of them.
|
||||
h.appDataPackets.history.Iterate(func(p *packet) (bool, error) {
|
||||
if !p.declaredLost && !p.skippedPacket {
|
||||
h.queueFramesForRetransmission(p)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
|
||||
// Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial.
|
||||
// Otherwise, we don't know which Initial the Retry was sent in response to.
|
||||
if h.ptoCount == 0 {
|
||||
// Don't set the RTT to a value lower than 5ms here.
|
||||
h.rttStats.UpdateRTT(utils.Max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
|
||||
}
|
||||
if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
|
||||
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
|
||||
}
|
||||
}
|
||||
h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Peek(), false)
|
||||
h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Peek(), true)
|
||||
oldAlarm := h.alarm
|
||||
h.alarm = time.Time{}
|
||||
if h.tracer != nil {
|
||||
if h.tracer.UpdatedPTOCount != nil {
|
||||
h.tracer.UpdatedPTOCount(0)
|
||||
}
|
||||
if !oldAlarm.IsZero() && h.tracer.LossTimerCanceled != nil {
|
||||
h.tracer.LossTimerCanceled()
|
||||
}
|
||||
}
|
||||
h.ptoCount = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHandler) SetHandshakeConfirmed() {
|
||||
if h.initialPackets != nil {
|
||||
panic("didn't drop initial correctly")
|
||||
}
|
||||
if h.handshakePackets != nil {
|
||||
panic("didn't drop handshake correctly")
|
||||
}
|
||||
h.handshakeConfirmed = true
|
||||
// We don't send PTOs for application data packets before the handshake completes.
|
||||
// Make sure the timer is armed now, if necessary.
|
||||
h.setLossDetectionTimer()
|
||||
}
|
||||
177
vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
Normal file
177
vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go
generated
vendored
Normal file
@@ -0,0 +1,177 @@
|
||||
package ackhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type sentPacketHistory struct {
|
||||
packets []*packet
|
||||
|
||||
numOutstanding int
|
||||
|
||||
highestPacketNumber protocol.PacketNumber
|
||||
}
|
||||
|
||||
func newSentPacketHistory() *sentPacketHistory {
|
||||
return &sentPacketHistory{
|
||||
packets: make([]*packet, 0, 32),
|
||||
highestPacketNumber: protocol.InvalidPacketNumber,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNumber) {
|
||||
if h.highestPacketNumber != protocol.InvalidPacketNumber {
|
||||
if pn != h.highestPacketNumber+1 {
|
||||
panic("non-sequential packet number use")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) {
|
||||
h.checkSequentialPacketNumberUse(pn)
|
||||
h.highestPacketNumber = pn
|
||||
h.packets = append(h.packets, &packet{
|
||||
PacketNumber: pn,
|
||||
skippedPacket: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber) {
|
||||
h.checkSequentialPacketNumberUse(pn)
|
||||
h.highestPacketNumber = pn
|
||||
if len(h.packets) > 0 {
|
||||
h.packets = append(h.packets, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) SentAckElicitingPacket(p *packet) {
|
||||
h.checkSequentialPacketNumberUse(p.PacketNumber)
|
||||
h.highestPacketNumber = p.PacketNumber
|
||||
h.packets = append(h.packets, p)
|
||||
if p.outstanding() {
|
||||
h.numOutstanding++
|
||||
}
|
||||
}
|
||||
|
||||
// Iterate iterates through all packets.
|
||||
func (h *sentPacketHistory) Iterate(cb func(*packet) (cont bool, err error)) error {
|
||||
for _, p := range h.packets {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
cont, err := cb(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !cont {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FirstOutstanding returns the first outstanding packet.
|
||||
func (h *sentPacketHistory) FirstOutstanding() *packet {
|
||||
if !h.HasOutstandingPackets() {
|
||||
return nil
|
||||
}
|
||||
for _, p := range h.packets {
|
||||
if p != nil && p.outstanding() {
|
||||
return p
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) Len() int {
|
||||
return len(h.packets)
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) Remove(pn protocol.PacketNumber) error {
|
||||
idx, ok := h.getIndex(pn)
|
||||
if !ok {
|
||||
return fmt.Errorf("packet %d not found in sent packet history", pn)
|
||||
}
|
||||
p := h.packets[idx]
|
||||
if p.outstanding() {
|
||||
h.numOutstanding--
|
||||
if h.numOutstanding < 0 {
|
||||
panic("negative number of outstanding packets")
|
||||
}
|
||||
}
|
||||
h.packets[idx] = nil
|
||||
// clean up all skipped packets directly before this packet number
|
||||
for idx > 0 {
|
||||
idx--
|
||||
p := h.packets[idx]
|
||||
if p == nil || !p.skippedPacket {
|
||||
break
|
||||
}
|
||||
h.packets[idx] = nil
|
||||
}
|
||||
if idx == 0 {
|
||||
h.cleanupStart()
|
||||
}
|
||||
if len(h.packets) > 0 && h.packets[0] == nil {
|
||||
panic("remove failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getIndex gets the index of packet p in the packets slice.
|
||||
func (h *sentPacketHistory) getIndex(p protocol.PacketNumber) (int, bool) {
|
||||
if len(h.packets) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
first := h.packets[0].PacketNumber
|
||||
if p < first {
|
||||
return 0, false
|
||||
}
|
||||
index := int(p - first)
|
||||
if index > len(h.packets)-1 {
|
||||
return 0, false
|
||||
}
|
||||
return index, true
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) HasOutstandingPackets() bool {
|
||||
return h.numOutstanding > 0
|
||||
}
|
||||
|
||||
// delete all nil entries at the beginning of the packets slice
|
||||
func (h *sentPacketHistory) cleanupStart() {
|
||||
for i, p := range h.packets {
|
||||
if p != nil {
|
||||
h.packets = h.packets[i:]
|
||||
return
|
||||
}
|
||||
}
|
||||
h.packets = h.packets[:0]
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) LowestPacketNumber() protocol.PacketNumber {
|
||||
if len(h.packets) == 0 {
|
||||
return protocol.InvalidPacketNumber
|
||||
}
|
||||
return h.packets[0].PacketNumber
|
||||
}
|
||||
|
||||
func (h *sentPacketHistory) DeclareLost(pn protocol.PacketNumber) {
|
||||
idx, ok := h.getIndex(pn)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
p := h.packets[idx]
|
||||
if p.outstanding() {
|
||||
h.numOutstanding--
|
||||
if h.numOutstanding < 0 {
|
||||
panic("negative number of outstanding packets")
|
||||
}
|
||||
}
|
||||
h.packets[idx] = nil
|
||||
if idx == 0 {
|
||||
h.cleanupStart()
|
||||
}
|
||||
}
|
||||
25
vendor/github.com/quic-go/quic-go/internal/congestion/bandwidth.go
generated
vendored
Normal file
25
vendor/github.com/quic-go/quic-go/internal/congestion/bandwidth.go
generated
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// Bandwidth of a connection
|
||||
type Bandwidth uint64
|
||||
|
||||
const infBandwidth Bandwidth = math.MaxUint64
|
||||
|
||||
const (
|
||||
// BitsPerSecond is 1 bit per second
|
||||
BitsPerSecond Bandwidth = 1
|
||||
// BytesPerSecond is 1 byte per second
|
||||
BytesPerSecond = 8 * BitsPerSecond
|
||||
)
|
||||
|
||||
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
|
||||
func BandwidthFromDelta(bytes protocol.ByteCount, delta time.Duration) Bandwidth {
|
||||
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
|
||||
}
|
||||
18
vendor/github.com/quic-go/quic-go/internal/congestion/clock.go
generated
vendored
Normal file
18
vendor/github.com/quic-go/quic-go/internal/congestion/clock.go
generated
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
package congestion
|
||||
|
||||
import "time"
|
||||
|
||||
// A Clock returns the current time
|
||||
type Clock interface {
|
||||
Now() time.Time
|
||||
}
|
||||
|
||||
// DefaultClock implements the Clock interface using the Go stdlib clock.
|
||||
type DefaultClock struct{}
|
||||
|
||||
var _ Clock = DefaultClock{}
|
||||
|
||||
// Now gets the current time
|
||||
func (DefaultClock) Now() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
214
vendor/github.com/quic-go/quic-go/internal/congestion/cubic.go
generated
vendored
Normal file
214
vendor/github.com/quic-go/quic-go/internal/congestion/cubic.go
generated
vendored
Normal file
@@ -0,0 +1,214 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// This cubic implementation is based on the one found in Chromiums's QUIC
|
||||
// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
|
||||
|
||||
// Constants based on TCP defaults.
|
||||
// The following constants are in 2^10 fractions of a second instead of ms to
|
||||
// allow a 10 shift right to divide.
|
||||
|
||||
// 1024*1024^3 (first 1024 is from 0.100^3)
|
||||
// where 0.100 is 100 ms which is the scaling round trip time.
|
||||
const (
|
||||
cubeScale = 40
|
||||
cubeCongestionWindowScale = 410
|
||||
cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
|
||||
// TODO: when re-enabling cubic, make sure to use the actual packet size here
|
||||
maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
|
||||
)
|
||||
|
||||
const defaultNumConnections = 1
|
||||
|
||||
// Default Cubic backoff factor
|
||||
const beta float32 = 0.7
|
||||
|
||||
// Additional backoff factor when loss occurs in the concave part of the Cubic
|
||||
// curve. This additional backoff factor is expected to give up bandwidth to
|
||||
// new concurrent flows and speed up convergence.
|
||||
const betaLastMax float32 = 0.85
|
||||
|
||||
// Cubic implements the cubic algorithm from TCP
|
||||
type Cubic struct {
|
||||
clock Clock
|
||||
|
||||
// Number of connections to simulate.
|
||||
numConnections int
|
||||
|
||||
// Time when this cycle started, after last loss event.
|
||||
epoch time.Time
|
||||
|
||||
// Max congestion window used just before last loss event.
|
||||
// Note: to improve fairness to other streams an additional back off is
|
||||
// applied to this value if the new value is below our latest value.
|
||||
lastMaxCongestionWindow protocol.ByteCount
|
||||
|
||||
// Number of acked bytes since the cycle started (epoch).
|
||||
ackedBytesCount protocol.ByteCount
|
||||
|
||||
// TCP Reno equivalent congestion window in packets.
|
||||
estimatedTCPcongestionWindow protocol.ByteCount
|
||||
|
||||
// Origin point of cubic function.
|
||||
originPointCongestionWindow protocol.ByteCount
|
||||
|
||||
// Time to origin point of cubic function in 2^10 fractions of a second.
|
||||
timeToOriginPoint uint32
|
||||
|
||||
// Last congestion window in packets computed by cubic function.
|
||||
lastTargetCongestionWindow protocol.ByteCount
|
||||
}
|
||||
|
||||
// NewCubic returns a new Cubic instance
|
||||
func NewCubic(clock Clock) *Cubic {
|
||||
c := &Cubic{
|
||||
clock: clock,
|
||||
numConnections: defaultNumConnections,
|
||||
}
|
||||
c.Reset()
|
||||
return c
|
||||
}
|
||||
|
||||
// Reset is called after a timeout to reset the cubic state
|
||||
func (c *Cubic) Reset() {
|
||||
c.epoch = time.Time{}
|
||||
c.lastMaxCongestionWindow = 0
|
||||
c.ackedBytesCount = 0
|
||||
c.estimatedTCPcongestionWindow = 0
|
||||
c.originPointCongestionWindow = 0
|
||||
c.timeToOriginPoint = 0
|
||||
c.lastTargetCongestionWindow = 0
|
||||
}
|
||||
|
||||
func (c *Cubic) alpha() float32 {
|
||||
// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
|
||||
// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
|
||||
// We derive the equivalent alpha for an N-connection emulation as:
|
||||
b := c.beta()
|
||||
return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
|
||||
}
|
||||
|
||||
func (c *Cubic) beta() float32 {
|
||||
// kNConnectionBeta is the backoff factor after loss for our N-connection
|
||||
// emulation, which emulates the effective backoff of an ensemble of N
|
||||
// TCP-Reno connections on a single loss event. The effective multiplier is
|
||||
// computed as:
|
||||
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
func (c *Cubic) betaLastMax() float32 {
|
||||
// betaLastMax is the additional backoff factor after loss for our
|
||||
// N-connection emulation, which emulates the additional backoff of
|
||||
// an ensemble of N TCP-Reno connections on a single loss event. The
|
||||
// effective multiplier is computed as:
|
||||
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
|
||||
}
|
||||
|
||||
// OnApplicationLimited is called on ack arrival when sender is unable to use
|
||||
// the available congestion window. Resets Cubic state during quiescence.
|
||||
func (c *Cubic) OnApplicationLimited() {
|
||||
// When sender is not using the available congestion window, the window does
|
||||
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
|
||||
// using the entire window during the time since the beginning of the current
|
||||
// "epoch" (the end of the last loss recovery period). Since
|
||||
// application-limited periods break this assumption, we reset the epoch when
|
||||
// in such a period. This reset effectively freezes congestion window growth
|
||||
// through application-limited periods and allows Cubic growth to continue
|
||||
// when the entire window is being used.
|
||||
c.epoch = time.Time{}
|
||||
}
|
||||
|
||||
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
|
||||
// a loss event. Returns the new congestion window in packets. The new
|
||||
// congestion window is a multiplicative decrease of our current window.
|
||||
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
|
||||
if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow {
|
||||
// We never reached the old max, so assume we are competing with another
|
||||
// flow. Use our extra back off factor to allow the other flow to go up.
|
||||
c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
|
||||
} else {
|
||||
c.lastMaxCongestionWindow = currentCongestionWindow
|
||||
}
|
||||
c.epoch = time.Time{} // Reset time.
|
||||
return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
|
||||
}
|
||||
|
||||
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
|
||||
// Returns the new congestion window in packets. The new congestion window
|
||||
// follows a cubic function that depends on the time passed since last
|
||||
// packet loss.
|
||||
func (c *Cubic) CongestionWindowAfterAck(
|
||||
ackedBytes protocol.ByteCount,
|
||||
currentCongestionWindow protocol.ByteCount,
|
||||
delayMin time.Duration,
|
||||
eventTime time.Time,
|
||||
) protocol.ByteCount {
|
||||
c.ackedBytesCount += ackedBytes
|
||||
|
||||
if c.epoch.IsZero() {
|
||||
// First ACK after a loss event.
|
||||
c.epoch = eventTime // Start of epoch.
|
||||
c.ackedBytesCount = ackedBytes // Reset count.
|
||||
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
|
||||
c.estimatedTCPcongestionWindow = currentCongestionWindow
|
||||
if c.lastMaxCongestionWindow <= currentCongestionWindow {
|
||||
c.timeToOriginPoint = 0
|
||||
c.originPointCongestionWindow = currentCongestionWindow
|
||||
} else {
|
||||
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
|
||||
c.originPointCongestionWindow = c.lastMaxCongestionWindow
|
||||
}
|
||||
}
|
||||
|
||||
// Change the time unit from microseconds to 2^10 fractions per second. Take
|
||||
// the round trip time in account. This is done to allow us to use shift as a
|
||||
// divide operator.
|
||||
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
|
||||
|
||||
// Right-shifts of negative, signed numbers have implementation-dependent
|
||||
// behavior, so force the offset to be positive, as is done in the kernel.
|
||||
offset := int64(c.timeToOriginPoint) - elapsedTime
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
|
||||
deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale
|
||||
var targetCongestionWindow protocol.ByteCount
|
||||
if elapsedTime > int64(c.timeToOriginPoint) {
|
||||
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
|
||||
} else {
|
||||
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
|
||||
}
|
||||
// Limit the CWND increase to half the acked bytes.
|
||||
targetCongestionWindow = utils.Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
|
||||
|
||||
// Increase the window by approximately Alpha * 1 MSS of bytes every
|
||||
// time we ack an estimated tcp window of bytes. For small
|
||||
// congestion windows (less than 25), the formula below will
|
||||
// increase slightly slower than linearly per estimated tcp window
|
||||
// of bytes.
|
||||
c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow))
|
||||
c.ackedBytesCount = 0
|
||||
|
||||
// We have a new cubic congestion window.
|
||||
c.lastTargetCongestionWindow = targetCongestionWindow
|
||||
|
||||
// Compute target congestion_window based on cubic target and estimated TCP
|
||||
// congestion_window, use highest (fastest).
|
||||
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
|
||||
targetCongestionWindow = c.estimatedTCPcongestionWindow
|
||||
}
|
||||
return targetCongestionWindow
|
||||
}
|
||||
|
||||
// SetNumConnections sets the number of emulated connections
|
||||
func (c *Cubic) SetNumConnections(n int) {
|
||||
c.numConnections = n
|
||||
}
|
||||
316
vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go
generated
vendored
Normal file
316
vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go
generated
vendored
Normal file
@@ -0,0 +1,316 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxDatagramSize is the default maximum packet size used in the Linux TCP implementation.
|
||||
// Used in QUIC for congestion window computations in bytes.
|
||||
initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
|
||||
maxBurstPackets = 3
|
||||
renoBeta = 0.7 // Reno backoff factor.
|
||||
minCongestionWindowPackets = 2
|
||||
initialCongestionWindow = 32
|
||||
)
|
||||
|
||||
type cubicSender struct {
|
||||
hybridSlowStart HybridSlowStart
|
||||
rttStats *utils.RTTStats
|
||||
cubic *Cubic
|
||||
pacer *pacer
|
||||
clock Clock
|
||||
|
||||
reno bool
|
||||
|
||||
// Track the largest packet that has been sent.
|
||||
largestSentPacketNumber protocol.PacketNumber
|
||||
|
||||
// Track the largest packet that has been acked.
|
||||
largestAckedPacketNumber protocol.PacketNumber
|
||||
|
||||
// Track the largest packet number outstanding when a CWND cutback occurs.
|
||||
largestSentAtLastCutback protocol.PacketNumber
|
||||
|
||||
// Whether the last loss event caused us to exit slowstart.
|
||||
// Used for stats collection of slowstartPacketsLost
|
||||
lastCutbackExitedSlowstart bool
|
||||
|
||||
// Congestion window in bytes.
|
||||
congestionWindow protocol.ByteCount
|
||||
|
||||
// Slow start congestion window in bytes, aka ssthresh.
|
||||
slowStartThreshold protocol.ByteCount
|
||||
|
||||
// ACK counter for the Reno implementation.
|
||||
numAckedPackets uint64
|
||||
|
||||
initialCongestionWindow protocol.ByteCount
|
||||
initialMaxCongestionWindow protocol.ByteCount
|
||||
|
||||
maxDatagramSize protocol.ByteCount
|
||||
|
||||
lastState logging.CongestionState
|
||||
tracer *logging.ConnectionTracer
|
||||
}
|
||||
|
||||
var (
|
||||
_ SendAlgorithm = &cubicSender{}
|
||||
_ SendAlgorithmWithDebugInfos = &cubicSender{}
|
||||
)
|
||||
|
||||
// NewCubicSender makes a new cubic sender
|
||||
func NewCubicSender(
|
||||
clock Clock,
|
||||
rttStats *utils.RTTStats,
|
||||
initialMaxDatagramSize protocol.ByteCount,
|
||||
reno bool,
|
||||
tracer *logging.ConnectionTracer,
|
||||
) *cubicSender {
|
||||
return newCubicSender(
|
||||
clock,
|
||||
rttStats,
|
||||
reno,
|
||||
initialMaxDatagramSize,
|
||||
initialCongestionWindow*initialMaxDatagramSize,
|
||||
protocol.MaxCongestionWindowPackets*initialMaxDatagramSize,
|
||||
tracer,
|
||||
)
|
||||
}
|
||||
|
||||
func newCubicSender(
|
||||
clock Clock,
|
||||
rttStats *utils.RTTStats,
|
||||
reno bool,
|
||||
initialMaxDatagramSize,
|
||||
initialCongestionWindow,
|
||||
initialMaxCongestionWindow protocol.ByteCount,
|
||||
tracer *logging.ConnectionTracer,
|
||||
) *cubicSender {
|
||||
c := &cubicSender{
|
||||
rttStats: rttStats,
|
||||
largestSentPacketNumber: protocol.InvalidPacketNumber,
|
||||
largestAckedPacketNumber: protocol.InvalidPacketNumber,
|
||||
largestSentAtLastCutback: protocol.InvalidPacketNumber,
|
||||
initialCongestionWindow: initialCongestionWindow,
|
||||
initialMaxCongestionWindow: initialMaxCongestionWindow,
|
||||
congestionWindow: initialCongestionWindow,
|
||||
slowStartThreshold: protocol.MaxByteCount,
|
||||
cubic: NewCubic(clock),
|
||||
clock: clock,
|
||||
reno: reno,
|
||||
tracer: tracer,
|
||||
maxDatagramSize: initialMaxDatagramSize,
|
||||
}
|
||||
c.pacer = newPacer(c.BandwidthEstimate)
|
||||
if c.tracer != nil && c.tracer.UpdatedCongestionState != nil {
|
||||
c.lastState = logging.CongestionStateSlowStart
|
||||
c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) time.Time {
|
||||
return c.pacer.TimeUntilSend()
|
||||
}
|
||||
|
||||
func (c *cubicSender) HasPacingBudget(now time.Time) bool {
|
||||
return c.pacer.Budget(now) >= c.maxDatagramSize
|
||||
}
|
||||
|
||||
func (c *cubicSender) maxCongestionWindow() protocol.ByteCount {
|
||||
return c.maxDatagramSize * protocol.MaxCongestionWindowPackets
|
||||
}
|
||||
|
||||
func (c *cubicSender) minCongestionWindow() protocol.ByteCount {
|
||||
return c.maxDatagramSize * minCongestionWindowPackets
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketSent(
|
||||
sentTime time.Time,
|
||||
_ protocol.ByteCount,
|
||||
packetNumber protocol.PacketNumber,
|
||||
bytes protocol.ByteCount,
|
||||
isRetransmittable bool,
|
||||
) {
|
||||
c.pacer.SentPacket(sentTime, bytes)
|
||||
if !isRetransmittable {
|
||||
return
|
||||
}
|
||||
c.largestSentPacketNumber = packetNumber
|
||||
c.hybridSlowStart.OnPacketSent(packetNumber)
|
||||
}
|
||||
|
||||
func (c *cubicSender) CanSend(bytesInFlight protocol.ByteCount) bool {
|
||||
return bytesInFlight < c.GetCongestionWindow()
|
||||
}
|
||||
|
||||
func (c *cubicSender) InRecovery() bool {
|
||||
return c.largestAckedPacketNumber != protocol.InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
|
||||
}
|
||||
|
||||
func (c *cubicSender) InSlowStart() bool {
|
||||
return c.GetCongestionWindow() < c.slowStartThreshold
|
||||
}
|
||||
|
||||
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
|
||||
return c.congestionWindow
|
||||
}
|
||||
|
||||
func (c *cubicSender) MaybeExitSlowStart() {
|
||||
if c.InSlowStart() &&
|
||||
c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
|
||||
// exit slow start
|
||||
c.slowStartThreshold = c.congestionWindow
|
||||
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnPacketAcked(
|
||||
ackedPacketNumber protocol.PacketNumber,
|
||||
ackedBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
eventTime time.Time,
|
||||
) {
|
||||
c.largestAckedPacketNumber = utils.Max(ackedPacketNumber, c.largestAckedPacketNumber)
|
||||
if c.InRecovery() {
|
||||
return
|
||||
}
|
||||
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
|
||||
if c.InSlowStart() {
|
||||
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) OnCongestionEvent(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) {
|
||||
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
|
||||
// already sent should be treated as a single loss event, since it's expected.
|
||||
if packetNumber <= c.largestSentAtLastCutback {
|
||||
return
|
||||
}
|
||||
c.lastCutbackExitedSlowstart = c.InSlowStart()
|
||||
c.maybeTraceStateChange(logging.CongestionStateRecovery)
|
||||
|
||||
if c.reno {
|
||||
c.congestionWindow = protocol.ByteCount(float64(c.congestionWindow) * renoBeta)
|
||||
} else {
|
||||
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
|
||||
}
|
||||
if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
|
||||
c.congestionWindow = minCwnd
|
||||
}
|
||||
c.slowStartThreshold = c.congestionWindow
|
||||
c.largestSentAtLastCutback = c.largestSentPacketNumber
|
||||
// reset packet count from congestion avoidance mode. We start
|
||||
// counting again when we're out of recovery.
|
||||
c.numAckedPackets = 0
|
||||
}
|
||||
|
||||
// Called when we receive an ack. Normal TCP tracks how many packets one ack
|
||||
// represents, but quic has a separate ack for each packet.
|
||||
func (c *cubicSender) maybeIncreaseCwnd(
|
||||
_ protocol.PacketNumber,
|
||||
ackedBytes protocol.ByteCount,
|
||||
priorInFlight protocol.ByteCount,
|
||||
eventTime time.Time,
|
||||
) {
|
||||
// Do not increase the congestion window unless the sender is close to using
|
||||
// the current window.
|
||||
if !c.isCwndLimited(priorInFlight) {
|
||||
c.cubic.OnApplicationLimited()
|
||||
c.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
|
||||
return
|
||||
}
|
||||
if c.congestionWindow >= c.maxCongestionWindow() {
|
||||
return
|
||||
}
|
||||
if c.InSlowStart() {
|
||||
// TCP slow start, exponential growth, increase by one for each ACK.
|
||||
c.congestionWindow += c.maxDatagramSize
|
||||
c.maybeTraceStateChange(logging.CongestionStateSlowStart)
|
||||
return
|
||||
}
|
||||
// Congestion avoidance
|
||||
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
|
||||
if c.reno {
|
||||
// Classic Reno congestion avoidance.
|
||||
c.numAckedPackets++
|
||||
if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
|
||||
c.congestionWindow += c.maxDatagramSize
|
||||
c.numAckedPackets = 0
|
||||
}
|
||||
} else {
|
||||
c.congestionWindow = utils.Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool {
|
||||
congestionWindow := c.GetCongestionWindow()
|
||||
if bytesInFlight >= congestionWindow {
|
||||
return true
|
||||
}
|
||||
availableBytes := congestionWindow - bytesInFlight
|
||||
slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
|
||||
return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
|
||||
}
|
||||
|
||||
// BandwidthEstimate returns the current bandwidth estimate
|
||||
func (c *cubicSender) BandwidthEstimate() Bandwidth {
|
||||
srtt := c.rttStats.SmoothedRTT()
|
||||
if srtt == 0 {
|
||||
// If we haven't measured an rtt, the bandwidth estimate is unknown.
|
||||
return infBandwidth
|
||||
}
|
||||
return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
|
||||
}
|
||||
|
||||
// OnRetransmissionTimeout is called on an retransmission timeout
|
||||
func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
|
||||
c.largestSentAtLastCutback = protocol.InvalidPacketNumber
|
||||
if !packetsRetransmitted {
|
||||
return
|
||||
}
|
||||
c.hybridSlowStart.Restart()
|
||||
c.cubic.Reset()
|
||||
c.slowStartThreshold = c.congestionWindow / 2
|
||||
c.congestionWindow = c.minCongestionWindow()
|
||||
}
|
||||
|
||||
// OnConnectionMigration is called when the connection is migrated (?)
|
||||
func (c *cubicSender) OnConnectionMigration() {
|
||||
c.hybridSlowStart.Restart()
|
||||
c.largestSentPacketNumber = protocol.InvalidPacketNumber
|
||||
c.largestAckedPacketNumber = protocol.InvalidPacketNumber
|
||||
c.largestSentAtLastCutback = protocol.InvalidPacketNumber
|
||||
c.lastCutbackExitedSlowstart = false
|
||||
c.cubic.Reset()
|
||||
c.numAckedPackets = 0
|
||||
c.congestionWindow = c.initialCongestionWindow
|
||||
c.slowStartThreshold = c.initialMaxCongestionWindow
|
||||
}
|
||||
|
||||
func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
|
||||
if c.tracer == nil || c.tracer.UpdatedCongestionState == nil || new == c.lastState {
|
||||
return
|
||||
}
|
||||
c.tracer.UpdatedCongestionState(new)
|
||||
c.lastState = new
|
||||
}
|
||||
|
||||
func (c *cubicSender) SetMaxDatagramSize(s protocol.ByteCount) {
|
||||
if s < c.maxDatagramSize {
|
||||
panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
|
||||
}
|
||||
cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
|
||||
c.maxDatagramSize = s
|
||||
if cwndIsMinCwnd {
|
||||
c.congestionWindow = c.minCongestionWindow()
|
||||
}
|
||||
c.pacer.SetMaxDatagramSize(s)
|
||||
}
|
||||
113
vendor/github.com/quic-go/quic-go/internal/congestion/hybrid_slow_start.go
generated
vendored
Normal file
113
vendor/github.com/quic-go/quic-go/internal/congestion/hybrid_slow_start.go
generated
vendored
Normal file
@@ -0,0 +1,113 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
// Note(pwestin): the magic clamping numbers come from the original code in
|
||||
// tcp_cubic.c.
|
||||
const hybridStartLowWindow = protocol.ByteCount(16)
|
||||
|
||||
// Number of delay samples for detecting the increase of delay.
|
||||
const hybridStartMinSamples = uint32(8)
|
||||
|
||||
// Exit slow start if the min rtt has increased by more than 1/8th.
|
||||
const hybridStartDelayFactorExp = 3 // 2^3 = 8
|
||||
// The original paper specifies 2 and 8ms, but those have changed over time.
|
||||
const (
|
||||
hybridStartDelayMinThresholdUs = int64(4000)
|
||||
hybridStartDelayMaxThresholdUs = int64(16000)
|
||||
)
|
||||
|
||||
// HybridSlowStart implements the TCP hybrid slow start algorithm
|
||||
type HybridSlowStart struct {
|
||||
endPacketNumber protocol.PacketNumber
|
||||
lastSentPacketNumber protocol.PacketNumber
|
||||
started bool
|
||||
currentMinRTT time.Duration
|
||||
rttSampleCount uint32
|
||||
hystartFound bool
|
||||
}
|
||||
|
||||
// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
|
||||
func (s *HybridSlowStart) StartReceiveRound(lastSent protocol.PacketNumber) {
|
||||
s.endPacketNumber = lastSent
|
||||
s.currentMinRTT = 0
|
||||
s.rttSampleCount = 0
|
||||
s.started = true
|
||||
}
|
||||
|
||||
// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
|
||||
func (s *HybridSlowStart) IsEndOfRound(ack protocol.PacketNumber) bool {
|
||||
return s.endPacketNumber < ack
|
||||
}
|
||||
|
||||
// ShouldExitSlowStart should be called on every new ack frame, since a new
|
||||
// RTT measurement can be made then.
|
||||
// rtt: the RTT for this ack packet.
|
||||
// minRTT: is the lowest delay (RTT) we have seen during the session.
|
||||
// congestionWindow: the congestion window in packets.
|
||||
func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow protocol.ByteCount) bool {
|
||||
if !s.started {
|
||||
// Time to start the hybrid slow start.
|
||||
s.StartReceiveRound(s.lastSentPacketNumber)
|
||||
}
|
||||
if s.hystartFound {
|
||||
return true
|
||||
}
|
||||
// Second detection parameter - delay increase detection.
|
||||
// Compare the minimum delay (s.currentMinRTT) of the current
|
||||
// burst of packets relative to the minimum delay during the session.
|
||||
// Note: we only look at the first few(8) packets in each burst, since we
|
||||
// only want to compare the lowest RTT of the burst relative to previous
|
||||
// bursts.
|
||||
s.rttSampleCount++
|
||||
if s.rttSampleCount <= hybridStartMinSamples {
|
||||
if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
|
||||
s.currentMinRTT = latestRTT
|
||||
}
|
||||
}
|
||||
// We only need to check this once per round.
|
||||
if s.rttSampleCount == hybridStartMinSamples {
|
||||
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
|
||||
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
|
||||
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
|
||||
minRTTincreaseThresholdUs = utils.Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
|
||||
minRTTincreaseThreshold := time.Duration(utils.Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
|
||||
|
||||
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
|
||||
s.hystartFound = true
|
||||
}
|
||||
}
|
||||
// Exit from slow start if the cwnd is greater than 16 and
|
||||
// increasing delay is found.
|
||||
return congestionWindow >= hybridStartLowWindow && s.hystartFound
|
||||
}
|
||||
|
||||
// OnPacketSent is called when a packet was sent
|
||||
func (s *HybridSlowStart) OnPacketSent(packetNumber protocol.PacketNumber) {
|
||||
s.lastSentPacketNumber = packetNumber
|
||||
}
|
||||
|
||||
// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
|
||||
// the round when the final packet of the burst is received and start it on
|
||||
// the next incoming ack.
|
||||
func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber protocol.PacketNumber) {
|
||||
if s.IsEndOfRound(ackedPacketNumber) {
|
||||
s.started = false
|
||||
}
|
||||
}
|
||||
|
||||
// Started returns true if started
|
||||
func (s *HybridSlowStart) Started() bool {
|
||||
return s.started
|
||||
}
|
||||
|
||||
// Restart the slow start phase
|
||||
func (s *HybridSlowStart) Restart() {
|
||||
s.started = false
|
||||
s.hystartFound = false
|
||||
}
|
||||
28
vendor/github.com/quic-go/quic-go/internal/congestion/interface.go
generated
vendored
Normal file
28
vendor/github.com/quic-go/quic-go/internal/congestion/interface.go
generated
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
// A SendAlgorithm performs congestion control
|
||||
type SendAlgorithm interface {
|
||||
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Time
|
||||
HasPacingBudget(now time.Time) bool
|
||||
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
|
||||
CanSend(bytesInFlight protocol.ByteCount) bool
|
||||
MaybeExitSlowStart()
|
||||
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
|
||||
OnCongestionEvent(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
|
||||
OnRetransmissionTimeout(packetsRetransmitted bool)
|
||||
SetMaxDatagramSize(protocol.ByteCount)
|
||||
}
|
||||
|
||||
// A SendAlgorithmWithDebugInfos is a SendAlgorithm that exposes some debug infos
|
||||
type SendAlgorithmWithDebugInfos interface {
|
||||
SendAlgorithm
|
||||
InSlowStart() bool
|
||||
InRecovery() bool
|
||||
GetCongestionWindow() protocol.ByteCount
|
||||
}
|
||||
80
vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go
generated
vendored
Normal file
80
vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go
generated
vendored
Normal file
@@ -0,0 +1,80 @@
|
||||
package congestion
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
const maxBurstSizePackets = 10
|
||||
|
||||
// The pacer implements a token bucket pacing algorithm.
|
||||
type pacer struct {
|
||||
budgetAtLastSent protocol.ByteCount
|
||||
maxDatagramSize protocol.ByteCount
|
||||
lastSentTime time.Time
|
||||
adjustedBandwidth func() uint64 // in bytes/s
|
||||
}
|
||||
|
||||
func newPacer(getBandwidth func() Bandwidth) *pacer {
|
||||
p := &pacer{
|
||||
maxDatagramSize: initialMaxDatagramSize,
|
||||
adjustedBandwidth: func() uint64 {
|
||||
// Bandwidth is in bits/s. We need the value in bytes/s.
|
||||
bw := uint64(getBandwidth() / BytesPerSecond)
|
||||
// Use a slightly higher value than the actual measured bandwidth.
|
||||
// RTT variations then won't result in under-utilization of the congestion window.
|
||||
// Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire,
|
||||
// provided the congestion window is fully utilized and acknowledgments arrive at regular intervals.
|
||||
return bw * 5 / 4
|
||||
},
|
||||
}
|
||||
p.budgetAtLastSent = p.maxBurstSize()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *pacer) SentPacket(sendTime time.Time, size protocol.ByteCount) {
|
||||
budget := p.Budget(sendTime)
|
||||
if size > budget {
|
||||
p.budgetAtLastSent = 0
|
||||
} else {
|
||||
p.budgetAtLastSent = budget - size
|
||||
}
|
||||
p.lastSentTime = sendTime
|
||||
}
|
||||
|
||||
func (p *pacer) Budget(now time.Time) protocol.ByteCount {
|
||||
if p.lastSentTime.IsZero() {
|
||||
return p.maxBurstSize()
|
||||
}
|
||||
budget := p.budgetAtLastSent + (protocol.ByteCount(p.adjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
|
||||
if budget < 0 { // protect against overflows
|
||||
budget = protocol.MaxByteCount
|
||||
}
|
||||
return utils.Min(p.maxBurstSize(), budget)
|
||||
}
|
||||
|
||||
func (p *pacer) maxBurstSize() protocol.ByteCount {
|
||||
return utils.Max(
|
||||
protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.adjustedBandwidth())/1e9,
|
||||
maxBurstSizePackets*p.maxDatagramSize,
|
||||
)
|
||||
}
|
||||
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
// It returns the zero value of time.Time if a packet can be sent immediately.
|
||||
func (p *pacer) TimeUntilSend() time.Time {
|
||||
if p.budgetAtLastSent >= p.maxDatagramSize {
|
||||
return time.Time{}
|
||||
}
|
||||
return p.lastSentTime.Add(utils.Max(
|
||||
protocol.MinPacingDelay,
|
||||
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.adjustedBandwidth())))*time.Nanosecond,
|
||||
))
|
||||
}
|
||||
|
||||
func (p *pacer) SetMaxDatagramSize(s protocol.ByteCount) {
|
||||
p.maxDatagramSize = s
|
||||
}
|
||||
125
vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
Normal file
125
vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go
generated
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
package flowcontrol
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type baseFlowController struct {
|
||||
// for sending data
|
||||
bytesSent protocol.ByteCount
|
||||
sendWindow protocol.ByteCount
|
||||
lastBlockedAt protocol.ByteCount
|
||||
|
||||
// for receiving data
|
||||
//nolint:structcheck // The mutex is used both by the stream and the connection flow controller
|
||||
mutex sync.Mutex
|
||||
bytesRead protocol.ByteCount
|
||||
highestReceived protocol.ByteCount
|
||||
receiveWindow protocol.ByteCount
|
||||
receiveWindowSize protocol.ByteCount
|
||||
maxReceiveWindowSize protocol.ByteCount
|
||||
|
||||
allowWindowIncrease func(size protocol.ByteCount) bool
|
||||
|
||||
epochStartTime time.Time
|
||||
epochStartOffset protocol.ByteCount
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
// IsNewlyBlocked says if it is newly blocked by flow control.
|
||||
// For every offset, it only returns true once.
|
||||
// If it is blocked, the offset is returned.
|
||||
func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
|
||||
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
|
||||
return false, 0
|
||||
}
|
||||
c.lastBlockedAt = c.sendWindow
|
||||
return true, c.sendWindow
|
||||
}
|
||||
|
||||
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.bytesSent += n
|
||||
}
|
||||
|
||||
// UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame.
|
||||
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
|
||||
if offset > c.sendWindow {
|
||||
c.sendWindow = offset
|
||||
}
|
||||
}
|
||||
|
||||
func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
|
||||
// this only happens during connection establishment, when data is sent before we receive the peer's transport parameters
|
||||
if c.bytesSent > c.sendWindow {
|
||||
return 0
|
||||
}
|
||||
return c.sendWindow - c.bytesSent
|
||||
}
|
||||
|
||||
// needs to be called with locked mutex
|
||||
func (c *baseFlowController) addBytesRead(n protocol.ByteCount) {
|
||||
// pretend we sent a WindowUpdate when reading the first byte
|
||||
// this way auto-tuning of the window size already works for the first WindowUpdate
|
||||
if c.bytesRead == 0 {
|
||||
c.startNewAutoTuningEpoch(time.Now())
|
||||
}
|
||||
c.bytesRead += n
|
||||
}
|
||||
|
||||
func (c *baseFlowController) hasWindowUpdate() bool {
|
||||
bytesRemaining := c.receiveWindow - c.bytesRead
|
||||
// update the window when more than the threshold was consumed
|
||||
return bytesRemaining <= protocol.ByteCount(float64(c.receiveWindowSize)*(1-protocol.WindowUpdateThreshold))
|
||||
}
|
||||
|
||||
// getWindowUpdate updates the receive window, if necessary
|
||||
// it returns the new offset
|
||||
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
|
||||
if !c.hasWindowUpdate() {
|
||||
return 0
|
||||
}
|
||||
|
||||
c.maybeAdjustWindowSize()
|
||||
c.receiveWindow = c.bytesRead + c.receiveWindowSize
|
||||
return c.receiveWindow
|
||||
}
|
||||
|
||||
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
|
||||
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
|
||||
func (c *baseFlowController) maybeAdjustWindowSize() {
|
||||
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
|
||||
// don't do anything if less than half the window has been consumed
|
||||
if bytesReadInEpoch <= c.receiveWindowSize/2 {
|
||||
return
|
||||
}
|
||||
rtt := c.rttStats.SmoothedRTT()
|
||||
if rtt == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
|
||||
now := time.Now()
|
||||
if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
|
||||
// window is consumed too fast, try to increase the window size
|
||||
newSize := utils.Min(2*c.receiveWindowSize, c.maxReceiveWindowSize)
|
||||
if newSize > c.receiveWindowSize && (c.allowWindowIncrease == nil || c.allowWindowIncrease(newSize-c.receiveWindowSize)) {
|
||||
c.receiveWindowSize = newSize
|
||||
}
|
||||
}
|
||||
c.startNewAutoTuningEpoch(now)
|
||||
}
|
||||
|
||||
func (c *baseFlowController) startNewAutoTuningEpoch(now time.Time) {
|
||||
c.epochStartTime = now
|
||||
c.epochStartOffset = c.bytesRead
|
||||
}
|
||||
|
||||
func (c *baseFlowController) checkFlowControlViolation() bool {
|
||||
return c.highestReceived > c.receiveWindow
|
||||
}
|
||||
112
vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go
generated
vendored
Normal file
112
vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go
generated
vendored
Normal file
@@ -0,0 +1,112 @@
|
||||
package flowcontrol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type connectionFlowController struct {
|
||||
baseFlowController
|
||||
|
||||
queueWindowUpdate func()
|
||||
}
|
||||
|
||||
var _ ConnectionFlowController = &connectionFlowController{}
|
||||
|
||||
// NewConnectionFlowController gets a new flow controller for the connection
|
||||
// It is created before we receive the peer's transport parameters, thus it starts with a sendWindow of 0.
|
||||
func NewConnectionFlowController(
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
queueWindowUpdate func(),
|
||||
allowWindowIncrease func(size protocol.ByteCount) bool,
|
||||
rttStats *utils.RTTStats,
|
||||
logger utils.Logger,
|
||||
) ConnectionFlowController {
|
||||
return &connectionFlowController{
|
||||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
receiveWindowSize: receiveWindow,
|
||||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
allowWindowIncrease: allowWindowIncrease,
|
||||
logger: logger,
|
||||
},
|
||||
queueWindowUpdate: queueWindowUpdate,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
|
||||
return c.baseFlowController.sendWindowSize()
|
||||
}
|
||||
|
||||
// IncrementHighestReceived adds an increment to the highestReceived value
|
||||
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.highestReceived += increment
|
||||
if c.checkFlowControlViolation() {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FlowControlError,
|
||||
ErrorMessage: fmt.Sprintf("received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
c.baseFlowController.addBytesRead(n)
|
||||
shouldQueueWindowUpdate := c.hasWindowUpdate()
|
||||
c.mutex.Unlock()
|
||||
if shouldQueueWindowUpdate {
|
||||
c.queueWindowUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
c.mutex.Lock()
|
||||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if oldWindowSize < c.receiveWindowSize {
|
||||
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return offset
|
||||
}
|
||||
|
||||
// EnsureMinimumWindowSize sets a minimum window size
|
||||
// it should make sure that the connection-level window is increased when a stream-level window grows
|
||||
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
if inc > c.receiveWindowSize {
|
||||
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10))
|
||||
newSize := utils.Min(inc, c.maxReceiveWindowSize)
|
||||
if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) {
|
||||
c.receiveWindowSize = newSize
|
||||
}
|
||||
c.startNewAutoTuningEpoch(time.Now())
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
// Reset rests the flow controller. This happens when 0-RTT is rejected.
|
||||
// All stream data is invalidated, it's if we had never opened a stream and never sent any data.
|
||||
// At that point, we only have sent stream data, but we didn't have the keys to open 1-RTT keys yet.
|
||||
func (c *connectionFlowController) Reset() error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if c.bytesRead > 0 || c.highestReceived > 0 || !c.epochStartTime.IsZero() {
|
||||
return errors.New("flow controller reset after reading data")
|
||||
}
|
||||
c.bytesSent = 0
|
||||
c.lastBlockedAt = 0
|
||||
return nil
|
||||
}
|
||||
42
vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go
generated
vendored
Normal file
42
vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go
generated
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
package flowcontrol
|
||||
|
||||
import "github.com/quic-go/quic-go/internal/protocol"
|
||||
|
||||
type flowController interface {
|
||||
// for sending
|
||||
SendWindowSize() protocol.ByteCount
|
||||
UpdateSendWindow(protocol.ByteCount)
|
||||
AddBytesSent(protocol.ByteCount)
|
||||
// for receiving
|
||||
AddBytesRead(protocol.ByteCount)
|
||||
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
|
||||
IsNewlyBlocked() (bool, protocol.ByteCount)
|
||||
}
|
||||
|
||||
// A StreamFlowController is a flow controller for a QUIC stream.
|
||||
type StreamFlowController interface {
|
||||
flowController
|
||||
// for receiving
|
||||
// UpdateHighestReceived should be called when a new highest offset is received
|
||||
// final has to be to true if this is the final offset of the stream,
|
||||
// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
|
||||
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
|
||||
// Abandon should be called when reading from the stream is aborted early,
|
||||
// and there won't be any further calls to AddBytesRead.
|
||||
Abandon()
|
||||
}
|
||||
|
||||
// The ConnectionFlowController is the flow controller for the connection.
|
||||
type ConnectionFlowController interface {
|
||||
flowController
|
||||
Reset() error
|
||||
}
|
||||
|
||||
type connectionFlowControllerI interface {
|
||||
ConnectionFlowController
|
||||
// The following two methods are not supposed to be called from outside this packet, but are needed internally
|
||||
// for sending
|
||||
EnsureMinimumWindowSize(protocol.ByteCount)
|
||||
// for receiving
|
||||
IncrementHighestReceived(protocol.ByteCount) error
|
||||
}
|
||||
149
vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
Normal file
149
vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
Normal file
@@ -0,0 +1,149 @@
|
||||
package flowcontrol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
type streamFlowController struct {
|
||||
baseFlowController
|
||||
|
||||
streamID protocol.StreamID
|
||||
|
||||
queueWindowUpdate func()
|
||||
|
||||
connection connectionFlowControllerI
|
||||
|
||||
receivedFinalOffset bool
|
||||
}
|
||||
|
||||
var _ StreamFlowController = &streamFlowController{}
|
||||
|
||||
// NewStreamFlowController gets a new flow controller for a stream
|
||||
func NewStreamFlowController(
|
||||
streamID protocol.StreamID,
|
||||
cfc ConnectionFlowController,
|
||||
receiveWindow protocol.ByteCount,
|
||||
maxReceiveWindow protocol.ByteCount,
|
||||
initialSendWindow protocol.ByteCount,
|
||||
queueWindowUpdate func(protocol.StreamID),
|
||||
rttStats *utils.RTTStats,
|
||||
logger utils.Logger,
|
||||
) StreamFlowController {
|
||||
return &streamFlowController{
|
||||
streamID: streamID,
|
||||
connection: cfc.(connectionFlowControllerI),
|
||||
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
|
||||
baseFlowController: baseFlowController{
|
||||
rttStats: rttStats,
|
||||
receiveWindow: receiveWindow,
|
||||
receiveWindowSize: receiveWindow,
|
||||
maxReceiveWindowSize: maxReceiveWindow,
|
||||
sendWindow: initialSendWindow,
|
||||
logger: logger,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateHighestReceived updates the highestReceived value, if the offset is higher.
|
||||
func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool) error {
|
||||
// If the final offset for this stream is already known, check for consistency.
|
||||
if c.receivedFinalOffset {
|
||||
// If we receive another final offset, check that it's the same.
|
||||
if final && offset != c.highestReceived {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FinalSizeError,
|
||||
ErrorMessage: fmt.Sprintf("received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, offset),
|
||||
}
|
||||
}
|
||||
// Check that the offset is below the final offset.
|
||||
if offset > c.highestReceived {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FinalSizeError,
|
||||
ErrorMessage: fmt.Sprintf("received offset %d for stream %d, but final offset was already received at %d", offset, c.streamID, c.highestReceived),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if final {
|
||||
c.receivedFinalOffset = true
|
||||
}
|
||||
if offset == c.highestReceived {
|
||||
return nil
|
||||
}
|
||||
// A higher offset was received before.
|
||||
// This can happen due to reordering.
|
||||
if offset <= c.highestReceived {
|
||||
if final {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FinalSizeError,
|
||||
ErrorMessage: fmt.Sprintf("received final offset %d for stream %d, but already received offset %d before", offset, c.streamID, c.highestReceived),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
increment := offset - c.highestReceived
|
||||
c.highestReceived = offset
|
||||
if c.checkFlowControlViolation() {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.FlowControlError,
|
||||
ErrorMessage: fmt.Sprintf("received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow),
|
||||
}
|
||||
}
|
||||
return c.connection.IncrementHighestReceived(increment)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
|
||||
c.mutex.Lock()
|
||||
c.baseFlowController.addBytesRead(n)
|
||||
shouldQueueWindowUpdate := c.shouldQueueWindowUpdate()
|
||||
c.mutex.Unlock()
|
||||
if shouldQueueWindowUpdate {
|
||||
c.queueWindowUpdate()
|
||||
}
|
||||
c.connection.AddBytesRead(n)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) Abandon() {
|
||||
c.mutex.Lock()
|
||||
unread := c.highestReceived - c.bytesRead
|
||||
c.mutex.Unlock()
|
||||
if unread > 0 {
|
||||
c.connection.AddBytesRead(unread)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||
c.baseFlowController.AddBytesSent(n)
|
||||
c.connection.AddBytesSent(n)
|
||||
}
|
||||
|
||||
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
||||
return utils.Min(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize())
|
||||
}
|
||||
|
||||
func (c *streamFlowController) shouldQueueWindowUpdate() bool {
|
||||
return !c.receivedFinalOffset && c.hasWindowUpdate()
|
||||
}
|
||||
|
||||
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
|
||||
// If we already received the final offset for this stream, the peer won't need any additional flow control credit.
|
||||
if c.receivedFinalOffset {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
|
||||
c.mutex.Lock()
|
||||
oldWindowSize := c.receiveWindowSize
|
||||
offset := c.baseFlowController.getWindowUpdate()
|
||||
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
|
||||
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
|
||||
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return offset
|
||||
}
|
||||
94
vendor/github.com/quic-go/quic-go/internal/handshake/aead.go
generated
vendored
Normal file
94
vendor/github.com/quic-go/quic-go/internal/handshake/aead.go
generated
vendored
Normal file
@@ -0,0 +1,94 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
)
|
||||
|
||||
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
|
||||
keyLabel := hkdfLabelKeyV1
|
||||
ivLabel := hkdfLabelIVV1
|
||||
if v == protocol.Version2 {
|
||||
keyLabel = hkdfLabelKeyV2
|
||||
ivLabel = hkdfLabelIVV2
|
||||
}
|
||||
key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, keyLabel, suite.KeyLen)
|
||||
iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, ivLabel, suite.IVLen())
|
||||
return suite.AEAD(key, iv)
|
||||
}
|
||||
|
||||
type longHeaderSealer struct {
|
||||
aead cipher.AEAD
|
||||
headerProtector headerProtector
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
}
|
||||
|
||||
var _ LongHeaderSealer = &longHeaderSealer{}
|
||||
|
||||
func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer {
|
||||
return &longHeaderSealer{
|
||||
aead: aead,
|
||||
headerProtector: headerProtector,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
|
||||
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
return s.aead.Seal(dst, s.nonceBuf, src, ad)
|
||||
}
|
||||
|
||||
func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
s.headerProtector.EncryptHeader(sample, firstByte, pnBytes)
|
||||
}
|
||||
|
||||
func (s *longHeaderSealer) Overhead() int {
|
||||
return s.aead.Overhead()
|
||||
}
|
||||
|
||||
type longHeaderOpener struct {
|
||||
aead cipher.AEAD
|
||||
headerProtector headerProtector
|
||||
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
}
|
||||
|
||||
var _ LongHeaderOpener = &longHeaderOpener{}
|
||||
|
||||
func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener {
|
||||
return &longHeaderOpener{
|
||||
aead: aead,
|
||||
headerProtector: headerProtector,
|
||||
nonceBuf: make([]byte, aead.NonceSize()),
|
||||
}
|
||||
}
|
||||
|
||||
func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
|
||||
return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN)
|
||||
}
|
||||
|
||||
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
|
||||
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
|
||||
if err == nil {
|
||||
o.highestRcvdPN = utils.Max(o.highestRcvdPN, pn)
|
||||
} else {
|
||||
err = ErrDecryptionFailed
|
||||
}
|
||||
return dec, err
|
||||
}
|
||||
|
||||
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
|
||||
o.headerProtector.DecryptHeader(sample, firstByte, pnBytes)
|
||||
}
|
||||
104
vendor/github.com/quic-go/quic-go/internal/handshake/cipher_suite.go
generated
vendored
Normal file
104
vendor/github.com/quic-go/quic-go/internal/handshake/cipher_suite.go
generated
vendored
Normal file
@@ -0,0 +1,104 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// These cipher suite implementations are copied from the standard library crypto/tls package.
|
||||
|
||||
const aeadNonceLength = 12
|
||||
|
||||
type cipherSuite struct {
|
||||
ID uint16
|
||||
Hash crypto.Hash
|
||||
KeyLen int
|
||||
AEAD func(key, nonceMask []byte) cipher.AEAD
|
||||
}
|
||||
|
||||
func (s cipherSuite) IVLen() int { return aeadNonceLength }
|
||||
|
||||
func getCipherSuite(id uint16) *cipherSuite {
|
||||
switch id {
|
||||
case tls.TLS_AES_128_GCM_SHA256:
|
||||
return &cipherSuite{ID: tls.TLS_AES_128_GCM_SHA256, Hash: crypto.SHA256, KeyLen: 16, AEAD: aeadAESGCMTLS13}
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
return &cipherSuite{ID: tls.TLS_CHACHA20_POLY1305_SHA256, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadChaCha20Poly1305}
|
||||
case tls.TLS_AES_256_GCM_SHA384:
|
||||
return &cipherSuite{ID: tls.TLS_AES_256_GCM_SHA384, Hash: crypto.SHA384, KeyLen: 32, AEAD: aeadAESGCMTLS13}
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown cypher suite: %d", id))
|
||||
}
|
||||
}
|
||||
|
||||
func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
|
||||
if len(nonceMask) != aeadNonceLength {
|
||||
panic("tls: internal error: wrong nonce length")
|
||||
}
|
||||
aes, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(aes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ret := &xorNonceAEAD{aead: aead}
|
||||
copy(ret.nonceMask[:], nonceMask)
|
||||
return ret
|
||||
}
|
||||
|
||||
func aeadChaCha20Poly1305(key, nonceMask []byte) cipher.AEAD {
|
||||
if len(nonceMask) != aeadNonceLength {
|
||||
panic("tls: internal error: wrong nonce length")
|
||||
}
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ret := &xorNonceAEAD{aead: aead}
|
||||
copy(ret.nonceMask[:], nonceMask)
|
||||
return ret
|
||||
}
|
||||
|
||||
// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
|
||||
// before each call.
|
||||
type xorNonceAEAD struct {
|
||||
nonceMask [aeadNonceLength]byte
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
|
||||
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
|
||||
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
|
||||
|
||||
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
21
vendor/github.com/quic-go/quic-go/internal/handshake/conn.go
generated
vendored
Normal file
21
vendor/github.com/quic-go/quic-go/internal/handshake/conn.go
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
localAddr, remoteAddr net.Addr
|
||||
}
|
||||
|
||||
var _ net.Conn = &conn{}
|
||||
|
||||
func (c *conn) Read([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Write([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Close() error { return nil }
|
||||
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
|
||||
func (c *conn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetDeadline(time.Time) error { return nil }
|
||||
681
vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go
generated
vendored
Normal file
681
vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go
generated
vendored
Normal file
@@ -0,0 +1,681 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
type quicVersionContextKey struct{}
|
||||
|
||||
var QUICVersionContextKey = &quicVersionContextKey{}
|
||||
|
||||
const clientSessionStateRevision = 3
|
||||
|
||||
type cryptoSetup struct {
|
||||
tlsConf *tls.Config
|
||||
conn *qtls.QUICConn
|
||||
|
||||
events []Event
|
||||
|
||||
version protocol.VersionNumber
|
||||
|
||||
ourParams *wire.TransportParameters
|
||||
peerParams *wire.TransportParameters
|
||||
|
||||
zeroRTTParameters *wire.TransportParameters
|
||||
allow0RTT bool
|
||||
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
tracer *logging.ConnectionTracer
|
||||
logger utils.Logger
|
||||
|
||||
perspective protocol.Perspective
|
||||
|
||||
mutex sync.Mutex // protects all members below
|
||||
|
||||
handshakeCompleteTime time.Time
|
||||
|
||||
zeroRTTOpener LongHeaderOpener // only set for the server
|
||||
zeroRTTSealer LongHeaderSealer // only set for the client
|
||||
|
||||
initialOpener LongHeaderOpener
|
||||
initialSealer LongHeaderSealer
|
||||
|
||||
handshakeOpener LongHeaderOpener
|
||||
handshakeSealer LongHeaderSealer
|
||||
|
||||
used0RTT atomic.Bool
|
||||
|
||||
aead *updatableAEAD
|
||||
has1RTTSealer bool
|
||||
has1RTTOpener bool
|
||||
}
|
||||
|
||||
var _ CryptoSetup = &cryptoSetup{}
|
||||
|
||||
// NewCryptoSetupClient creates a new crypto setup for the client
|
||||
func NewCryptoSetupClient(
|
||||
connID protocol.ConnectionID,
|
||||
tp *wire.TransportParameters,
|
||||
tlsConf *tls.Config,
|
||||
enable0RTT bool,
|
||||
rttStats *utils.RTTStats,
|
||||
tracer *logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
version protocol.VersionNumber,
|
||||
) CryptoSetup {
|
||||
cs := newCryptoSetup(
|
||||
connID,
|
||||
tp,
|
||||
rttStats,
|
||||
tracer,
|
||||
logger,
|
||||
protocol.PerspectiveClient,
|
||||
version,
|
||||
)
|
||||
|
||||
tlsConf = tlsConf.Clone()
|
||||
tlsConf.MinVersion = tls.VersionTLS13
|
||||
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
|
||||
qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState)
|
||||
cs.tlsConf = tlsConf
|
||||
|
||||
cs.conn = qtls.QUICClient(quicConf)
|
||||
cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient))
|
||||
|
||||
return cs
|
||||
}
|
||||
|
||||
// NewCryptoSetupServer creates a new crypto setup for the server
|
||||
func NewCryptoSetupServer(
|
||||
connID protocol.ConnectionID,
|
||||
localAddr, remoteAddr net.Addr,
|
||||
tp *wire.TransportParameters,
|
||||
tlsConf *tls.Config,
|
||||
allow0RTT bool,
|
||||
rttStats *utils.RTTStats,
|
||||
tracer *logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
version protocol.VersionNumber,
|
||||
) CryptoSetup {
|
||||
cs := newCryptoSetup(
|
||||
connID,
|
||||
tp,
|
||||
rttStats,
|
||||
tracer,
|
||||
logger,
|
||||
protocol.PerspectiveServer,
|
||||
version,
|
||||
)
|
||||
cs.allow0RTT = allow0RTT
|
||||
|
||||
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
|
||||
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket)
|
||||
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)
|
||||
|
||||
cs.tlsConf = quicConf.TLSConfig
|
||||
cs.conn = qtls.QUICServer(quicConf)
|
||||
|
||||
return cs
|
||||
}
|
||||
|
||||
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
|
||||
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
|
||||
// that allows the caller to get the local and the remote address.
|
||||
func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) {
|
||||
if conf.GetConfigForClient != nil {
|
||||
gcfc := conf.GetConfigForClient
|
||||
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
|
||||
c, err := gcfc(info)
|
||||
if c != nil {
|
||||
c = c.Clone()
|
||||
// This won't be necessary anymore once https://github.com/golang/go/issues/63722 is accepted.
|
||||
c.MinVersion = tls.VersionTLS13
|
||||
// We're returning a tls.Config here, so we need to apply this recursively.
|
||||
addConnToClientHelloInfo(c, localAddr, remoteAddr)
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
}
|
||||
if conf.GetCertificate != nil {
|
||||
gc := conf.GetCertificate
|
||||
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
|
||||
return gc(info)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newCryptoSetup(
|
||||
connID protocol.ConnectionID,
|
||||
tp *wire.TransportParameters,
|
||||
rttStats *utils.RTTStats,
|
||||
tracer *logging.ConnectionTracer,
|
||||
logger utils.Logger,
|
||||
perspective protocol.Perspective,
|
||||
version protocol.VersionNumber,
|
||||
) *cryptoSetup {
|
||||
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
|
||||
if tracer != nil && tracer.UpdatedKeyFromTLS != nil {
|
||||
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
|
||||
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
|
||||
}
|
||||
return &cryptoSetup{
|
||||
initialSealer: initialSealer,
|
||||
initialOpener: initialOpener,
|
||||
aead: newUpdatableAEAD(rttStats, tracer, logger, version),
|
||||
events: make([]Event, 0, 16),
|
||||
ourParams: tp,
|
||||
rttStats: rttStats,
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
perspective: perspective,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
|
||||
initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version)
|
||||
h.initialSealer = initialSealer
|
||||
h.initialOpener = initialOpener
|
||||
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
|
||||
h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
|
||||
return h.aead.SetLargestAcked(pn)
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) StartHandshake() error {
|
||||
err := h.conn.Start(context.WithValue(context.Background(), QUICVersionContextKey, h.version))
|
||||
if err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
for {
|
||||
ev := h.conn.NextEvent()
|
||||
done, err := h.handleEvent(ev)
|
||||
if err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
}
|
||||
if h.perspective == protocol.PerspectiveClient {
|
||||
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
|
||||
h.logger.Debugf("Doing 0-RTT.")
|
||||
h.events = append(h.events, Event{Kind: EventRestoredTransportParameters, TransportParameters: h.zeroRTTParameters})
|
||||
} else {
|
||||
h.logger.Debugf("Not doing 0-RTT. Has sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the crypto setup.
|
||||
// It aborts the handshake, if it is still running.
|
||||
func (h *cryptoSetup) Close() error {
|
||||
return h.conn.Close()
|
||||
}
|
||||
|
||||
// HandleMessage handles a TLS handshake message.
|
||||
// It is called by the crypto streams when a new message is available.
|
||||
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
|
||||
if err := h.handleMessage(data, encLevel); err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
|
||||
if err := h.conn.HandleData(qtls.ToTLSEncryptionLevel(encLevel), data); err != nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
ev := h.conn.NextEvent()
|
||||
done, err := h.handleEvent(ev)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) {
|
||||
switch ev.Kind {
|
||||
case qtls.QUICNoEvent:
|
||||
return true, nil
|
||||
case qtls.QUICSetReadSecret:
|
||||
h.SetReadKey(ev.Level, ev.Suite, ev.Data)
|
||||
return false, nil
|
||||
case qtls.QUICSetWriteSecret:
|
||||
h.SetWriteKey(ev.Level, ev.Suite, ev.Data)
|
||||
return false, nil
|
||||
case qtls.QUICTransportParameters:
|
||||
return false, h.handleTransportParameters(ev.Data)
|
||||
case qtls.QUICTransportParametersRequired:
|
||||
h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective))
|
||||
return false, nil
|
||||
case qtls.QUICRejectedEarlyData:
|
||||
h.rejected0RTT()
|
||||
return false, nil
|
||||
case qtls.QUICWriteData:
|
||||
h.WriteRecord(ev.Level, ev.Data)
|
||||
return false, nil
|
||||
case qtls.QUICHandshakeDone:
|
||||
h.handshakeComplete()
|
||||
return false, nil
|
||||
default:
|
||||
return false, fmt.Errorf("unexpected event: %d", ev.Kind)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) NextEvent() Event {
|
||||
if len(h.events) == 0 {
|
||||
return Event{Kind: EventNoEvent}
|
||||
}
|
||||
ev := h.events[0]
|
||||
h.events = h.events[1:]
|
||||
return ev
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleTransportParameters(data []byte) error {
|
||||
var tp wire.TransportParameters
|
||||
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
|
||||
return err
|
||||
}
|
||||
h.peerParams = &tp
|
||||
h.events = append(h.events, Event{Kind: EventReceivedTransportParameters, TransportParameters: h.peerParams})
|
||||
return nil
|
||||
}
|
||||
|
||||
// must be called after receiving the transport parameters
|
||||
func (h *cryptoSetup) marshalDataForSessionState() []byte {
|
||||
b := make([]byte, 0, 256)
|
||||
b = quicvarint.Append(b, clientSessionStateRevision)
|
||||
b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds()))
|
||||
return h.peerParams.MarshalForSessionTicket(b)
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleDataFromSessionState(data []byte) {
|
||||
tp, err := h.handleDataFromSessionStateImpl(data)
|
||||
if err != nil {
|
||||
h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
|
||||
return
|
||||
}
|
||||
h.zeroRTTParameters = tp
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) {
|
||||
r := bytes.NewReader(data)
|
||||
ver, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ver != clientSessionStateRevision {
|
||||
return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
|
||||
}
|
||||
rtt, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond)
|
||||
var tp wire.TransportParameters
|
||||
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tp, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) getDataForSessionTicket() []byte {
|
||||
ticket := &sessionTicket{
|
||||
RTT: h.rttStats.SmoothedRTT(),
|
||||
}
|
||||
if h.allow0RTT {
|
||||
ticket.Parameters = h.ourParams
|
||||
}
|
||||
return ticket.Marshal()
|
||||
}
|
||||
|
||||
// GetSessionTicket generates a new session ticket.
|
||||
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
|
||||
// It is only valid for the server.
|
||||
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
|
||||
if err := qtls.SendSessionTicket(h.conn, h.allow0RTT); err != nil {
|
||||
// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
|
||||
// We can't check h.tlsConfig here, since the actual config might have been obtained from
|
||||
// the GetConfigForClient callback.
|
||||
// See https://github.com/golang/go/issues/62032.
|
||||
// Once that issue is resolved, this error assertion can be removed.
|
||||
if strings.Contains(err.Error(), "session ticket keys unavailable") {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
ev := h.conn.NextEvent()
|
||||
if ev.Kind != qtls.QUICWriteData || ev.Level != qtls.QUICEncryptionLevelApplication {
|
||||
panic("crypto/tls bug: where's my session ticket?")
|
||||
}
|
||||
ticket := ev.Data
|
||||
if ev := h.conn.NextEvent(); ev.Kind != qtls.QUICNoEvent {
|
||||
panic("crypto/tls bug: why more than one ticket?")
|
||||
}
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
// handleSessionTicket is called for the server when receiving the client's session ticket.
|
||||
// It reads parameters from the session ticket and decides whether to accept 0-RTT when the session ticket is used for 0-RTT.
|
||||
func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bool) bool {
|
||||
var t sessionTicket
|
||||
if err := t.Unmarshal(sessionTicketData, using0RTT); err != nil {
|
||||
h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error())
|
||||
return false
|
||||
}
|
||||
h.rttStats.SetInitialRTT(t.RTT)
|
||||
if !using0RTT {
|
||||
return false
|
||||
}
|
||||
valid := h.ourParams.ValidFor0RTT(t.Parameters)
|
||||
if !valid {
|
||||
h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.")
|
||||
return false
|
||||
}
|
||||
if !h.allow0RTT {
|
||||
h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.")
|
||||
return false
|
||||
}
|
||||
h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
|
||||
return true
|
||||
}
|
||||
|
||||
// rejected0RTT is called for the client when the server rejects 0-RTT.
|
||||
func (h *cryptoSetup) rejected0RTT() {
|
||||
h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.")
|
||||
|
||||
h.mutex.Lock()
|
||||
had0RTTKeys := h.zeroRTTSealer != nil
|
||||
h.zeroRTTSealer = nil
|
||||
h.mutex.Unlock()
|
||||
|
||||
if had0RTTKeys {
|
||||
h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
||||
suite := getCipherSuite(suiteID)
|
||||
h.mutex.Lock()
|
||||
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
||||
switch el {
|
||||
case qtls.QUICEncryptionLevelEarly:
|
||||
if h.perspective == protocol.PerspectiveClient {
|
||||
panic("Received 0-RTT read key for the client")
|
||||
}
|
||||
h.zeroRTTOpener = newLongHeaderOpener(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
)
|
||||
h.used0RTT.Store(true)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
case qtls.QUICEncryptionLevelHandshake:
|
||||
h.handshakeOpener = newLongHeaderOpener(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
case qtls.QUICEncryptionLevelApplication:
|
||||
h.aead.SetReadKey(suite, trafficSecret)
|
||||
h.has1RTTOpener = true
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
default:
|
||||
panic("unexpected read encryption level")
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
|
||||
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
||||
suite := getCipherSuite(suiteID)
|
||||
h.mutex.Lock()
|
||||
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
||||
switch el {
|
||||
case qtls.QUICEncryptionLevelEarly:
|
||||
if h.perspective == protocol.PerspectiveServer {
|
||||
panic("Received 0-RTT write key for the server")
|
||||
}
|
||||
h.zeroRTTSealer = newLongHeaderSealer(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
)
|
||||
h.mutex.Unlock()
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective)
|
||||
}
|
||||
// don't set used0RTT here. 0-RTT might still get rejected.
|
||||
return
|
||||
case qtls.QUICEncryptionLevelHandshake:
|
||||
h.handshakeSealer = newLongHeaderSealer(
|
||||
createAEAD(suite, trafficSecret, h.version),
|
||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||
)
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
case qtls.QUICEncryptionLevelApplication:
|
||||
h.aead.SetWriteKey(suite, trafficSecret)
|
||||
h.has1RTTSealer = true
|
||||
if h.logger.Debug() {
|
||||
h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||
}
|
||||
if h.zeroRTTSealer != nil {
|
||||
// Once we receive handshake keys, we know that 0-RTT was not rejected.
|
||||
h.used0RTT.Store(true)
|
||||
h.zeroRTTSealer = nil
|
||||
h.logger.Debugf("Dropping 0-RTT keys.")
|
||||
if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil {
|
||||
h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
|
||||
}
|
||||
}
|
||||
default:
|
||||
panic("unexpected write encryption level")
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
|
||||
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteRecord is called when TLS writes data
|
||||
func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) {
|
||||
//nolint:exhaustive // handshake records can only be written for Initial and Handshake.
|
||||
switch encLevel {
|
||||
case qtls.QUICEncryptionLevelInitial:
|
||||
h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p})
|
||||
case qtls.QUICEncryptionLevelHandshake:
|
||||
h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p})
|
||||
case qtls.QUICEncryptionLevelApplication:
|
||||
panic("unexpected write")
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) DiscardInitialKeys() {
|
||||
h.mutex.Lock()
|
||||
dropped := h.initialOpener != nil
|
||||
h.initialOpener = nil
|
||||
h.initialSealer = nil
|
||||
h.mutex.Unlock()
|
||||
if dropped {
|
||||
h.logger.Debugf("Dropping Initial keys.")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) handshakeComplete() {
|
||||
h.handshakeCompleteTime = time.Now()
|
||||
h.events = append(h.events, Event{Kind: EventHandshakeComplete})
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) SetHandshakeConfirmed() {
|
||||
h.aead.SetHandshakeConfirmed()
|
||||
// drop Handshake keys
|
||||
var dropped bool
|
||||
h.mutex.Lock()
|
||||
if h.handshakeOpener != nil {
|
||||
h.handshakeOpener = nil
|
||||
h.handshakeSealer = nil
|
||||
dropped = true
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
if dropped {
|
||||
h.logger.Debugf("Dropping Handshake keys.")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.initialSealer == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.initialSealer, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.zeroRTTSealer == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.zeroRTTSealer, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.handshakeSealer == nil {
|
||||
if h.initialSealer == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
return h.handshakeSealer, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if !h.has1RTTSealer {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
return h.aead, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.initialOpener == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.initialOpener, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.zeroRTTOpener == nil {
|
||||
if h.initialOpener != nil {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
// if the initial opener is also not available, the keys were already dropped
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.zeroRTTOpener, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.handshakeOpener == nil {
|
||||
if h.initialOpener != nil {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
// if the initial opener is also not available, the keys were already dropped
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
return h.handshakeOpener, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
|
||||
h.zeroRTTOpener = nil
|
||||
h.logger.Debugf("Dropping 0-RTT keys.")
|
||||
if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil {
|
||||
h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
|
||||
}
|
||||
}
|
||||
|
||||
if !h.has1RTTOpener {
|
||||
return nil, ErrKeysNotYetAvailable
|
||||
}
|
||||
return h.aead, nil
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) ConnectionState() ConnectionState {
|
||||
return ConnectionState{
|
||||
ConnectionState: h.conn.ConnectionState(),
|
||||
Used0RTT: h.used0RTT.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func wrapError(err error) error {
|
||||
// alert 80 is an internal error
|
||||
if alertErr := qtls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 {
|
||||
return qerr.NewLocalCryptoError(uint8(alertErr), err)
|
||||
}
|
||||
return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()}
|
||||
}
|
||||
135
vendor/github.com/quic-go/quic-go/internal/handshake/header_protector.go
generated
vendored
Normal file
135
vendor/github.com/quic-go/quic-go/internal/handshake/header_protector.go
generated
vendored
Normal file
@@ -0,0 +1,135 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/chacha20"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type headerProtector interface {
|
||||
EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
||||
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
|
||||
}
|
||||
|
||||
func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string {
|
||||
if v == protocol.Version2 {
|
||||
return "quicv2 hp"
|
||||
}
|
||||
return "quic hp"
|
||||
}
|
||||
|
||||
func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector {
|
||||
hkdfLabel := hkdfHeaderProtectionLabel(v)
|
||||
switch suite.ID {
|
||||
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
|
||||
return newAESHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel)
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel)
|
||||
default:
|
||||
panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID))
|
||||
}
|
||||
}
|
||||
|
||||
type aesHeaderProtector struct {
|
||||
mask []byte
|
||||
block cipher.Block
|
||||
isLongHeader bool
|
||||
}
|
||||
|
||||
var _ headerProtector = &aesHeaderProtector{}
|
||||
|
||||
func newAESHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
|
||||
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
|
||||
block, err := aes.NewCipher(hpKey)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
|
||||
}
|
||||
return &aesHeaderProtector{
|
||||
block: block,
|
||||
mask: make([]byte, block.BlockSize()),
|
||||
isLongHeader: isLongHeader,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
if len(sample) != len(p.mask) {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
p.block.Encrypt(p.mask, sample)
|
||||
if p.isLongHeader {
|
||||
*firstByte ^= p.mask[0] & 0xf
|
||||
} else {
|
||||
*firstByte ^= p.mask[0] & 0x1f
|
||||
}
|
||||
for i := range hdrBytes {
|
||||
hdrBytes[i] ^= p.mask[i+1]
|
||||
}
|
||||
}
|
||||
|
||||
type chachaHeaderProtector struct {
|
||||
mask [5]byte
|
||||
|
||||
key [32]byte
|
||||
isLongHeader bool
|
||||
}
|
||||
|
||||
var _ headerProtector = &chachaHeaderProtector{}
|
||||
|
||||
func newChaChaHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
|
||||
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
|
||||
|
||||
p := &chachaHeaderProtector{
|
||||
isLongHeader: isLongHeader,
|
||||
}
|
||||
copy(p.key[:], hpKey)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
p.apply(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
if len(sample) != 16 {
|
||||
panic("invalid sample size")
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
p.mask[i] = 0
|
||||
}
|
||||
cipher, err := chacha20.NewUnauthenticatedCipher(p.key[:], sample[4:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
cipher.SetCounter(binary.LittleEndian.Uint32(sample[:4]))
|
||||
cipher.XORKeyStream(p.mask[:], p.mask[:])
|
||||
p.applyMask(firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (p *chachaHeaderProtector) applyMask(firstByte *byte, hdrBytes []byte) {
|
||||
if p.isLongHeader {
|
||||
*firstByte ^= p.mask[0] & 0xf
|
||||
} else {
|
||||
*firstByte ^= p.mask[0] & 0x1f
|
||||
}
|
||||
for i := range hdrBytes {
|
||||
hdrBytes[i] ^= p.mask[i+1]
|
||||
}
|
||||
}
|
||||
29
vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go
generated
vendored
Normal file
29
vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go
generated
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/binary"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// hkdfExpandLabel HKDF expands a label.
|
||||
// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the
|
||||
// hkdfExpandLabel in the standard library.
|
||||
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
|
||||
b := make([]byte, 3, 3+6+len(label)+1+len(context))
|
||||
binary.BigEndian.PutUint16(b, uint16(length))
|
||||
b[2] = uint8(6 + len(label))
|
||||
b = append(b, []byte("tls13 ")...)
|
||||
b = append(b, []byte(label)...)
|
||||
b = b[:3+6+len(label)+1]
|
||||
b[3+6+len(label)] = uint8(len(context))
|
||||
b = append(b, context...)
|
||||
|
||||
out := make([]byte, length)
|
||||
n, err := hkdf.Expand(hash.New, secret, b).Read(out)
|
||||
if err != nil || n != length {
|
||||
panic("quic: HKDF-Expand-Label invocation failed unexpectedly")
|
||||
}
|
||||
return out
|
||||
}
|
||||
71
vendor/github.com/quic-go/quic-go/internal/handshake/initial_aead.go
generated
vendored
Normal file
71
vendor/github.com/quic-go/quic-go/internal/handshake/initial_aead.go
generated
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/tls"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var (
|
||||
quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
|
||||
quicSaltV2 = []byte{0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9}
|
||||
)
|
||||
|
||||
const (
|
||||
hkdfLabelKeyV1 = "quic key"
|
||||
hkdfLabelKeyV2 = "quicv2 key"
|
||||
hkdfLabelIVV1 = "quic iv"
|
||||
hkdfLabelIVV2 = "quicv2 iv"
|
||||
)
|
||||
|
||||
func getSalt(v protocol.VersionNumber) []byte {
|
||||
if v == protocol.Version2 {
|
||||
return quicSaltV2
|
||||
}
|
||||
return quicSaltV1
|
||||
}
|
||||
|
||||
var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256)
|
||||
|
||||
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
|
||||
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) {
|
||||
clientSecret, serverSecret := computeSecrets(connID, v)
|
||||
var mySecret, otherSecret []byte
|
||||
if pers == protocol.PerspectiveClient {
|
||||
mySecret = clientSecret
|
||||
otherSecret = serverSecret
|
||||
} else {
|
||||
mySecret = serverSecret
|
||||
otherSecret = clientSecret
|
||||
}
|
||||
myKey, myIV := computeInitialKeyAndIV(mySecret, v)
|
||||
otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v)
|
||||
|
||||
encrypter := initialSuite.AEAD(myKey, myIV)
|
||||
decrypter := initialSuite.AEAD(otherKey, otherIV)
|
||||
|
||||
return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)),
|
||||
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v)))
|
||||
}
|
||||
|
||||
func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) {
|
||||
initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v))
|
||||
clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
|
||||
serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
|
||||
return
|
||||
}
|
||||
|
||||
func computeInitialKeyAndIV(secret []byte, v protocol.VersionNumber) (key, iv []byte) {
|
||||
keyLabel := hkdfLabelKeyV1
|
||||
ivLabel := hkdfLabelIVV1
|
||||
if v == protocol.Version2 {
|
||||
keyLabel = hkdfLabelKeyV2
|
||||
ivLabel = hkdfLabelIVV2
|
||||
}
|
||||
key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16)
|
||||
iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12)
|
||||
return
|
||||
}
|
||||
116
vendor/github.com/quic-go/quic-go/internal/handshake/interface.go
generated
vendored
Normal file
116
vendor/github.com/quic-go/quic-go/internal/handshake/interface.go
generated
vendored
Normal file
@@ -0,0 +1,116 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrKeysNotYetAvailable is returned when an opener or a sealer is requested for an encryption level,
|
||||
// but the corresponding opener has not yet been initialized
|
||||
// This can happen when packets arrive out of order.
|
||||
ErrKeysNotYetAvailable = errors.New("CryptoSetup: keys at this encryption level not yet available")
|
||||
// ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level,
|
||||
// but the corresponding keys have already been dropped.
|
||||
ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped")
|
||||
// ErrDecryptionFailed is returned when the AEAD fails to open the packet.
|
||||
ErrDecryptionFailed = errors.New("decryption failed")
|
||||
)
|
||||
|
||||
type headerDecryptor interface {
|
||||
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
|
||||
}
|
||||
|
||||
// LongHeaderOpener opens a long header packet
|
||||
type LongHeaderOpener interface {
|
||||
headerDecryptor
|
||||
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
|
||||
Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// ShortHeaderOpener opens a short header packet
|
||||
type ShortHeaderOpener interface {
|
||||
headerDecryptor
|
||||
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
|
||||
Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// LongHeaderSealer seals a long header packet
|
||||
type LongHeaderSealer interface {
|
||||
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
|
||||
EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
|
||||
Overhead() int
|
||||
}
|
||||
|
||||
// ShortHeaderSealer seals a short header packet
|
||||
type ShortHeaderSealer interface {
|
||||
LongHeaderSealer
|
||||
KeyPhase() protocol.KeyPhaseBit
|
||||
}
|
||||
|
||||
type ConnectionState struct {
|
||||
tls.ConnectionState
|
||||
Used0RTT bool
|
||||
}
|
||||
|
||||
// EventKind is the kind of handshake event.
|
||||
type EventKind uint8
|
||||
|
||||
const (
|
||||
// EventNoEvent signals that there are no new handshake events
|
||||
EventNoEvent EventKind = iota + 1
|
||||
// EventWriteInitialData contains new CRYPTO data to send at the Initial encryption level
|
||||
EventWriteInitialData
|
||||
// EventWriteHandshakeData contains new CRYPTO data to send at the Handshake encryption level
|
||||
EventWriteHandshakeData
|
||||
// EventReceivedReadKeys signals that new decryption keys are available.
|
||||
// It doesn't say which encryption level those keys are for.
|
||||
EventReceivedReadKeys
|
||||
// EventDiscard0RTTKeys signals that the Handshake keys were discarded.
|
||||
EventDiscard0RTTKeys
|
||||
// EventReceivedTransportParameters contains the transport parameters sent by the peer.
|
||||
EventReceivedTransportParameters
|
||||
// EventRestoredTransportParameters contains the transport parameters restored from the session ticket.
|
||||
// It is only used for the client.
|
||||
EventRestoredTransportParameters
|
||||
// EventHandshakeComplete signals that the TLS handshake was completed.
|
||||
EventHandshakeComplete
|
||||
)
|
||||
|
||||
// Event is a handshake event.
|
||||
type Event struct {
|
||||
Kind EventKind
|
||||
Data []byte
|
||||
TransportParameters *wire.TransportParameters
|
||||
}
|
||||
|
||||
// CryptoSetup handles the handshake and protecting / unprotecting packets
|
||||
type CryptoSetup interface {
|
||||
StartHandshake() error
|
||||
io.Closer
|
||||
ChangeConnectionID(protocol.ConnectionID)
|
||||
GetSessionTicket() ([]byte, error)
|
||||
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) error
|
||||
NextEvent() Event
|
||||
|
||||
SetLargest1RTTAcked(protocol.PacketNumber) error
|
||||
DiscardInitialKeys()
|
||||
SetHandshakeConfirmed()
|
||||
ConnectionState() ConnectionState
|
||||
|
||||
GetInitialOpener() (LongHeaderOpener, error)
|
||||
GetHandshakeOpener() (LongHeaderOpener, error)
|
||||
Get0RTTOpener() (LongHeaderOpener, error)
|
||||
Get1RTTOpener() (ShortHeaderOpener, error)
|
||||
|
||||
GetInitialSealer() (LongHeaderSealer, error)
|
||||
GetHandshakeSealer() (LongHeaderSealer, error)
|
||||
Get0RTTSealer() (LongHeaderSealer, error)
|
||||
Get1RTTSealer() (ShortHeaderSealer, error)
|
||||
}
|
||||
63
vendor/github.com/quic-go/quic-go/internal/handshake/retry.go
generated
vendored
Normal file
63
vendor/github.com/quic-go/quic-go/internal/handshake/retry.go
generated
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var (
|
||||
retryAEADv1 cipher.AEAD // used for QUIC v1 (RFC 9000)
|
||||
retryAEADv2 cipher.AEAD // used for QUIC v2 (RFC 9369)
|
||||
)
|
||||
|
||||
func init() {
|
||||
retryAEADv1 = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e})
|
||||
retryAEADv2 = initAEAD([16]byte{0x8f, 0xb4, 0xb0, 0x1b, 0x56, 0xac, 0x48, 0xe2, 0x60, 0xfb, 0xcb, 0xce, 0xad, 0x7c, 0xcc, 0x92})
|
||||
}
|
||||
|
||||
func initAEAD(key [16]byte) cipher.AEAD {
|
||||
aes, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(aes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return aead
|
||||
}
|
||||
|
||||
var (
|
||||
retryBuf bytes.Buffer
|
||||
retryMutex sync.Mutex
|
||||
retryNonceV1 = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb}
|
||||
retryNonceV2 = [12]byte{0xd8, 0x69, 0x69, 0xbc, 0x2d, 0x7c, 0x6d, 0x99, 0x90, 0xef, 0xb0, 0x4a}
|
||||
)
|
||||
|
||||
// GetRetryIntegrityTag calculates the integrity tag on a Retry packet
|
||||
func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte {
|
||||
retryMutex.Lock()
|
||||
defer retryMutex.Unlock()
|
||||
|
||||
retryBuf.WriteByte(uint8(origDestConnID.Len()))
|
||||
retryBuf.Write(origDestConnID.Bytes())
|
||||
retryBuf.Write(retry)
|
||||
defer retryBuf.Reset()
|
||||
|
||||
var tag [16]byte
|
||||
var sealed []byte
|
||||
if version == protocol.Version2 {
|
||||
sealed = retryAEADv2.Seal(tag[:0], retryNonceV2[:], nil, retryBuf.Bytes())
|
||||
} else {
|
||||
sealed = retryAEADv1.Seal(tag[:0], retryNonceV1[:], nil, retryBuf.Bytes())
|
||||
}
|
||||
if len(sealed) != 16 {
|
||||
panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed)))
|
||||
}
|
||||
return &tag
|
||||
}
|
||||
54
vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go
generated
vendored
Normal file
54
vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go
generated
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
const sessionTicketRevision = 4
|
||||
|
||||
type sessionTicket struct {
|
||||
Parameters *wire.TransportParameters
|
||||
RTT time.Duration // to be encoded in mus
|
||||
}
|
||||
|
||||
func (t *sessionTicket) Marshal() []byte {
|
||||
b := make([]byte, 0, 256)
|
||||
b = quicvarint.Append(b, sessionTicketRevision)
|
||||
b = quicvarint.Append(b, uint64(t.RTT.Microseconds()))
|
||||
if t.Parameters == nil {
|
||||
return b
|
||||
}
|
||||
return t.Parameters.MarshalForSessionTicket(b)
|
||||
}
|
||||
|
||||
func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
|
||||
r := bytes.NewReader(b)
|
||||
rev, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return errors.New("failed to read session ticket revision")
|
||||
}
|
||||
if rev != sessionTicketRevision {
|
||||
return fmt.Errorf("unknown session ticket revision: %d", rev)
|
||||
}
|
||||
rtt, err := quicvarint.Read(r)
|
||||
if err != nil {
|
||||
return errors.New("failed to read RTT")
|
||||
}
|
||||
if using0RTT {
|
||||
var tp wire.TransportParameters
|
||||
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
|
||||
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
|
||||
}
|
||||
t.Parameters = &tp
|
||||
} else if r.Len() > 0 {
|
||||
return fmt.Errorf("the session ticket has more bytes than expected")
|
||||
}
|
||||
t.RTT = time.Duration(rtt) * time.Microsecond
|
||||
return nil
|
||||
}
|
||||
120
vendor/github.com/quic-go/quic-go/internal/handshake/token_generator.go
generated
vendored
Normal file
120
vendor/github.com/quic-go/quic-go/internal/handshake/token_generator.go
generated
vendored
Normal file
@@ -0,0 +1,120 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
tokenPrefixIP byte = iota
|
||||
tokenPrefixString
|
||||
)
|
||||
|
||||
// A Token is derived from the client address and can be used to verify the ownership of this address.
|
||||
type Token struct {
|
||||
IsRetryToken bool
|
||||
SentTime time.Time
|
||||
encodedRemoteAddr []byte
|
||||
// only set for retry tokens
|
||||
OriginalDestConnectionID protocol.ConnectionID
|
||||
RetrySrcConnectionID protocol.ConnectionID
|
||||
}
|
||||
|
||||
// ValidateRemoteAddr validates the address, but does not check expiration
|
||||
func (t *Token) ValidateRemoteAddr(addr net.Addr) bool {
|
||||
return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr)
|
||||
}
|
||||
|
||||
// token is the struct that is used for ASN1 serialization and deserialization
|
||||
type token struct {
|
||||
IsRetryToken bool
|
||||
RemoteAddr []byte
|
||||
Timestamp int64
|
||||
OriginalDestConnectionID []byte
|
||||
RetrySrcConnectionID []byte
|
||||
}
|
||||
|
||||
// A TokenGenerator generates tokens
|
||||
type TokenGenerator struct {
|
||||
tokenProtector tokenProtector
|
||||
}
|
||||
|
||||
// NewTokenGenerator initializes a new TokenGenerator
|
||||
func NewTokenGenerator(key TokenProtectorKey) *TokenGenerator {
|
||||
return &TokenGenerator{tokenProtector: newTokenProtector(key)}
|
||||
}
|
||||
|
||||
// NewRetryToken generates a new token for a Retry for a given source address
|
||||
func (g *TokenGenerator) NewRetryToken(
|
||||
raddr net.Addr,
|
||||
origDestConnID protocol.ConnectionID,
|
||||
retrySrcConnID protocol.ConnectionID,
|
||||
) ([]byte, error) {
|
||||
data, err := asn1.Marshal(token{
|
||||
IsRetryToken: true,
|
||||
RemoteAddr: encodeRemoteAddr(raddr),
|
||||
OriginalDestConnectionID: origDestConnID.Bytes(),
|
||||
RetrySrcConnectionID: retrySrcConnID.Bytes(),
|
||||
Timestamp: time.Now().UnixNano(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.tokenProtector.NewToken(data)
|
||||
}
|
||||
|
||||
// NewToken generates a new token to be sent in a NEW_TOKEN frame
|
||||
func (g *TokenGenerator) NewToken(raddr net.Addr) ([]byte, error) {
|
||||
data, err := asn1.Marshal(token{
|
||||
RemoteAddr: encodeRemoteAddr(raddr),
|
||||
Timestamp: time.Now().UnixNano(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return g.tokenProtector.NewToken(data)
|
||||
}
|
||||
|
||||
// DecodeToken decodes a token
|
||||
func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) {
|
||||
// if the client didn't send any token, DecodeToken will be called with a nil-slice
|
||||
if len(encrypted) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := g.tokenProtector.DecodeToken(encrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := &token{}
|
||||
rest, err := asn1.Unmarshal(data, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(rest) != 0 {
|
||||
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
|
||||
}
|
||||
token := &Token{
|
||||
IsRetryToken: t.IsRetryToken,
|
||||
SentTime: time.Unix(0, t.Timestamp),
|
||||
encodedRemoteAddr: t.RemoteAddr,
|
||||
}
|
||||
if t.IsRetryToken {
|
||||
token.OriginalDestConnectionID = protocol.ParseConnectionID(t.OriginalDestConnectionID)
|
||||
token.RetrySrcConnectionID = protocol.ParseConnectionID(t.RetrySrcConnectionID)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// encodeRemoteAddr encodes a remote address such that it can be saved in the token
|
||||
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
|
||||
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
|
||||
return append([]byte{tokenPrefixIP}, udpAddr.IP...)
|
||||
}
|
||||
return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...)
|
||||
}
|
||||
82
vendor/github.com/quic-go/quic-go/internal/handshake/token_protector.go
generated
vendored
Normal file
82
vendor/github.com/quic-go/quic-go/internal/handshake/token_protector.go
generated
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// TokenProtectorKey is the key used to encrypt both Retry and session resumption tokens.
|
||||
type TokenProtectorKey [32]byte
|
||||
|
||||
// TokenProtector is used to create and verify a token
|
||||
type tokenProtector interface {
|
||||
// NewToken creates a new token
|
||||
NewToken([]byte) ([]byte, error)
|
||||
// DecodeToken decodes a token
|
||||
DecodeToken([]byte) ([]byte, error)
|
||||
}
|
||||
|
||||
const tokenNonceSize = 32
|
||||
|
||||
// tokenProtector is used to create and verify a token
|
||||
type tokenProtectorImpl struct {
|
||||
key TokenProtectorKey
|
||||
}
|
||||
|
||||
// newTokenProtector creates a source for source address tokens
|
||||
func newTokenProtector(key TokenProtectorKey) tokenProtector {
|
||||
return &tokenProtectorImpl{key: key}
|
||||
}
|
||||
|
||||
// NewToken encodes data into a new token.
|
||||
func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) {
|
||||
var nonce [tokenNonceSize]byte
|
||||
if _, err := rand.Read(nonce[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, aeadNonce, err := s.createAEAD(nonce[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(nonce[:], aead.Seal(nil, aeadNonce, data, nil)...), nil
|
||||
}
|
||||
|
||||
// DecodeToken decodes a token.
|
||||
func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
|
||||
if len(p) < tokenNonceSize {
|
||||
return nil, fmt.Errorf("token too short: %d", len(p))
|
||||
}
|
||||
nonce := p[:tokenNonceSize]
|
||||
aead, aeadNonce, err := s.createAEAD(nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil)
|
||||
}
|
||||
|
||||
func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
|
||||
h := hkdf.New(sha256.New, s.key[:], nonce, []byte("quic-go token source"))
|
||||
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
|
||||
if _, err := io.ReadFull(h, key); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
aeadNonce := make([]byte, 12)
|
||||
if _, err := io.ReadFull(h, aeadNonce); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
c, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
aead, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return aead, aeadNonce, nil
|
||||
}
|
||||
332
vendor/github.com/quic-go/quic-go/internal/handshake/updatable_aead.go
generated
vendored
Normal file
332
vendor/github.com/quic-go/quic-go/internal/handshake/updatable_aead.go
generated
vendored
Normal file
@@ -0,0 +1,332 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update.
|
||||
// It's a package-level variable to allow modifying it for testing purposes.
|
||||
var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
|
||||
|
||||
// FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update.
|
||||
// It's a package-level variable to allow modifying it for testing purposes.
|
||||
var FirstKeyUpdateInterval uint64 = 100
|
||||
|
||||
type updatableAEAD struct {
|
||||
suite *cipherSuite
|
||||
|
||||
keyPhase protocol.KeyPhase
|
||||
largestAcked protocol.PacketNumber
|
||||
firstPacketNumber protocol.PacketNumber
|
||||
handshakeConfirmed bool
|
||||
|
||||
invalidPacketLimit uint64
|
||||
invalidPacketCount uint64
|
||||
|
||||
// Time when the keys should be dropped. Keys are dropped on the next call to Open().
|
||||
prevRcvAEADExpiry time.Time
|
||||
prevRcvAEAD cipher.AEAD
|
||||
|
||||
firstRcvdWithCurrentKey protocol.PacketNumber
|
||||
firstSentWithCurrentKey protocol.PacketNumber
|
||||
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
|
||||
numRcvdWithCurrentKey uint64
|
||||
numSentWithCurrentKey uint64
|
||||
rcvAEAD cipher.AEAD
|
||||
sendAEAD cipher.AEAD
|
||||
// caches cipher.AEAD.Overhead(). This speeds up calls to Overhead().
|
||||
aeadOverhead int
|
||||
|
||||
nextRcvAEAD cipher.AEAD
|
||||
nextSendAEAD cipher.AEAD
|
||||
nextRcvTrafficSecret []byte
|
||||
nextSendTrafficSecret []byte
|
||||
|
||||
headerDecrypter headerProtector
|
||||
headerEncrypter headerProtector
|
||||
|
||||
rttStats *utils.RTTStats
|
||||
|
||||
tracer *logging.ConnectionTracer
|
||||
logger utils.Logger
|
||||
version protocol.VersionNumber
|
||||
|
||||
// use a single slice to avoid allocations
|
||||
nonceBuf []byte
|
||||
}
|
||||
|
||||
var (
|
||||
_ ShortHeaderOpener = &updatableAEAD{}
|
||||
_ ShortHeaderSealer = &updatableAEAD{}
|
||||
)
|
||||
|
||||
func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD {
|
||||
return &updatableAEAD{
|
||||
firstPacketNumber: protocol.InvalidPacketNumber,
|
||||
largestAcked: protocol.InvalidPacketNumber,
|
||||
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
|
||||
firstSentWithCurrentKey: protocol.InvalidPacketNumber,
|
||||
rttStats: rttStats,
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) rollKeys() {
|
||||
if a.prevRcvAEAD != nil {
|
||||
a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry)
|
||||
if a.tracer != nil && a.tracer.DroppedKey != nil {
|
||||
a.tracer.DroppedKey(a.keyPhase - 1)
|
||||
}
|
||||
a.prevRcvAEADExpiry = time.Time{}
|
||||
}
|
||||
|
||||
a.keyPhase++
|
||||
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
|
||||
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
|
||||
a.numRcvdWithCurrentKey = 0
|
||||
a.numSentWithCurrentKey = 0
|
||||
a.prevRcvAEAD = a.rcvAEAD
|
||||
a.rcvAEAD = a.nextRcvAEAD
|
||||
a.sendAEAD = a.nextSendAEAD
|
||||
|
||||
a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret)
|
||||
a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret)
|
||||
a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version)
|
||||
a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) startKeyDropTimer(now time.Time) {
|
||||
d := 3 * a.rttStats.PTO(true)
|
||||
a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d)
|
||||
a.prevRcvAEADExpiry = now.Add(d)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte {
|
||||
return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
|
||||
}
|
||||
|
||||
// SetReadKey sets the read key.
|
||||
// For the client, this function is called before SetWriteKey.
|
||||
// For the server, this function is called after SetWriteKey.
|
||||
func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) {
|
||||
a.rcvAEAD = createAEAD(suite, trafficSecret, a.version)
|
||||
a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
|
||||
if a.suite == nil {
|
||||
a.setAEADParameters(a.rcvAEAD, suite)
|
||||
}
|
||||
|
||||
a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
|
||||
a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version)
|
||||
}
|
||||
|
||||
// SetWriteKey sets the write key.
|
||||
// For the client, this function is called after SetReadKey.
|
||||
// For the server, this function is called before SetWriteKey.
|
||||
func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) {
|
||||
a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
|
||||
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
|
||||
if a.suite == nil {
|
||||
a.setAEADParameters(a.sendAEAD, suite)
|
||||
}
|
||||
|
||||
a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
|
||||
a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *cipherSuite) {
|
||||
a.nonceBuf = make([]byte, aead.NonceSize())
|
||||
a.aeadOverhead = aead.Overhead()
|
||||
a.suite = suite
|
||||
switch suite.ID {
|
||||
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
|
||||
a.invalidPacketLimit = protocol.InvalidPacketLimitAES
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown cipher suite %d", suite.ID))
|
||||
}
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
|
||||
return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
|
||||
dec, err := a.open(dst, src, rcvTime, pn, kp, ad)
|
||||
if err == ErrDecryptionFailed {
|
||||
a.invalidPacketCount++
|
||||
if a.invalidPacketCount >= a.invalidPacketLimit {
|
||||
return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached}
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
a.highestRcvdPN = utils.Max(a.highestRcvdPN, pn)
|
||||
}
|
||||
return dec, err
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
|
||||
if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) {
|
||||
a.prevRcvAEAD = nil
|
||||
a.logger.Debugf("Dropping key phase %d", a.keyPhase-1)
|
||||
a.prevRcvAEADExpiry = time.Time{}
|
||||
if a.tracer != nil && a.tracer.DroppedKey != nil {
|
||||
a.tracer.DroppedKey(a.keyPhase - 1)
|
||||
}
|
||||
}
|
||||
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
||||
if kp != a.keyPhase.Bit() {
|
||||
if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
|
||||
if a.prevRcvAEAD == nil {
|
||||
return nil, ErrKeysDropped
|
||||
}
|
||||
// we updated the key, but the peer hasn't updated yet
|
||||
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||
if err != nil {
|
||||
err = ErrDecryptionFailed
|
||||
}
|
||||
return dec, err
|
||||
}
|
||||
// try opening the packet with the next key phase
|
||||
dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||
if err != nil {
|
||||
return nil, ErrDecryptionFailed
|
||||
}
|
||||
// Opening succeeded. Check if the peer was allowed to update.
|
||||
if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
|
||||
return nil, &qerr.TransportError{
|
||||
ErrorCode: qerr.KeyUpdateError,
|
||||
ErrorMessage: "keys updated too quickly",
|
||||
}
|
||||
}
|
||||
a.rollKeys()
|
||||
a.logger.Debugf("Peer updated keys to %d", a.keyPhase)
|
||||
// The peer initiated this key update. It's safe to drop the keys for the previous generation now.
|
||||
// Start a timer to drop the previous key generation.
|
||||
a.startKeyDropTimer(rcvTime)
|
||||
if a.tracer != nil && a.tracer.UpdatedKey != nil {
|
||||
a.tracer.UpdatedKey(a.keyPhase, true)
|
||||
}
|
||||
a.firstRcvdWithCurrentKey = pn
|
||||
return dec, err
|
||||
}
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad)
|
||||
if err != nil {
|
||||
return dec, ErrDecryptionFailed
|
||||
}
|
||||
a.numRcvdWithCurrentKey++
|
||||
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
|
||||
// We initiated the key updated, and now we received the first packet protected with the new key phase.
|
||||
// Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys.
|
||||
if a.keyPhase > 0 {
|
||||
a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase)
|
||||
a.startKeyDropTimer(rcvTime)
|
||||
}
|
||||
a.firstRcvdWithCurrentKey = pn
|
||||
}
|
||||
return dec, err
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
|
||||
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
|
||||
a.firstSentWithCurrentKey = pn
|
||||
}
|
||||
if a.firstPacketNumber == protocol.InvalidPacketNumber {
|
||||
a.firstPacketNumber = pn
|
||||
}
|
||||
a.numSentWithCurrentKey++
|
||||
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
|
||||
// The AEAD we're using here will be the qtls.aeadAESGCM13.
|
||||
// It uses the nonce provided here and XOR it with the IV.
|
||||
return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error {
|
||||
if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
|
||||
pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.KeyUpdateError,
|
||||
ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase),
|
||||
}
|
||||
}
|
||||
a.largestAcked = pn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) SetHandshakeConfirmed() {
|
||||
a.handshakeConfirmed = true
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) updateAllowed() bool {
|
||||
if !a.handshakeConfirmed {
|
||||
return false
|
||||
}
|
||||
// the first key update is allowed as soon as the handshake is confirmed
|
||||
return a.keyPhase == 0 ||
|
||||
// subsequent key updates as soon as a packet sent with that key phase has been acknowledged
|
||||
(a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
|
||||
a.largestAcked != protocol.InvalidPacketNumber &&
|
||||
a.largestAcked >= a.firstSentWithCurrentKey)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
|
||||
if !a.updateAllowed() {
|
||||
return false
|
||||
}
|
||||
// Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism.
|
||||
if a.keyPhase == 0 {
|
||||
if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if a.numRcvdWithCurrentKey >= KeyUpdateInterval {
|
||||
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
|
||||
return true
|
||||
}
|
||||
if a.numSentWithCurrentKey >= KeyUpdateInterval {
|
||||
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
|
||||
if a.shouldInitiateKeyUpdate() {
|
||||
a.rollKeys()
|
||||
a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase)
|
||||
if a.tracer != nil && a.tracer.UpdatedKey != nil {
|
||||
a.tracer.UpdatedKey(a.keyPhase, false)
|
||||
}
|
||||
}
|
||||
return a.keyPhase.Bit()
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) Overhead() int {
|
||||
return a.aeadOverhead
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
|
||||
a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes)
|
||||
}
|
||||
|
||||
func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber {
|
||||
return a.firstPacketNumber
|
||||
}
|
||||
50
vendor/github.com/quic-go/quic-go/internal/logutils/frame.go
generated
vendored
Normal file
50
vendor/github.com/quic-go/quic-go/internal/logutils/frame.go
generated
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
package logutils
|
||||
|
||||
import (
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
// ConvertFrame converts a wire.Frame into a logging.Frame.
|
||||
// This makes it possible for external packages to access the frames.
|
||||
// Furthermore, it removes the data slices from CRYPTO and STREAM frames.
|
||||
func ConvertFrame(frame wire.Frame) logging.Frame {
|
||||
switch f := frame.(type) {
|
||||
case *wire.AckFrame:
|
||||
// We use a pool for ACK frames.
|
||||
// Implementations of the tracer interface may hold on to frames, so we need to make a copy here.
|
||||
return ConvertAckFrame(f)
|
||||
case *wire.CryptoFrame:
|
||||
return &logging.CryptoFrame{
|
||||
Offset: f.Offset,
|
||||
Length: protocol.ByteCount(len(f.Data)),
|
||||
}
|
||||
case *wire.StreamFrame:
|
||||
return &logging.StreamFrame{
|
||||
StreamID: f.StreamID,
|
||||
Offset: f.Offset,
|
||||
Length: f.DataLen(),
|
||||
Fin: f.Fin,
|
||||
}
|
||||
case *wire.DatagramFrame:
|
||||
return &logging.DatagramFrame{
|
||||
Length: logging.ByteCount(len(f.Data)),
|
||||
}
|
||||
default:
|
||||
return logging.Frame(frame)
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertAckFrame(f *wire.AckFrame) *logging.AckFrame {
|
||||
ranges := make([]wire.AckRange, 0, len(f.AckRanges))
|
||||
ranges = append(ranges, f.AckRanges...)
|
||||
ack := &logging.AckFrame{
|
||||
AckRanges: ranges,
|
||||
DelayTime: f.DelayTime,
|
||||
ECNCE: f.ECNCE,
|
||||
ECT0: f.ECT0,
|
||||
ECT1: f.ECT1,
|
||||
}
|
||||
return ack
|
||||
}
|
||||
116
vendor/github.com/quic-go/quic-go/internal/protocol/connection_id.go
generated
vendored
Normal file
116
vendor/github.com/quic-go/quic-go/internal/protocol/connection_id.go
generated
vendored
Normal file
@@ -0,0 +1,116 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
var ErrInvalidConnectionIDLen = errors.New("invalid Connection ID length")
|
||||
|
||||
// An ArbitraryLenConnectionID is a QUIC Connection ID able to represent Connection IDs according to RFC 8999.
|
||||
// Future QUIC versions might allow connection ID lengths up to 255 bytes, while QUIC v1
|
||||
// restricts the length to 20 bytes.
|
||||
type ArbitraryLenConnectionID []byte
|
||||
|
||||
func (c ArbitraryLenConnectionID) Len() int {
|
||||
return len(c)
|
||||
}
|
||||
|
||||
func (c ArbitraryLenConnectionID) Bytes() []byte {
|
||||
return c
|
||||
}
|
||||
|
||||
func (c ArbitraryLenConnectionID) String() string {
|
||||
if c.Len() == 0 {
|
||||
return "(empty)"
|
||||
}
|
||||
return fmt.Sprintf("%x", c.Bytes())
|
||||
}
|
||||
|
||||
const maxConnectionIDLen = 20
|
||||
|
||||
// A ConnectionID in QUIC
|
||||
type ConnectionID struct {
|
||||
b [20]byte
|
||||
l uint8
|
||||
}
|
||||
|
||||
// GenerateConnectionID generates a connection ID using cryptographic random
|
||||
func GenerateConnectionID(l int) (ConnectionID, error) {
|
||||
var c ConnectionID
|
||||
c.l = uint8(l)
|
||||
_, err := rand.Read(c.b[:l])
|
||||
return c, err
|
||||
}
|
||||
|
||||
// ParseConnectionID interprets b as a Connection ID.
|
||||
// It panics if b is longer than 20 bytes.
|
||||
func ParseConnectionID(b []byte) ConnectionID {
|
||||
if len(b) > maxConnectionIDLen {
|
||||
panic("invalid conn id length")
|
||||
}
|
||||
var c ConnectionID
|
||||
c.l = uint8(len(b))
|
||||
copy(c.b[:c.l], b)
|
||||
return c
|
||||
}
|
||||
|
||||
// GenerateConnectionIDForInitial generates a connection ID for the Initial packet.
|
||||
// It uses a length randomly chosen between 8 and 20 bytes.
|
||||
func GenerateConnectionIDForInitial() (ConnectionID, error) {
|
||||
r := make([]byte, 1)
|
||||
if _, err := rand.Read(r); err != nil {
|
||||
return ConnectionID{}, err
|
||||
}
|
||||
l := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
|
||||
return GenerateConnectionID(l)
|
||||
}
|
||||
|
||||
// ReadConnectionID reads a connection ID of length len from the given io.Reader.
|
||||
// It returns io.EOF if there are not enough bytes to read.
|
||||
func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) {
|
||||
var c ConnectionID
|
||||
if l == 0 {
|
||||
return c, nil
|
||||
}
|
||||
if l > maxConnectionIDLen {
|
||||
return c, ErrInvalidConnectionIDLen
|
||||
}
|
||||
c.l = uint8(l)
|
||||
_, err := io.ReadFull(r, c.b[:l])
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return c, io.EOF
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
// Len returns the length of the connection ID in bytes
|
||||
func (c ConnectionID) Len() int {
|
||||
return int(c.l)
|
||||
}
|
||||
|
||||
// Bytes returns the byte representation
|
||||
func (c ConnectionID) Bytes() []byte {
|
||||
return c.b[:c.l]
|
||||
}
|
||||
|
||||
func (c ConnectionID) String() string {
|
||||
if c.Len() == 0 {
|
||||
return "(empty)"
|
||||
}
|
||||
return fmt.Sprintf("%x", c.Bytes())
|
||||
}
|
||||
|
||||
type DefaultConnectionIDGenerator struct {
|
||||
ConnLen int
|
||||
}
|
||||
|
||||
func (d *DefaultConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) {
|
||||
return GenerateConnectionID(d.ConnLen)
|
||||
}
|
||||
|
||||
func (d *DefaultConnectionIDGenerator) ConnectionIDLen() int {
|
||||
return d.ConnLen
|
||||
}
|
||||
30
vendor/github.com/quic-go/quic-go/internal/protocol/encryption_level.go
generated
vendored
Normal file
30
vendor/github.com/quic-go/quic-go/internal/protocol/encryption_level.go
generated
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
package protocol
|
||||
|
||||
// EncryptionLevel is the encryption level
|
||||
// Default value is Unencrypted
|
||||
type EncryptionLevel uint8
|
||||
|
||||
const (
|
||||
// EncryptionInitial is the Initial encryption level
|
||||
EncryptionInitial EncryptionLevel = 1 + iota
|
||||
// EncryptionHandshake is the Handshake encryption level
|
||||
EncryptionHandshake
|
||||
// Encryption0RTT is the 0-RTT encryption level
|
||||
Encryption0RTT
|
||||
// Encryption1RTT is the 1-RTT encryption level
|
||||
Encryption1RTT
|
||||
)
|
||||
|
||||
func (e EncryptionLevel) String() string {
|
||||
switch e {
|
||||
case EncryptionInitial:
|
||||
return "Initial"
|
||||
case EncryptionHandshake:
|
||||
return "Handshake"
|
||||
case Encryption0RTT:
|
||||
return "0-RTT"
|
||||
case Encryption1RTT:
|
||||
return "1-RTT"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
36
vendor/github.com/quic-go/quic-go/internal/protocol/key_phase.go
generated
vendored
Normal file
36
vendor/github.com/quic-go/quic-go/internal/protocol/key_phase.go
generated
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
package protocol
|
||||
|
||||
// KeyPhase is the key phase
|
||||
type KeyPhase uint64
|
||||
|
||||
// Bit determines the key phase bit
|
||||
func (p KeyPhase) Bit() KeyPhaseBit {
|
||||
if p%2 == 0 {
|
||||
return KeyPhaseZero
|
||||
}
|
||||
return KeyPhaseOne
|
||||
}
|
||||
|
||||
// KeyPhaseBit is the key phase bit
|
||||
type KeyPhaseBit uint8
|
||||
|
||||
const (
|
||||
// KeyPhaseUndefined is an undefined key phase
|
||||
KeyPhaseUndefined KeyPhaseBit = iota
|
||||
// KeyPhaseZero is key phase 0
|
||||
KeyPhaseZero
|
||||
// KeyPhaseOne is key phase 1
|
||||
KeyPhaseOne
|
||||
)
|
||||
|
||||
func (p KeyPhaseBit) String() string {
|
||||
//nolint:exhaustive
|
||||
switch p {
|
||||
case KeyPhaseZero:
|
||||
return "0"
|
||||
case KeyPhaseOne:
|
||||
return "1"
|
||||
default:
|
||||
return "undefined"
|
||||
}
|
||||
}
|
||||
79
vendor/github.com/quic-go/quic-go/internal/protocol/packet_number.go
generated
vendored
Normal file
79
vendor/github.com/quic-go/quic-go/internal/protocol/packet_number.go
generated
vendored
Normal file
@@ -0,0 +1,79 @@
|
||||
package protocol
|
||||
|
||||
// A PacketNumber in QUIC
|
||||
type PacketNumber int64
|
||||
|
||||
// InvalidPacketNumber is a packet number that is never sent.
|
||||
// In QUIC, 0 is a valid packet number.
|
||||
const InvalidPacketNumber PacketNumber = -1
|
||||
|
||||
// PacketNumberLen is the length of the packet number in bytes
|
||||
type PacketNumberLen uint8
|
||||
|
||||
const (
|
||||
// PacketNumberLen1 is a packet number length of 1 byte
|
||||
PacketNumberLen1 PacketNumberLen = 1
|
||||
// PacketNumberLen2 is a packet number length of 2 bytes
|
||||
PacketNumberLen2 PacketNumberLen = 2
|
||||
// PacketNumberLen3 is a packet number length of 3 bytes
|
||||
PacketNumberLen3 PacketNumberLen = 3
|
||||
// PacketNumberLen4 is a packet number length of 4 bytes
|
||||
PacketNumberLen4 PacketNumberLen = 4
|
||||
)
|
||||
|
||||
// DecodePacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
|
||||
func DecodePacketNumber(
|
||||
packetNumberLength PacketNumberLen,
|
||||
lastPacketNumber PacketNumber,
|
||||
wirePacketNumber PacketNumber,
|
||||
) PacketNumber {
|
||||
var epochDelta PacketNumber
|
||||
switch packetNumberLength {
|
||||
case PacketNumberLen1:
|
||||
epochDelta = PacketNumber(1) << 8
|
||||
case PacketNumberLen2:
|
||||
epochDelta = PacketNumber(1) << 16
|
||||
case PacketNumberLen3:
|
||||
epochDelta = PacketNumber(1) << 24
|
||||
case PacketNumberLen4:
|
||||
epochDelta = PacketNumber(1) << 32
|
||||
}
|
||||
epoch := lastPacketNumber & ^(epochDelta - 1)
|
||||
var prevEpochBegin PacketNumber
|
||||
if epoch > epochDelta {
|
||||
prevEpochBegin = epoch - epochDelta
|
||||
}
|
||||
nextEpochBegin := epoch + epochDelta
|
||||
return closestTo(
|
||||
lastPacketNumber+1,
|
||||
epoch+wirePacketNumber,
|
||||
closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber),
|
||||
)
|
||||
}
|
||||
|
||||
func closestTo(target, a, b PacketNumber) PacketNumber {
|
||||
if delta(target, a) < delta(target, b) {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func delta(a, b PacketNumber) PacketNumber {
|
||||
if a < b {
|
||||
return b - a
|
||||
}
|
||||
return a - b
|
||||
}
|
||||
|
||||
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
|
||||
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
|
||||
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen {
|
||||
diff := uint64(packetNumber - leastUnacked)
|
||||
if diff < (1 << (16 - 1)) {
|
||||
return PacketNumberLen2
|
||||
}
|
||||
if diff < (1 << (24 - 1)) {
|
||||
return PacketNumberLen3
|
||||
}
|
||||
return PacketNumberLen4
|
||||
}
|
||||
190
vendor/github.com/quic-go/quic-go/internal/protocol/params.go
generated
vendored
Normal file
190
vendor/github.com/quic-go/quic-go/internal/protocol/params.go
generated
vendored
Normal file
@@ -0,0 +1,190 @@
|
||||
package protocol
|
||||
|
||||
import "time"
|
||||
|
||||
// DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use.
|
||||
const DesiredReceiveBufferSize = (1 << 20) * 2 // 2 MB
|
||||
|
||||
// DesiredSendBufferSize is the kernel UDP send buffer size that we'd like to use.
|
||||
const DesiredSendBufferSize = (1 << 20) * 2 // 2 MB
|
||||
|
||||
// InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets.
|
||||
const InitialPacketSizeIPv4 = 1252
|
||||
|
||||
// InitialPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets.
|
||||
const InitialPacketSizeIPv6 = 1232
|
||||
|
||||
// MaxCongestionWindowPackets is the maximum congestion window in packet.
|
||||
const MaxCongestionWindowPackets = 10000
|
||||
|
||||
// MaxUndecryptablePackets limits the number of undecryptable packets that are queued in the connection.
|
||||
const MaxUndecryptablePackets = 32
|
||||
|
||||
// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window
|
||||
// This is the value that Chromium is using
|
||||
const ConnectionFlowControlMultiplier = 1.5
|
||||
|
||||
// DefaultInitialMaxStreamData is the default initial stream-level flow control window for receiving data
|
||||
const DefaultInitialMaxStreamData = (1 << 10) * 512 // 512 kb
|
||||
|
||||
// DefaultInitialMaxData is the connection-level flow control window for receiving data
|
||||
const DefaultInitialMaxData = ConnectionFlowControlMultiplier * DefaultInitialMaxStreamData
|
||||
|
||||
// DefaultMaxReceiveStreamFlowControlWindow is the default maximum stream-level flow control window for receiving data
|
||||
const DefaultMaxReceiveStreamFlowControlWindow = 6 * (1 << 20) // 6 MB
|
||||
|
||||
// DefaultMaxReceiveConnectionFlowControlWindow is the default connection-level flow control window for receiving data
|
||||
const DefaultMaxReceiveConnectionFlowControlWindow = 15 * (1 << 20) // 15 MB
|
||||
|
||||
// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client
|
||||
const WindowUpdateThreshold = 0.25
|
||||
|
||||
// DefaultMaxIncomingStreams is the maximum number of streams that a peer may open
|
||||
const DefaultMaxIncomingStreams = 100
|
||||
|
||||
// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open
|
||||
const DefaultMaxIncomingUniStreams = 100
|
||||
|
||||
// MaxServerUnprocessedPackets is the max number of packets stored in the server that are not yet processed.
|
||||
const MaxServerUnprocessedPackets = 1024
|
||||
|
||||
// MaxConnUnprocessedPackets is the max number of packets stored in each connection that are not yet processed.
|
||||
const MaxConnUnprocessedPackets = 256
|
||||
|
||||
// SkipPacketInitialPeriod is the initial period length used for packet number skipping to prevent an Optimistic ACK attack.
|
||||
// Every time a packet number is skipped, the period is doubled, up to SkipPacketMaxPeriod.
|
||||
const SkipPacketInitialPeriod PacketNumber = 256
|
||||
|
||||
// SkipPacketMaxPeriod is the maximum period length used for packet number skipping.
|
||||
const SkipPacketMaxPeriod PacketNumber = 128 * 1024
|
||||
|
||||
// MaxAcceptQueueSize is the maximum number of connections that the server queues for accepting.
|
||||
// If the queue is full, new connection attempts will be rejected.
|
||||
const MaxAcceptQueueSize = 32
|
||||
|
||||
// TokenValidity is the duration that a (non-retry) token is considered valid
|
||||
const TokenValidity = 24 * time.Hour
|
||||
|
||||
// MaxOutstandingSentPackets is maximum number of packets saved for retransmission.
|
||||
// When reached, it imposes a soft limit on sending new packets:
|
||||
// Sending ACKs and retransmission is still allowed, but now new regular packets can be sent.
|
||||
const MaxOutstandingSentPackets = 2 * MaxCongestionWindowPackets
|
||||
|
||||
// MaxTrackedSentPackets is maximum number of sent packets saved for retransmission.
|
||||
// When reached, no more packets will be sent.
|
||||
// This value *must* be larger than MaxOutstandingSentPackets.
|
||||
const MaxTrackedSentPackets = MaxOutstandingSentPackets * 5 / 4
|
||||
|
||||
// MaxNonAckElicitingAcks is the maximum number of packets containing an ACK,
|
||||
// but no ack-eliciting frames, that we send in a row
|
||||
const MaxNonAckElicitingAcks = 19
|
||||
|
||||
// MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames
|
||||
// prevents DoS attacks against the streamFrameSorter
|
||||
const MaxStreamFrameSorterGaps = 1000
|
||||
|
||||
// MinStreamFrameBufferSize is the minimum data length of a received STREAM frame
|
||||
// that we use the buffer for. This protects against a DoS where an attacker would send us
|
||||
// very small STREAM frames to consume a lot of memory.
|
||||
const MinStreamFrameBufferSize = 128
|
||||
|
||||
// MinCoalescedPacketSize is the minimum size of a coalesced packet that we pack.
|
||||
// If a packet has less than this number of bytes, we won't coalesce any more packets onto it.
|
||||
const MinCoalescedPacketSize = 128
|
||||
|
||||
// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams.
|
||||
// This limits the size of the ClientHello and Certificates that can be received.
|
||||
const MaxCryptoStreamOffset = 16 * (1 << 10)
|
||||
|
||||
// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout
|
||||
const MinRemoteIdleTimeout = 5 * time.Second
|
||||
|
||||
// DefaultIdleTimeout is the default idle timeout
|
||||
const DefaultIdleTimeout = 30 * time.Second
|
||||
|
||||
// DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion.
|
||||
const DefaultHandshakeIdleTimeout = 5 * time.Second
|
||||
|
||||
// MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive.
|
||||
// It should be shorter than the time that NATs clear their mapping.
|
||||
const MaxKeepAliveInterval = 20 * time.Second
|
||||
|
||||
// RetiredConnectionIDDeleteTimeout is the time we keep closed connections around in order to retransmit the CONNECTION_CLOSE.
|
||||
// after this time all information about the old connection will be deleted
|
||||
const RetiredConnectionIDDeleteTimeout = 5 * time.Second
|
||||
|
||||
// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame.
|
||||
// This avoids splitting up STREAM frames into small pieces, which has 2 advantages:
|
||||
// 1. it reduces the framing overhead
|
||||
// 2. it reduces the head-of-line blocking, when a packet is lost
|
||||
const MinStreamFrameSize ByteCount = 128
|
||||
|
||||
// MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames
|
||||
// we send after the handshake completes.
|
||||
const MaxPostHandshakeCryptoFrameSize = 1000
|
||||
|
||||
// MaxAckFrameSize is the maximum size for an ACK frame that we write
|
||||
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large.
|
||||
// The MaxAckFrameSize should be large enough to encode many ACK range,
|
||||
// but must ensure that a maximum size ACK frame fits into one packet.
|
||||
const MaxAckFrameSize ByteCount = 1000
|
||||
|
||||
// MaxDatagramFrameSize is the maximum size of a DATAGRAM frame (RFC 9221).
|
||||
// The size is chosen such that a DATAGRAM frame fits into a QUIC packet.
|
||||
const MaxDatagramFrameSize ByteCount = 1200
|
||||
|
||||
// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames (RFC 9221)
|
||||
const DatagramRcvQueueLen = 128
|
||||
|
||||
// MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame.
|
||||
// It also serves as a limit for the packet history.
|
||||
// If at any point we keep track of more ranges, old ranges are discarded.
|
||||
const MaxNumAckRanges = 32
|
||||
|
||||
// MinPacingDelay is the minimum duration that is used for packet pacing
|
||||
// If the packet packing frequency is higher, multiple packets might be sent at once.
|
||||
// Example: For a packet pacing delay of 200μs, we would send 5 packets at once, wait for 1ms, and so forth.
|
||||
const MinPacingDelay = time.Millisecond
|
||||
|
||||
// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections
|
||||
// if no other value is configured.
|
||||
const DefaultConnectionIDLength = 4
|
||||
|
||||
// MaxActiveConnectionIDs is the number of connection IDs that we're storing.
|
||||
const MaxActiveConnectionIDs = 4
|
||||
|
||||
// MaxIssuedConnectionIDs is the maximum number of connection IDs that we're issuing at the same time.
|
||||
const MaxIssuedConnectionIDs = 6
|
||||
|
||||
// PacketsPerConnectionID is the number of packets we send using one connection ID.
|
||||
// If the peer provices us with enough new connection IDs, we switch to a new connection ID.
|
||||
const PacketsPerConnectionID = 10000
|
||||
|
||||
// AckDelayExponent is the ack delay exponent used when sending ACKs.
|
||||
const AckDelayExponent = 3
|
||||
|
||||
// Estimated timer granularity.
|
||||
// The loss detection timer will not be set to a value smaller than granularity.
|
||||
const TimerGranularity = time.Millisecond
|
||||
|
||||
// MaxAckDelay is the maximum time by which we delay sending ACKs.
|
||||
const MaxAckDelay = 25 * time.Millisecond
|
||||
|
||||
// MaxAckDelayInclGranularity is the max_ack_delay including the timer granularity.
|
||||
// This is the value that should be advertised to the peer.
|
||||
const MaxAckDelayInclGranularity = MaxAckDelay + TimerGranularity
|
||||
|
||||
// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update.
|
||||
const KeyUpdateInterval = 100 * 1000
|
||||
|
||||
// Max0RTTQueueingDuration is the maximum time that we store 0-RTT packets in order to wait for the corresponding Initial to be received.
|
||||
const Max0RTTQueueingDuration = 100 * time.Millisecond
|
||||
|
||||
// Max0RTTQueues is the maximum number of connections that we buffer 0-RTT packets for.
|
||||
const Max0RTTQueues = 32
|
||||
|
||||
// Max0RTTQueueLen is the maximum number of 0-RTT packets that we buffer for each connection.
|
||||
// When a new connection is created, all buffered packets are passed to the connection immediately.
|
||||
// To avoid blocking, this value has to be smaller than MaxConnUnprocessedPackets.
|
||||
// To avoid packets being dropped as undecryptable by the connection, this value has to be smaller than MaxUndecryptablePackets.
|
||||
const Max0RTTQueueLen = 31
|
||||
26
vendor/github.com/quic-go/quic-go/internal/protocol/perspective.go
generated
vendored
Normal file
26
vendor/github.com/quic-go/quic-go/internal/protocol/perspective.go
generated
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
package protocol
|
||||
|
||||
// Perspective determines if we're acting as a server or a client
|
||||
type Perspective int
|
||||
|
||||
// the perspectives
|
||||
const (
|
||||
PerspectiveServer Perspective = 1
|
||||
PerspectiveClient Perspective = 2
|
||||
)
|
||||
|
||||
// Opposite returns the perspective of the peer
|
||||
func (p Perspective) Opposite() Perspective {
|
||||
return 3 - p
|
||||
}
|
||||
|
||||
func (p Perspective) String() string {
|
||||
switch p {
|
||||
case PerspectiveServer:
|
||||
return "Server"
|
||||
case PerspectiveClient:
|
||||
return "Client"
|
||||
default:
|
||||
return "invalid perspective"
|
||||
}
|
||||
}
|
||||
152
vendor/github.com/quic-go/quic-go/internal/protocol/protocol.go
generated
vendored
Normal file
152
vendor/github.com/quic-go/quic-go/internal/protocol/protocol.go
generated
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// The PacketType is the Long Header Type
|
||||
type PacketType uint8
|
||||
|
||||
const (
|
||||
// PacketTypeInitial is the packet type of an Initial packet
|
||||
PacketTypeInitial PacketType = 1 + iota
|
||||
// PacketTypeRetry is the packet type of a Retry packet
|
||||
PacketTypeRetry
|
||||
// PacketTypeHandshake is the packet type of a Handshake packet
|
||||
PacketTypeHandshake
|
||||
// PacketType0RTT is the packet type of a 0-RTT packet
|
||||
PacketType0RTT
|
||||
)
|
||||
|
||||
func (t PacketType) String() string {
|
||||
switch t {
|
||||
case PacketTypeInitial:
|
||||
return "Initial"
|
||||
case PacketTypeRetry:
|
||||
return "Retry"
|
||||
case PacketTypeHandshake:
|
||||
return "Handshake"
|
||||
case PacketType0RTT:
|
||||
return "0-RTT Protected"
|
||||
default:
|
||||
return fmt.Sprintf("unknown packet type: %d", t)
|
||||
}
|
||||
}
|
||||
|
||||
type ECN uint8
|
||||
|
||||
const (
|
||||
ECNUnsupported ECN = iota
|
||||
ECNNon // 00
|
||||
ECT1 // 01
|
||||
ECT0 // 10
|
||||
ECNCE // 11
|
||||
)
|
||||
|
||||
func ParseECNHeaderBits(bits byte) ECN {
|
||||
switch bits {
|
||||
case 0:
|
||||
return ECNNon
|
||||
case 0b00000010:
|
||||
return ECT0
|
||||
case 0b00000001:
|
||||
return ECT1
|
||||
case 0b00000011:
|
||||
return ECNCE
|
||||
default:
|
||||
panic("invalid ECN bits")
|
||||
}
|
||||
}
|
||||
|
||||
func (e ECN) ToHeaderBits() byte {
|
||||
//nolint:exhaustive // There are only 4 values.
|
||||
switch e {
|
||||
case ECNNon:
|
||||
return 0
|
||||
case ECT0:
|
||||
return 0b00000010
|
||||
case ECT1:
|
||||
return 0b00000001
|
||||
case ECNCE:
|
||||
return 0b00000011
|
||||
default:
|
||||
panic("ECN unsupported")
|
||||
}
|
||||
}
|
||||
|
||||
func (e ECN) String() string {
|
||||
switch e {
|
||||
case ECNUnsupported:
|
||||
return "ECN unsupported"
|
||||
case ECNNon:
|
||||
return "Not-ECT"
|
||||
case ECT1:
|
||||
return "ECT(1)"
|
||||
case ECT0:
|
||||
return "ECT(0)"
|
||||
case ECNCE:
|
||||
return "CE"
|
||||
default:
|
||||
return fmt.Sprintf("invalid ECN value: %d", e)
|
||||
}
|
||||
}
|
||||
|
||||
// A ByteCount in QUIC
|
||||
type ByteCount int64
|
||||
|
||||
// MaxByteCount is the maximum value of a ByteCount
|
||||
const MaxByteCount = ByteCount(1<<62 - 1)
|
||||
|
||||
// InvalidByteCount is an invalid byte count
|
||||
const InvalidByteCount ByteCount = -1
|
||||
|
||||
// A StatelessResetToken is a stateless reset token.
|
||||
type StatelessResetToken [16]byte
|
||||
|
||||
// MaxPacketBufferSize maximum packet size of any QUIC packet, based on
|
||||
// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header,
|
||||
// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes.
|
||||
// Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452.
|
||||
const MaxPacketBufferSize = 1452
|
||||
|
||||
// MaxLargePacketBufferSize is used when using GSO
|
||||
const MaxLargePacketBufferSize = 20 * 1024
|
||||
|
||||
// MinInitialPacketSize is the minimum size an Initial packet is required to have.
|
||||
const MinInitialPacketSize = 1200
|
||||
|
||||
// MinUnknownVersionPacketSize is the minimum size a packet with an unknown version
|
||||
// needs to have in order to trigger a Version Negotiation packet.
|
||||
const MinUnknownVersionPacketSize = MinInitialPacketSize
|
||||
|
||||
// MinStatelessResetSize is the minimum size of a stateless reset packet that we send
|
||||
const MinStatelessResetSize = 1 /* first byte */ + 20 /* max. conn ID length */ + 4 /* max. packet number length */ + 1 /* min. payload length */ + 16 /* token */
|
||||
|
||||
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
|
||||
const MinConnectionIDLenInitial = 8
|
||||
|
||||
// DefaultAckDelayExponent is the default ack delay exponent
|
||||
const DefaultAckDelayExponent = 3
|
||||
|
||||
// DefaultActiveConnectionIDLimit is the default active connection ID limit
|
||||
const DefaultActiveConnectionIDLimit = 2
|
||||
|
||||
// MaxAckDelayExponent is the maximum ack delay exponent
|
||||
const MaxAckDelayExponent = 20
|
||||
|
||||
// DefaultMaxAckDelay is the default max_ack_delay
|
||||
const DefaultMaxAckDelay = 25 * time.Millisecond
|
||||
|
||||
// MaxMaxAckDelay is the maximum max_ack_delay
|
||||
const MaxMaxAckDelay = (1<<14 - 1) * time.Millisecond
|
||||
|
||||
// MaxConnIDLen is the maximum length of the connection ID
|
||||
const MaxConnIDLen = 20
|
||||
|
||||
// InvalidPacketLimitAES is the maximum number of packets that we can fail to decrypt when using
|
||||
// AEAD_AES_128_GCM or AEAD_AES_265_GCM.
|
||||
const InvalidPacketLimitAES = 1 << 52
|
||||
|
||||
// InvalidPacketLimitChaCha is the maximum number of packets that we can fail to decrypt when using AEAD_CHACHA20_POLY1305.
|
||||
const InvalidPacketLimitChaCha = 1 << 36
|
||||
76
vendor/github.com/quic-go/quic-go/internal/protocol/stream.go
generated
vendored
Normal file
76
vendor/github.com/quic-go/quic-go/internal/protocol/stream.go
generated
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
package protocol
|
||||
|
||||
// StreamType encodes if this is a unidirectional or bidirectional stream
|
||||
type StreamType uint8
|
||||
|
||||
const (
|
||||
// StreamTypeUni is a unidirectional stream
|
||||
StreamTypeUni StreamType = iota
|
||||
// StreamTypeBidi is a bidirectional stream
|
||||
StreamTypeBidi
|
||||
)
|
||||
|
||||
// InvalidPacketNumber is a stream ID that is invalid.
|
||||
// The first valid stream ID in QUIC is 0.
|
||||
const InvalidStreamID StreamID = -1
|
||||
|
||||
// StreamNum is the stream number
|
||||
type StreamNum int64
|
||||
|
||||
const (
|
||||
// InvalidStreamNum is an invalid stream number.
|
||||
InvalidStreamNum = -1
|
||||
// MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames
|
||||
// and as the stream count in the transport parameters
|
||||
MaxStreamCount StreamNum = 1 << 60
|
||||
)
|
||||
|
||||
// StreamID calculates the stream ID.
|
||||
func (s StreamNum) StreamID(stype StreamType, pers Perspective) StreamID {
|
||||
if s == 0 {
|
||||
return InvalidStreamID
|
||||
}
|
||||
var first StreamID
|
||||
switch stype {
|
||||
case StreamTypeBidi:
|
||||
switch pers {
|
||||
case PerspectiveClient:
|
||||
first = 0
|
||||
case PerspectiveServer:
|
||||
first = 1
|
||||
}
|
||||
case StreamTypeUni:
|
||||
switch pers {
|
||||
case PerspectiveClient:
|
||||
first = 2
|
||||
case PerspectiveServer:
|
||||
first = 3
|
||||
}
|
||||
}
|
||||
return first + 4*StreamID(s-1)
|
||||
}
|
||||
|
||||
// A StreamID in QUIC
|
||||
type StreamID int64
|
||||
|
||||
// InitiatedBy says if the stream was initiated by the client or by the server
|
||||
func (s StreamID) InitiatedBy() Perspective {
|
||||
if s%2 == 0 {
|
||||
return PerspectiveClient
|
||||
}
|
||||
return PerspectiveServer
|
||||
}
|
||||
|
||||
// Type says if this is a unidirectional or bidirectional stream
|
||||
func (s StreamID) Type() StreamType {
|
||||
if s%4 >= 2 {
|
||||
return StreamTypeUni
|
||||
}
|
||||
return StreamTypeBidi
|
||||
}
|
||||
|
||||
// StreamNum returns how many streams in total are below this
|
||||
// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9)
|
||||
func (s StreamID) StreamNum() StreamNum {
|
||||
return StreamNum(s/4) + 1
|
||||
}
|
||||
105
vendor/github.com/quic-go/quic-go/internal/protocol/version.go
generated
vendored
Normal file
105
vendor/github.com/quic-go/quic-go/internal/protocol/version.go
generated
vendored
Normal file
@@ -0,0 +1,105 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
// VersionNumber is a version number as int
|
||||
type VersionNumber uint32
|
||||
|
||||
// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions
|
||||
const (
|
||||
gquicVersion0 = 0x51303030
|
||||
maxGquicVersion = 0x51303439
|
||||
)
|
||||
|
||||
// The version numbers, making grepping easier
|
||||
const (
|
||||
VersionUnknown VersionNumber = math.MaxUint32
|
||||
versionDraft29 VersionNumber = 0xff00001d // draft-29 used to be a widely deployed version
|
||||
Version1 VersionNumber = 0x1
|
||||
Version2 VersionNumber = 0x6b3343cf
|
||||
)
|
||||
|
||||
// SupportedVersions lists the versions that the server supports
|
||||
// must be in sorted descending order
|
||||
var SupportedVersions = []VersionNumber{Version1, Version2}
|
||||
|
||||
// IsValidVersion says if the version is known to quic-go
|
||||
func IsValidVersion(v VersionNumber) bool {
|
||||
return v == Version1 || IsSupportedVersion(SupportedVersions, v)
|
||||
}
|
||||
|
||||
func (vn VersionNumber) String() string {
|
||||
//nolint:exhaustive
|
||||
switch vn {
|
||||
case VersionUnknown:
|
||||
return "unknown"
|
||||
case versionDraft29:
|
||||
return "draft-29"
|
||||
case Version1:
|
||||
return "v1"
|
||||
case Version2:
|
||||
return "v2"
|
||||
default:
|
||||
if vn.isGQUIC() {
|
||||
return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion())
|
||||
}
|
||||
return fmt.Sprintf("%#x", uint32(vn))
|
||||
}
|
||||
}
|
||||
|
||||
func (vn VersionNumber) isGQUIC() bool {
|
||||
return vn > gquicVersion0 && vn <= maxGquicVersion
|
||||
}
|
||||
|
||||
func (vn VersionNumber) toGQUICVersion() int {
|
||||
return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10)
|
||||
}
|
||||
|
||||
// IsSupportedVersion returns true if the server supports this version
|
||||
func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
|
||||
for _, t := range supported {
|
||||
if t == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ChooseSupportedVersion finds the best version in the overlap of ours and theirs
|
||||
// ours is a slice of versions that we support, sorted by our preference (descending)
|
||||
// theirs is a slice of versions offered by the peer. The order does not matter.
|
||||
// The bool returned indicates if a matching version was found.
|
||||
func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) {
|
||||
for _, ourVer := range ours {
|
||||
for _, theirVer := range theirs {
|
||||
if ourVer == theirVer {
|
||||
return ourVer, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a)
|
||||
func generateReservedVersion() VersionNumber {
|
||||
b := make([]byte, 4)
|
||||
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
|
||||
return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa)
|
||||
}
|
||||
|
||||
// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position
|
||||
func GetGreasedVersions(supported []VersionNumber) []VersionNumber {
|
||||
b := make([]byte, 1)
|
||||
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
|
||||
randPos := int(b[0]) % (len(supported) + 1)
|
||||
greased := make([]VersionNumber, len(supported)+1)
|
||||
copy(greased, supported[:randPos])
|
||||
greased[randPos] = generateReservedVersion()
|
||||
copy(greased[randPos+1:], supported[randPos:])
|
||||
return greased
|
||||
}
|
||||
88
vendor/github.com/quic-go/quic-go/internal/qerr/error_codes.go
generated
vendored
Normal file
88
vendor/github.com/quic-go/quic-go/internal/qerr/error_codes.go
generated
vendored
Normal file
@@ -0,0 +1,88 @@
|
||||
package qerr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/qtls"
|
||||
)
|
||||
|
||||
// TransportErrorCode is a QUIC transport error.
|
||||
type TransportErrorCode uint64
|
||||
|
||||
// The error codes defined by QUIC
|
||||
const (
|
||||
NoError TransportErrorCode = 0x0
|
||||
InternalError TransportErrorCode = 0x1
|
||||
ConnectionRefused TransportErrorCode = 0x2
|
||||
FlowControlError TransportErrorCode = 0x3
|
||||
StreamLimitError TransportErrorCode = 0x4
|
||||
StreamStateError TransportErrorCode = 0x5
|
||||
FinalSizeError TransportErrorCode = 0x6
|
||||
FrameEncodingError TransportErrorCode = 0x7
|
||||
TransportParameterError TransportErrorCode = 0x8
|
||||
ConnectionIDLimitError TransportErrorCode = 0x9
|
||||
ProtocolViolation TransportErrorCode = 0xa
|
||||
InvalidToken TransportErrorCode = 0xb
|
||||
ApplicationErrorErrorCode TransportErrorCode = 0xc
|
||||
CryptoBufferExceeded TransportErrorCode = 0xd
|
||||
KeyUpdateError TransportErrorCode = 0xe
|
||||
AEADLimitReached TransportErrorCode = 0xf
|
||||
NoViablePathError TransportErrorCode = 0x10
|
||||
)
|
||||
|
||||
func (e TransportErrorCode) IsCryptoError() bool {
|
||||
return e >= 0x100 && e < 0x200
|
||||
}
|
||||
|
||||
// Message is a description of the error.
|
||||
// It only returns a non-empty string for crypto errors.
|
||||
func (e TransportErrorCode) Message() string {
|
||||
if !e.IsCryptoError() {
|
||||
return ""
|
||||
}
|
||||
return qtls.AlertError(e - 0x100).Error()
|
||||
}
|
||||
|
||||
func (e TransportErrorCode) String() string {
|
||||
switch e {
|
||||
case NoError:
|
||||
return "NO_ERROR"
|
||||
case InternalError:
|
||||
return "INTERNAL_ERROR"
|
||||
case ConnectionRefused:
|
||||
return "CONNECTION_REFUSED"
|
||||
case FlowControlError:
|
||||
return "FLOW_CONTROL_ERROR"
|
||||
case StreamLimitError:
|
||||
return "STREAM_LIMIT_ERROR"
|
||||
case StreamStateError:
|
||||
return "STREAM_STATE_ERROR"
|
||||
case FinalSizeError:
|
||||
return "FINAL_SIZE_ERROR"
|
||||
case FrameEncodingError:
|
||||
return "FRAME_ENCODING_ERROR"
|
||||
case TransportParameterError:
|
||||
return "TRANSPORT_PARAMETER_ERROR"
|
||||
case ConnectionIDLimitError:
|
||||
return "CONNECTION_ID_LIMIT_ERROR"
|
||||
case ProtocolViolation:
|
||||
return "PROTOCOL_VIOLATION"
|
||||
case InvalidToken:
|
||||
return "INVALID_TOKEN"
|
||||
case ApplicationErrorErrorCode:
|
||||
return "APPLICATION_ERROR"
|
||||
case CryptoBufferExceeded:
|
||||
return "CRYPTO_BUFFER_EXCEEDED"
|
||||
case KeyUpdateError:
|
||||
return "KEY_UPDATE_ERROR"
|
||||
case AEADLimitReached:
|
||||
return "AEAD_LIMIT_REACHED"
|
||||
case NoViablePathError:
|
||||
return "NO_VIABLE_PATH"
|
||||
default:
|
||||
if e.IsCryptoError() {
|
||||
return fmt.Sprintf("CRYPTO_ERROR %#x", uint16(e))
|
||||
}
|
||||
return fmt.Sprintf("unknown error code: %#x", uint16(e))
|
||||
}
|
||||
}
|
||||
139
vendor/github.com/quic-go/quic-go/internal/qerr/errors.go
generated
vendored
Normal file
139
vendor/github.com/quic-go/quic-go/internal/qerr/errors.go
generated
vendored
Normal file
@@ -0,0 +1,139 @@
|
||||
package qerr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrHandshakeTimeout = &HandshakeTimeoutError{}
|
||||
ErrIdleTimeout = &IdleTimeoutError{}
|
||||
)
|
||||
|
||||
type TransportError struct {
|
||||
Remote bool
|
||||
FrameType uint64
|
||||
ErrorCode TransportErrorCode
|
||||
ErrorMessage string
|
||||
error error // only set for local errors, sometimes
|
||||
}
|
||||
|
||||
var _ error = &TransportError{}
|
||||
|
||||
// NewLocalCryptoError create a new TransportError instance for a crypto error
|
||||
func NewLocalCryptoError(tlsAlert uint8, err error) *TransportError {
|
||||
return &TransportError{
|
||||
ErrorCode: 0x100 + TransportErrorCode(tlsAlert),
|
||||
error: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *TransportError) Error() string {
|
||||
str := fmt.Sprintf("%s (%s)", e.ErrorCode.String(), getRole(e.Remote))
|
||||
if e.FrameType != 0 {
|
||||
str += fmt.Sprintf(" (frame type: %#x)", e.FrameType)
|
||||
}
|
||||
msg := e.ErrorMessage
|
||||
if len(msg) == 0 && e.error != nil {
|
||||
msg = e.error.Error()
|
||||
}
|
||||
if len(msg) == 0 {
|
||||
msg = e.ErrorCode.Message()
|
||||
}
|
||||
if len(msg) == 0 {
|
||||
return str
|
||||
}
|
||||
return str + ": " + msg
|
||||
}
|
||||
|
||||
func (e *TransportError) Is(target error) bool {
|
||||
return target == net.ErrClosed
|
||||
}
|
||||
|
||||
func (e *TransportError) Unwrap() error {
|
||||
return e.error
|
||||
}
|
||||
|
||||
// An ApplicationErrorCode is an application-defined error code.
|
||||
type ApplicationErrorCode uint64
|
||||
|
||||
func (e *ApplicationError) Is(target error) bool {
|
||||
return target == net.ErrClosed
|
||||
}
|
||||
|
||||
// A StreamErrorCode is an error code used to cancel streams.
|
||||
type StreamErrorCode uint64
|
||||
|
||||
type ApplicationError struct {
|
||||
Remote bool
|
||||
ErrorCode ApplicationErrorCode
|
||||
ErrorMessage string
|
||||
}
|
||||
|
||||
var _ error = &ApplicationError{}
|
||||
|
||||
func (e *ApplicationError) Error() string {
|
||||
if len(e.ErrorMessage) == 0 {
|
||||
return fmt.Sprintf("Application error %#x (%s)", e.ErrorCode, getRole(e.Remote))
|
||||
}
|
||||
return fmt.Sprintf("Application error %#x (%s): %s", e.ErrorCode, getRole(e.Remote), e.ErrorMessage)
|
||||
}
|
||||
|
||||
type IdleTimeoutError struct{}
|
||||
|
||||
var _ error = &IdleTimeoutError{}
|
||||
|
||||
func (e *IdleTimeoutError) Timeout() bool { return true }
|
||||
func (e *IdleTimeoutError) Temporary() bool { return false }
|
||||
func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" }
|
||||
func (e *IdleTimeoutError) Is(target error) bool { return target == net.ErrClosed }
|
||||
|
||||
type HandshakeTimeoutError struct{}
|
||||
|
||||
var _ error = &HandshakeTimeoutError{}
|
||||
|
||||
func (e *HandshakeTimeoutError) Timeout() bool { return true }
|
||||
func (e *HandshakeTimeoutError) Temporary() bool { return false }
|
||||
func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" }
|
||||
func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.ErrClosed }
|
||||
|
||||
// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version.
|
||||
type VersionNegotiationError struct {
|
||||
Ours []protocol.VersionNumber
|
||||
Theirs []protocol.VersionNumber
|
||||
}
|
||||
|
||||
func (e *VersionNegotiationError) Error() string {
|
||||
return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs)
|
||||
}
|
||||
|
||||
func (e *VersionNegotiationError) Is(target error) bool {
|
||||
return target == net.ErrClosed
|
||||
}
|
||||
|
||||
// A StatelessResetError occurs when we receive a stateless reset.
|
||||
type StatelessResetError struct {
|
||||
Token protocol.StatelessResetToken
|
||||
}
|
||||
|
||||
var _ net.Error = &StatelessResetError{}
|
||||
|
||||
func (e *StatelessResetError) Error() string {
|
||||
return fmt.Sprintf("received a stateless reset with token %x", e.Token)
|
||||
}
|
||||
|
||||
func (e *StatelessResetError) Is(target error) bool {
|
||||
return target == net.ErrClosed
|
||||
}
|
||||
|
||||
func (e *StatelessResetError) Timeout() bool { return false }
|
||||
func (e *StatelessResetError) Temporary() bool { return true }
|
||||
|
||||
func getRole(remote bool) string {
|
||||
if remote {
|
||||
return "remote"
|
||||
}
|
||||
return "local"
|
||||
}
|
||||
66
vendor/github.com/quic-go/quic-go/internal/qtls/cipher_suite_go121.go
generated
vendored
Normal file
66
vendor/github.com/quic-go/quic-go/internal/qtls/cipher_suite_go121.go
generated
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
//go:build go1.21
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/cipher"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type cipherSuiteTLS13 struct {
|
||||
ID uint16
|
||||
KeyLen int
|
||||
AEAD func(key, fixedNonce []byte) cipher.AEAD
|
||||
Hash crypto.Hash
|
||||
}
|
||||
|
||||
//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID
|
||||
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
|
||||
|
||||
//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13
|
||||
var cipherSuitesTLS13 []unsafe.Pointer
|
||||
|
||||
//go:linkname defaultCipherSuitesTLS13 crypto/tls.defaultCipherSuitesTLS13
|
||||
var defaultCipherSuitesTLS13 []uint16
|
||||
|
||||
//go:linkname defaultCipherSuitesTLS13NoAES crypto/tls.defaultCipherSuitesTLS13NoAES
|
||||
var defaultCipherSuitesTLS13NoAES []uint16
|
||||
|
||||
var cipherSuitesModified bool
|
||||
|
||||
// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls
|
||||
// such that it only contains the cipher suite with the chosen id.
|
||||
// The reset function returned resets them back to the original value.
|
||||
func SetCipherSuite(id uint16) (reset func()) {
|
||||
if cipherSuitesModified {
|
||||
panic("cipher suites modified multiple times without resetting")
|
||||
}
|
||||
cipherSuitesModified = true
|
||||
|
||||
origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...)
|
||||
origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...)
|
||||
origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...)
|
||||
// The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls.
|
||||
switch id {
|
||||
case tls.TLS_AES_128_GCM_SHA256:
|
||||
cipherSuitesTLS13 = cipherSuitesTLS13[:1]
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
cipherSuitesTLS13 = cipherSuitesTLS13[1:2]
|
||||
case tls.TLS_AES_256_GCM_SHA384:
|
||||
cipherSuitesTLS13 = cipherSuitesTLS13[2:]
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected cipher suite: %d", id))
|
||||
}
|
||||
defaultCipherSuitesTLS13 = []uint16{id}
|
||||
defaultCipherSuitesTLS13NoAES = []uint16{id}
|
||||
|
||||
return func() {
|
||||
cipherSuitesTLS13 = origCipherSuitesTLS13
|
||||
defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13
|
||||
defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES
|
||||
cipherSuitesModified = false
|
||||
}
|
||||
}
|
||||
61
vendor/github.com/quic-go/quic-go/internal/qtls/client_session_cache.go
generated
vendored
Normal file
61
vendor/github.com/quic-go/quic-go/internal/qtls/client_session_cache.go
generated
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
//go:build go1.21
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
)
|
||||
|
||||
type clientSessionCache struct {
|
||||
getData func() []byte
|
||||
setData func([]byte)
|
||||
wrapped tls.ClientSessionCache
|
||||
}
|
||||
|
||||
var _ tls.ClientSessionCache = &clientSessionCache{}
|
||||
|
||||
func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
|
||||
if cs == nil {
|
||||
c.wrapped.Put(key, nil)
|
||||
return
|
||||
}
|
||||
ticket, state, err := cs.ResumptionState()
|
||||
if err != nil || state == nil {
|
||||
c.wrapped.Put(key, cs)
|
||||
return
|
||||
}
|
||||
state.Extra = append(state.Extra, addExtraPrefix(c.getData()))
|
||||
newCS, err := tls.NewResumptionState(ticket, state)
|
||||
if err != nil {
|
||||
// It's not clear why this would error. Just save the original state.
|
||||
c.wrapped.Put(key, cs)
|
||||
return
|
||||
}
|
||||
c.wrapped.Put(key, newCS)
|
||||
}
|
||||
|
||||
func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
|
||||
cs, ok := c.wrapped.Get(key)
|
||||
if !ok || cs == nil {
|
||||
return cs, ok
|
||||
}
|
||||
ticket, state, err := cs.ResumptionState()
|
||||
if err != nil {
|
||||
// It's not clear why this would error.
|
||||
// Remove the ticket from the session cache, so we don't run into this error over and over again
|
||||
c.wrapped.Put(key, nil)
|
||||
return nil, false
|
||||
}
|
||||
// restore QUIC transport parameters and RTT stored in state.Extra
|
||||
if extra := findExtraData(state.Extra); extra != nil {
|
||||
c.setData(extra)
|
||||
}
|
||||
session, err := tls.NewResumptionState(ticket, state)
|
||||
if err != nil {
|
||||
// It's not clear why this would error.
|
||||
// Remove the ticket from the session cache, so we don't run into this error over and over again
|
||||
c.wrapped.Put(key, nil)
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
147
vendor/github.com/quic-go/quic-go/internal/qtls/go120.go
generated
vendored
Normal file
147
vendor/github.com/quic-go/quic-go/internal/qtls/go120.go
generated
vendored
Normal file
@@ -0,0 +1,147 @@
|
||||
//go:build go1.20 && !go1.21
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
||||
"github.com/quic-go/qtls-go1-20"
|
||||
)
|
||||
|
||||
type (
|
||||
QUICConn = qtls.QUICConn
|
||||
QUICConfig = qtls.QUICConfig
|
||||
QUICEvent = qtls.QUICEvent
|
||||
QUICEventKind = qtls.QUICEventKind
|
||||
QUICEncryptionLevel = qtls.QUICEncryptionLevel
|
||||
AlertError = qtls.AlertError
|
||||
)
|
||||
|
||||
const (
|
||||
QUICEncryptionLevelInitial = qtls.QUICEncryptionLevelInitial
|
||||
QUICEncryptionLevelEarly = qtls.QUICEncryptionLevelEarly
|
||||
QUICEncryptionLevelHandshake = qtls.QUICEncryptionLevelHandshake
|
||||
QUICEncryptionLevelApplication = qtls.QUICEncryptionLevelApplication
|
||||
)
|
||||
|
||||
const (
|
||||
QUICNoEvent = qtls.QUICNoEvent
|
||||
QUICSetReadSecret = qtls.QUICSetReadSecret
|
||||
QUICSetWriteSecret = qtls.QUICSetWriteSecret
|
||||
QUICWriteData = qtls.QUICWriteData
|
||||
QUICTransportParameters = qtls.QUICTransportParameters
|
||||
QUICTransportParametersRequired = qtls.QUICTransportParametersRequired
|
||||
QUICRejectedEarlyData = qtls.QUICRejectedEarlyData
|
||||
QUICHandshakeDone = qtls.QUICHandshakeDone
|
||||
)
|
||||
|
||||
func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, handleSessionTicket func([]byte, bool) bool) {
|
||||
qtls.InitSessionTicketKeys(conf.TLSConfig)
|
||||
conf.TLSConfig = conf.TLSConfig.Clone()
|
||||
conf.TLSConfig.MinVersion = tls.VersionTLS13
|
||||
conf.ExtraConfig = &qtls.ExtraConfig{
|
||||
Enable0RTT: enable0RTT,
|
||||
Accept0RTT: func(data []byte) bool {
|
||||
return handleSessionTicket(data, true)
|
||||
},
|
||||
GetAppDataForSessionTicket: getDataForSessionTicket,
|
||||
}
|
||||
}
|
||||
|
||||
func SetupConfigForClient(conf *QUICConfig, getDataForSessionState func() []byte, setDataFromSessionState func([]byte)) {
|
||||
conf.ExtraConfig = &qtls.ExtraConfig{
|
||||
GetAppDataForSessionState: getDataForSessionState,
|
||||
SetAppDataFromSessionState: setDataFromSessionState,
|
||||
}
|
||||
}
|
||||
|
||||
func QUICServer(config *QUICConfig) *QUICConn {
|
||||
return qtls.QUICServer(config)
|
||||
}
|
||||
|
||||
func QUICClient(config *QUICConfig) *QUICConn {
|
||||
return qtls.QUICClient(config)
|
||||
}
|
||||
|
||||
func ToTLSEncryptionLevel(e protocol.EncryptionLevel) qtls.QUICEncryptionLevel {
|
||||
switch e {
|
||||
case protocol.EncryptionInitial:
|
||||
return qtls.QUICEncryptionLevelInitial
|
||||
case protocol.EncryptionHandshake:
|
||||
return qtls.QUICEncryptionLevelHandshake
|
||||
case protocol.Encryption1RTT:
|
||||
return qtls.QUICEncryptionLevelApplication
|
||||
case protocol.Encryption0RTT:
|
||||
return qtls.QUICEncryptionLevelEarly
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected encryption level: %s", e))
|
||||
}
|
||||
}
|
||||
|
||||
func FromTLSEncryptionLevel(e qtls.QUICEncryptionLevel) protocol.EncryptionLevel {
|
||||
switch e {
|
||||
case qtls.QUICEncryptionLevelInitial:
|
||||
return protocol.EncryptionInitial
|
||||
case qtls.QUICEncryptionLevelHandshake:
|
||||
return protocol.EncryptionHandshake
|
||||
case qtls.QUICEncryptionLevelApplication:
|
||||
return protocol.Encryption1RTT
|
||||
case qtls.QUICEncryptionLevelEarly:
|
||||
return protocol.Encryption0RTT
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpect encryption level: %s", e))
|
||||
}
|
||||
}
|
||||
|
||||
//go:linkname cipherSuitesTLS13 github.com/quic-go/qtls-go1-20.cipherSuitesTLS13
|
||||
var cipherSuitesTLS13 []unsafe.Pointer
|
||||
|
||||
//go:linkname defaultCipherSuitesTLS13 github.com/quic-go/qtls-go1-20.defaultCipherSuitesTLS13
|
||||
var defaultCipherSuitesTLS13 []uint16
|
||||
|
||||
//go:linkname defaultCipherSuitesTLS13NoAES github.com/quic-go/qtls-go1-20.defaultCipherSuitesTLS13NoAES
|
||||
var defaultCipherSuitesTLS13NoAES []uint16
|
||||
|
||||
var cipherSuitesModified bool
|
||||
|
||||
// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls
|
||||
// such that it only contains the cipher suite with the chosen id.
|
||||
// The reset function returned resets them back to the original value.
|
||||
func SetCipherSuite(id uint16) (reset func()) {
|
||||
if cipherSuitesModified {
|
||||
panic("cipher suites modified multiple times without resetting")
|
||||
}
|
||||
cipherSuitesModified = true
|
||||
|
||||
origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...)
|
||||
origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...)
|
||||
origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...)
|
||||
// The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls.
|
||||
switch id {
|
||||
case tls.TLS_AES_128_GCM_SHA256:
|
||||
cipherSuitesTLS13 = cipherSuitesTLS13[:1]
|
||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
||||
cipherSuitesTLS13 = cipherSuitesTLS13[1:2]
|
||||
case tls.TLS_AES_256_GCM_SHA384:
|
||||
cipherSuitesTLS13 = cipherSuitesTLS13[2:]
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected cipher suite: %d", id))
|
||||
}
|
||||
defaultCipherSuitesTLS13 = []uint16{id}
|
||||
defaultCipherSuitesTLS13NoAES = []uint16{id}
|
||||
|
||||
return func() {
|
||||
cipherSuitesTLS13 = origCipherSuitesTLS13
|
||||
defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13
|
||||
defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES
|
||||
cipherSuitesModified = false
|
||||
}
|
||||
}
|
||||
|
||||
func SendSessionTicket(c *QUICConn, allow0RTT bool) error {
|
||||
return c.SendSessionTicket(allow0RTT)
|
||||
}
|
||||
159
vendor/github.com/quic-go/quic-go/internal/qtls/go121.go
generated
vendored
Normal file
159
vendor/github.com/quic-go/quic-go/internal/qtls/go121.go
generated
vendored
Normal file
@@ -0,0 +1,159 @@
|
||||
//go:build go1.21
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
type (
|
||||
QUICConn = tls.QUICConn
|
||||
QUICConfig = tls.QUICConfig
|
||||
QUICEvent = tls.QUICEvent
|
||||
QUICEventKind = tls.QUICEventKind
|
||||
QUICEncryptionLevel = tls.QUICEncryptionLevel
|
||||
QUICSessionTicketOptions = tls.QUICSessionTicketOptions
|
||||
AlertError = tls.AlertError
|
||||
)
|
||||
|
||||
const (
|
||||
QUICEncryptionLevelInitial = tls.QUICEncryptionLevelInitial
|
||||
QUICEncryptionLevelEarly = tls.QUICEncryptionLevelEarly
|
||||
QUICEncryptionLevelHandshake = tls.QUICEncryptionLevelHandshake
|
||||
QUICEncryptionLevelApplication = tls.QUICEncryptionLevelApplication
|
||||
)
|
||||
|
||||
const (
|
||||
QUICNoEvent = tls.QUICNoEvent
|
||||
QUICSetReadSecret = tls.QUICSetReadSecret
|
||||
QUICSetWriteSecret = tls.QUICSetWriteSecret
|
||||
QUICWriteData = tls.QUICWriteData
|
||||
QUICTransportParameters = tls.QUICTransportParameters
|
||||
QUICTransportParametersRequired = tls.QUICTransportParametersRequired
|
||||
QUICRejectedEarlyData = tls.QUICRejectedEarlyData
|
||||
QUICHandshakeDone = tls.QUICHandshakeDone
|
||||
)
|
||||
|
||||
func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) }
|
||||
func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) }
|
||||
|
||||
func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) {
|
||||
conf := qconf.TLSConfig
|
||||
|
||||
// Workaround for https://github.com/golang/go/issues/60506.
|
||||
// This initializes the session tickets _before_ cloning the config.
|
||||
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
|
||||
|
||||
conf = conf.Clone()
|
||||
conf.MinVersion = tls.VersionTLS13
|
||||
qconf.TLSConfig = conf
|
||||
|
||||
// add callbacks to save transport parameters into the session ticket
|
||||
origWrapSession := conf.WrapSession
|
||||
conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) {
|
||||
// Add QUIC session ticket
|
||||
state.Extra = append(state.Extra, addExtraPrefix(getData()))
|
||||
|
||||
if origWrapSession != nil {
|
||||
return origWrapSession(cs, state)
|
||||
}
|
||||
b, err := conf.EncryptTicket(cs, state)
|
||||
return b, err
|
||||
}
|
||||
origUnwrapSession := conf.UnwrapSession
|
||||
// UnwrapSession might be called multiple times, as the client can use multiple session tickets.
|
||||
// However, using 0-RTT is only possible with the first session ticket.
|
||||
// crypto/tls guarantees that this callback is called in the same order as the session ticket in the ClientHello.
|
||||
var unwrapCount int
|
||||
conf.UnwrapSession = func(identity []byte, connState tls.ConnectionState) (*tls.SessionState, error) {
|
||||
unwrapCount++
|
||||
var state *tls.SessionState
|
||||
var err error
|
||||
if origUnwrapSession != nil {
|
||||
state, err = origUnwrapSession(identity, connState)
|
||||
} else {
|
||||
state, err = conf.DecryptTicket(identity, connState)
|
||||
}
|
||||
if err != nil || state == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extra := findExtraData(state.Extra)
|
||||
if extra != nil {
|
||||
state.EarlyData = handleSessionTicket(extra, state.EarlyData && unwrapCount == 1)
|
||||
} else {
|
||||
state.EarlyData = false
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
}
|
||||
|
||||
func SetupConfigForClient(qconf *QUICConfig, getData func() []byte, setData func([]byte)) {
|
||||
conf := qconf.TLSConfig
|
||||
if conf.ClientSessionCache != nil {
|
||||
origCache := conf.ClientSessionCache
|
||||
conf.ClientSessionCache = &clientSessionCache{
|
||||
wrapped: origCache,
|
||||
getData: getData,
|
||||
setData: setData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ToTLSEncryptionLevel(e protocol.EncryptionLevel) tls.QUICEncryptionLevel {
|
||||
switch e {
|
||||
case protocol.EncryptionInitial:
|
||||
return tls.QUICEncryptionLevelInitial
|
||||
case protocol.EncryptionHandshake:
|
||||
return tls.QUICEncryptionLevelHandshake
|
||||
case protocol.Encryption1RTT:
|
||||
return tls.QUICEncryptionLevelApplication
|
||||
case protocol.Encryption0RTT:
|
||||
return tls.QUICEncryptionLevelEarly
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected encryption level: %s", e))
|
||||
}
|
||||
}
|
||||
|
||||
func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) protocol.EncryptionLevel {
|
||||
switch e {
|
||||
case tls.QUICEncryptionLevelInitial:
|
||||
return protocol.EncryptionInitial
|
||||
case tls.QUICEncryptionLevelHandshake:
|
||||
return protocol.EncryptionHandshake
|
||||
case tls.QUICEncryptionLevelApplication:
|
||||
return protocol.Encryption1RTT
|
||||
case tls.QUICEncryptionLevelEarly:
|
||||
return protocol.Encryption0RTT
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpect encryption level: %s", e))
|
||||
}
|
||||
}
|
||||
|
||||
const extraPrefix = "quic-go1"
|
||||
|
||||
func addExtraPrefix(b []byte) []byte {
|
||||
return append([]byte(extraPrefix), b...)
|
||||
}
|
||||
|
||||
func findExtraData(extras [][]byte) []byte {
|
||||
prefix := []byte(extraPrefix)
|
||||
for _, extra := range extras {
|
||||
if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) {
|
||||
continue
|
||||
}
|
||||
return extra[len(prefix):]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func SendSessionTicket(c *QUICConn, allow0RTT bool) error {
|
||||
return c.SendSessionTicket(tls.QUICSessionTicketOptions{
|
||||
EarlyData: allow0RTT,
|
||||
})
|
||||
}
|
||||
5
vendor/github.com/quic-go/quic-go/internal/qtls/go_oldversion.go
generated
vendored
Normal file
5
vendor/github.com/quic-go/quic-go/internal/qtls/go_oldversion.go
generated
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build !go1.20
|
||||
|
||||
package qtls
|
||||
|
||||
var _ int = "The version of quic-go you're using can't be built using outdated Go versions. For more details, please see https://github.com/quic-go/quic-go/wiki/quic-go-and-Go-versions."
|
||||
26
vendor/github.com/quic-go/quic-go/internal/utils/buffered_write_closer.go
generated
vendored
Normal file
26
vendor/github.com/quic-go/quic-go/internal/utils/buffered_write_closer.go
generated
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
)
|
||||
|
||||
type bufferedWriteCloser struct {
|
||||
*bufio.Writer
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// NewBufferedWriteCloser creates an io.WriteCloser from a bufio.Writer and an io.Closer
|
||||
func NewBufferedWriteCloser(writer *bufio.Writer, closer io.Closer) io.WriteCloser {
|
||||
return &bufferedWriteCloser{
|
||||
Writer: writer,
|
||||
Closer: closer,
|
||||
}
|
||||
}
|
||||
|
||||
func (h bufferedWriteCloser) Close() error {
|
||||
if err := h.Writer.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
return h.Closer.Close()
|
||||
}
|
||||
21
vendor/github.com/quic-go/quic-go/internal/utils/byteorder.go
generated
vendored
Normal file
21
vendor/github.com/quic-go/quic-go/internal/utils/byteorder.go
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
)
|
||||
|
||||
// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers.
|
||||
type ByteOrder interface {
|
||||
Uint32([]byte) uint32
|
||||
Uint24([]byte) uint32
|
||||
Uint16([]byte) uint16
|
||||
|
||||
ReadUint32(io.ByteReader) (uint32, error)
|
||||
ReadUint24(io.ByteReader) (uint32, error)
|
||||
ReadUint16(io.ByteReader) (uint16, error)
|
||||
|
||||
WriteUint32(*bytes.Buffer, uint32)
|
||||
WriteUint24(*bytes.Buffer, uint32)
|
||||
WriteUint16(*bytes.Buffer, uint16)
|
||||
}
|
||||
103
vendor/github.com/quic-go/quic-go/internal/utils/byteorder_big_endian.go
generated
vendored
Normal file
103
vendor/github.com/quic-go/quic-go/internal/utils/byteorder_big_endian.go
generated
vendored
Normal file
@@ -0,0 +1,103 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
)
|
||||
|
||||
// BigEndian is the big-endian implementation of ByteOrder.
|
||||
var BigEndian ByteOrder = bigEndian{}
|
||||
|
||||
type bigEndian struct{}
|
||||
|
||||
var _ ByteOrder = &bigEndian{}
|
||||
|
||||
// ReadUintN reads N bytes
|
||||
func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
|
||||
var res uint64
|
||||
for i := uint8(0); i < length; i++ {
|
||||
bt, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
res ^= uint64(bt) << ((length - 1 - i) * 8)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// ReadUint32 reads a uint32
|
||||
func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) {
|
||||
var b1, b2, b3, b4 uint8
|
||||
var err error
|
||||
if b4, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b3, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b2, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b1, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil
|
||||
}
|
||||
|
||||
// ReadUint24 reads a uint24
|
||||
func (bigEndian) ReadUint24(b io.ByteReader) (uint32, error) {
|
||||
var b1, b2, b3 uint8
|
||||
var err error
|
||||
if b3, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b2, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b1, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16, nil
|
||||
}
|
||||
|
||||
// ReadUint16 reads a uint16
|
||||
func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) {
|
||||
var b1, b2 uint8
|
||||
var err error
|
||||
if b2, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if b1, err = b.ReadByte(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint16(b1) + uint16(b2)<<8, nil
|
||||
}
|
||||
|
||||
func (bigEndian) Uint32(b []byte) uint32 {
|
||||
return binary.BigEndian.Uint32(b)
|
||||
}
|
||||
|
||||
func (bigEndian) Uint24(b []byte) uint32 {
|
||||
_ = b[2] // bounds check hint to compiler; see golang.org/issue/14808
|
||||
return uint32(b[2]) | uint32(b[1])<<8 | uint32(b[0])<<16
|
||||
}
|
||||
|
||||
func (bigEndian) Uint16(b []byte) uint16 {
|
||||
return binary.BigEndian.Uint16(b)
|
||||
}
|
||||
|
||||
// WriteUint32 writes a uint32
|
||||
func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) {
|
||||
b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)})
|
||||
}
|
||||
|
||||
// WriteUint24 writes a uint24
|
||||
func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) {
|
||||
b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)})
|
||||
}
|
||||
|
||||
// WriteUint16 writes a uint16
|
||||
func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) {
|
||||
b.Write([]byte{uint8(i >> 8), uint8(i)})
|
||||
}
|
||||
10
vendor/github.com/quic-go/quic-go/internal/utils/ip.go
generated
vendored
Normal file
10
vendor/github.com/quic-go/quic-go/internal/utils/ip.go
generated
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
package utils
|
||||
|
||||
import "net"
|
||||
|
||||
func IsIPv4(ip net.IP) bool {
|
||||
// If ip is not an IPv4 address, To4 returns nil.
|
||||
// Note that there might be some corner cases, where this is not correct.
|
||||
// See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6.
|
||||
return ip.To4() != nil
|
||||
}
|
||||
6
vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/README.md
generated
vendored
Normal file
6
vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/README.md
generated
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
# Usage
|
||||
|
||||
This is the Go standard library implementation of a linked list
|
||||
(https://golang.org/src/container/list/list.go), with the following modifications:
|
||||
* it uses Go generics
|
||||
* it allows passing in a `sync.Pool` (via the `NewWithPool` constructor) to reduce allocations of `Element` structs
|
||||
264
vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/linkedlist.go
generated
vendored
Normal file
264
vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/linkedlist.go
generated
vendored
Normal file
@@ -0,0 +1,264 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package list implements a doubly linked list.
|
||||
//
|
||||
// To iterate over a list (where l is a *List[T]):
|
||||
//
|
||||
// for e := l.Front(); e != nil; e = e.Next() {
|
||||
// // do something with e.Value
|
||||
// }
|
||||
package list
|
||||
|
||||
import "sync"
|
||||
|
||||
func NewPool[T any]() *sync.Pool {
|
||||
return &sync.Pool{New: func() any { return &Element[T]{} }}
|
||||
}
|
||||
|
||||
// Element is an element of a linked list.
|
||||
type Element[T any] struct {
|
||||
// Next and previous pointers in the doubly-linked list of elements.
|
||||
// To simplify the implementation, internally a list l is implemented
|
||||
// as a ring, such that &l.root is both the next element of the last
|
||||
// list element (l.Back()) and the previous element of the first list
|
||||
// element (l.Front()).
|
||||
next, prev *Element[T]
|
||||
|
||||
// The list to which this element belongs.
|
||||
list *List[T]
|
||||
|
||||
// The value stored with this element.
|
||||
Value T
|
||||
}
|
||||
|
||||
// Next returns the next list element or nil.
|
||||
func (e *Element[T]) Next() *Element[T] {
|
||||
if p := e.next; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prev returns the previous list element or nil.
|
||||
func (e *Element[T]) Prev() *Element[T] {
|
||||
if p := e.prev; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Element[T]) List() *List[T] {
|
||||
return e.list
|
||||
}
|
||||
|
||||
// List represents a doubly linked list.
|
||||
// The zero value for List is an empty list ready to use.
|
||||
type List[T any] struct {
|
||||
root Element[T] // sentinel list element, only &root, root.prev, and root.next are used
|
||||
len int // current list length excluding (this) sentinel element
|
||||
|
||||
pool *sync.Pool
|
||||
}
|
||||
|
||||
// Init initializes or clears list l.
|
||||
func (l *List[T]) Init() *List[T] {
|
||||
l.root.next = &l.root
|
||||
l.root.prev = &l.root
|
||||
l.len = 0
|
||||
return l
|
||||
}
|
||||
|
||||
// New returns an initialized list.
|
||||
func New[T any]() *List[T] { return new(List[T]).Init() }
|
||||
|
||||
// NewWithPool returns an initialized list, using a sync.Pool for list elements.
|
||||
func NewWithPool[T any](pool *sync.Pool) *List[T] {
|
||||
l := &List[T]{pool: pool}
|
||||
return l.Init()
|
||||
}
|
||||
|
||||
// Len returns the number of elements of list l.
|
||||
// The complexity is O(1).
|
||||
func (l *List[T]) Len() int { return l.len }
|
||||
|
||||
// Front returns the first element of list l or nil if the list is empty.
|
||||
func (l *List[T]) Front() *Element[T] {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.next
|
||||
}
|
||||
|
||||
// Back returns the last element of list l or nil if the list is empty.
|
||||
func (l *List[T]) Back() *Element[T] {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.prev
|
||||
}
|
||||
|
||||
// lazyInit lazily initializes a zero List value.
|
||||
func (l *List[T]) lazyInit() {
|
||||
if l.root.next == nil {
|
||||
l.Init()
|
||||
}
|
||||
}
|
||||
|
||||
// insert inserts e after at, increments l.len, and returns e.
|
||||
func (l *List[T]) insert(e, at *Element[T]) *Element[T] {
|
||||
e.prev = at
|
||||
e.next = at.next
|
||||
e.prev.next = e
|
||||
e.next.prev = e
|
||||
e.list = l
|
||||
l.len++
|
||||
return e
|
||||
}
|
||||
|
||||
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
|
||||
func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] {
|
||||
var e *Element[T]
|
||||
if l.pool != nil {
|
||||
e = l.pool.Get().(*Element[T])
|
||||
} else {
|
||||
e = &Element[T]{}
|
||||
}
|
||||
e.Value = v
|
||||
return l.insert(e, at)
|
||||
}
|
||||
|
||||
// remove removes e from its list, decrements l.len
|
||||
func (l *List[T]) remove(e *Element[T]) {
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
e.next = nil // avoid memory leaks
|
||||
e.prev = nil // avoid memory leaks
|
||||
e.list = nil
|
||||
if l.pool != nil {
|
||||
l.pool.Put(e)
|
||||
}
|
||||
l.len--
|
||||
}
|
||||
|
||||
// move moves e to next to at.
|
||||
func (l *List[T]) move(e, at *Element[T]) {
|
||||
if e == at {
|
||||
return
|
||||
}
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
|
||||
e.prev = at
|
||||
e.next = at.next
|
||||
e.prev.next = e
|
||||
e.next.prev = e
|
||||
}
|
||||
|
||||
// Remove removes e from l if e is an element of list l.
|
||||
// It returns the element value e.Value.
|
||||
// The element must not be nil.
|
||||
func (l *List[T]) Remove(e *Element[T]) T {
|
||||
v := e.Value
|
||||
if e.list == l {
|
||||
// if e.list == l, l must have been initialized when e was inserted
|
||||
// in l or l == nil (e is a zero Element) and l.remove will crash
|
||||
l.remove(e)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// PushFront inserts a new element e with value v at the front of list l and returns e.
|
||||
func (l *List[T]) PushFront(v T) *Element[T] {
|
||||
l.lazyInit()
|
||||
return l.insertValue(v, &l.root)
|
||||
}
|
||||
|
||||
// PushBack inserts a new element e with value v at the back of list l and returns e.
|
||||
func (l *List[T]) PushBack(v T) *Element[T] {
|
||||
l.lazyInit()
|
||||
return l.insertValue(v, l.root.prev)
|
||||
}
|
||||
|
||||
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
|
||||
// If mark is not an element of l, the list is not modified.
|
||||
// The mark must not be nil.
|
||||
func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
return l.insertValue(v, mark.prev)
|
||||
}
|
||||
|
||||
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
|
||||
// If mark is not an element of l, the list is not modified.
|
||||
// The mark must not be nil.
|
||||
func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
return l.insertValue(v, mark)
|
||||
}
|
||||
|
||||
// MoveToFront moves element e to the front of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
// The element must not be nil.
|
||||
func (l *List[T]) MoveToFront(e *Element[T]) {
|
||||
if e.list != l || l.root.next == e {
|
||||
return
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.move(e, &l.root)
|
||||
}
|
||||
|
||||
// MoveToBack moves element e to the back of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
// The element must not be nil.
|
||||
func (l *List[T]) MoveToBack(e *Element[T]) {
|
||||
if e.list != l || l.root.prev == e {
|
||||
return
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.move(e, l.root.prev)
|
||||
}
|
||||
|
||||
// MoveBefore moves element e to its new position before mark.
|
||||
// If e or mark is not an element of l, or e == mark, the list is not modified.
|
||||
// The element and mark must not be nil.
|
||||
func (l *List[T]) MoveBefore(e, mark *Element[T]) {
|
||||
if e.list != l || e == mark || mark.list != l {
|
||||
return
|
||||
}
|
||||
l.move(e, mark.prev)
|
||||
}
|
||||
|
||||
// MoveAfter moves element e to its new position after mark.
|
||||
// If e or mark is not an element of l, or e == mark, the list is not modified.
|
||||
// The element and mark must not be nil.
|
||||
func (l *List[T]) MoveAfter(e, mark *Element[T]) {
|
||||
if e.list != l || e == mark || mark.list != l {
|
||||
return
|
||||
}
|
||||
l.move(e, mark)
|
||||
}
|
||||
|
||||
// PushBackList inserts a copy of another list at the back of list l.
|
||||
// The lists l and other may be the same. They must not be nil.
|
||||
func (l *List[T]) PushBackList(other *List[T]) {
|
||||
l.lazyInit()
|
||||
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
|
||||
l.insertValue(e.Value, l.root.prev)
|
||||
}
|
||||
}
|
||||
|
||||
// PushFrontList inserts a copy of another list at the front of list l.
|
||||
// The lists l and other may be the same. They must not be nil.
|
||||
func (l *List[T]) PushFrontList(other *List[T]) {
|
||||
l.lazyInit()
|
||||
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
|
||||
l.insertValue(e.Value, &l.root)
|
||||
}
|
||||
}
|
||||
131
vendor/github.com/quic-go/quic-go/internal/utils/log.go
generated
vendored
Normal file
131
vendor/github.com/quic-go/quic-go/internal/utils/log.go
generated
vendored
Normal file
@@ -0,0 +1,131 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LogLevel of quic-go
|
||||
type LogLevel uint8
|
||||
|
||||
const (
|
||||
// LogLevelNothing disables
|
||||
LogLevelNothing LogLevel = iota
|
||||
// LogLevelError enables err logs
|
||||
LogLevelError
|
||||
// LogLevelInfo enables info logs (e.g. packets)
|
||||
LogLevelInfo
|
||||
// LogLevelDebug enables debug logs (e.g. packet contents)
|
||||
LogLevelDebug
|
||||
)
|
||||
|
||||
const logEnv = "QUIC_GO_LOG_LEVEL"
|
||||
|
||||
// A Logger logs.
|
||||
type Logger interface {
|
||||
SetLogLevel(LogLevel)
|
||||
SetLogTimeFormat(format string)
|
||||
WithPrefix(prefix string) Logger
|
||||
Debug() bool
|
||||
|
||||
Errorf(format string, args ...interface{})
|
||||
Infof(format string, args ...interface{})
|
||||
Debugf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// DefaultLogger is used by quic-go for logging.
|
||||
var DefaultLogger Logger
|
||||
|
||||
type defaultLogger struct {
|
||||
prefix string
|
||||
|
||||
logLevel LogLevel
|
||||
timeFormat string
|
||||
}
|
||||
|
||||
var _ Logger = &defaultLogger{}
|
||||
|
||||
// SetLogLevel sets the log level
|
||||
func (l *defaultLogger) SetLogLevel(level LogLevel) {
|
||||
l.logLevel = level
|
||||
}
|
||||
|
||||
// SetLogTimeFormat sets the format of the timestamp
|
||||
// an empty string disables the logging of timestamps
|
||||
func (l *defaultLogger) SetLogTimeFormat(format string) {
|
||||
log.SetFlags(0) // disable timestamp logging done by the log package
|
||||
l.timeFormat = format
|
||||
}
|
||||
|
||||
// Debugf logs something
|
||||
func (l *defaultLogger) Debugf(format string, args ...interface{}) {
|
||||
if l.logLevel == LogLevelDebug {
|
||||
l.logMessage(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Infof logs something
|
||||
func (l *defaultLogger) Infof(format string, args ...interface{}) {
|
||||
if l.logLevel >= LogLevelInfo {
|
||||
l.logMessage(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Errorf logs something
|
||||
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
|
||||
if l.logLevel >= LogLevelError {
|
||||
l.logMessage(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *defaultLogger) logMessage(format string, args ...interface{}) {
|
||||
var pre string
|
||||
|
||||
if len(l.timeFormat) > 0 {
|
||||
pre = time.Now().Format(l.timeFormat) + " "
|
||||
}
|
||||
if len(l.prefix) > 0 {
|
||||
pre += l.prefix + " "
|
||||
}
|
||||
log.Printf(pre+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) WithPrefix(prefix string) Logger {
|
||||
if len(l.prefix) > 0 {
|
||||
prefix = l.prefix + " " + prefix
|
||||
}
|
||||
return &defaultLogger{
|
||||
logLevel: l.logLevel,
|
||||
timeFormat: l.timeFormat,
|
||||
prefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
// Debug returns true if the log level is LogLevelDebug
|
||||
func (l *defaultLogger) Debug() bool {
|
||||
return l.logLevel == LogLevelDebug
|
||||
}
|
||||
|
||||
func init() {
|
||||
DefaultLogger = &defaultLogger{}
|
||||
DefaultLogger.SetLogLevel(readLoggingEnv())
|
||||
}
|
||||
|
||||
func readLoggingEnv() LogLevel {
|
||||
switch strings.ToLower(os.Getenv(logEnv)) {
|
||||
case "":
|
||||
return LogLevelNothing
|
||||
case "debug":
|
||||
return LogLevelDebug
|
||||
case "info":
|
||||
return LogLevelInfo
|
||||
case "error":
|
||||
return LogLevelError
|
||||
default:
|
||||
fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/quic-go/quic-go/wiki/Logging")
|
||||
return LogLevelNothing
|
||||
}
|
||||
}
|
||||
72
vendor/github.com/quic-go/quic-go/internal/utils/minmax.go
generated
vendored
Normal file
72
vendor/github.com/quic-go/quic-go/internal/utils/minmax.go
generated
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
// InfDuration is a duration of infinite length
|
||||
const InfDuration = time.Duration(math.MaxInt64)
|
||||
|
||||
func Max[T constraints.Ordered](a, b T) T {
|
||||
if a < b {
|
||||
return b
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func Min[T constraints.Ordered](a, b T) T {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// MinNonZeroDuration return the minimum duration that's not zero.
|
||||
func MinNonZeroDuration(a, b time.Duration) time.Duration {
|
||||
if a == 0 {
|
||||
return b
|
||||
}
|
||||
if b == 0 {
|
||||
return a
|
||||
}
|
||||
return Min(a, b)
|
||||
}
|
||||
|
||||
// AbsDuration returns the absolute value of a time duration
|
||||
func AbsDuration(d time.Duration) time.Duration {
|
||||
if d >= 0 {
|
||||
return d
|
||||
}
|
||||
return -d
|
||||
}
|
||||
|
||||
// MinTime returns the earlier time
|
||||
func MinTime(a, b time.Time) time.Time {
|
||||
if a.After(b) {
|
||||
return b
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// MinNonZeroTime returns the earlist time that is not time.Time{}
|
||||
// If both a and b are time.Time{}, it returns time.Time{}
|
||||
func MinNonZeroTime(a, b time.Time) time.Time {
|
||||
if a.IsZero() {
|
||||
return b
|
||||
}
|
||||
if b.IsZero() {
|
||||
return a
|
||||
}
|
||||
return MinTime(a, b)
|
||||
}
|
||||
|
||||
// MaxTime returns the later time
|
||||
func MaxTime(a, b time.Time) time.Time {
|
||||
if a.After(b) {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
29
vendor/github.com/quic-go/quic-go/internal/utils/rand.go
generated
vendored
Normal file
29
vendor/github.com/quic-go/quic-go/internal/utils/rand.go
generated
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// Rand is a wrapper around crypto/rand that adds some convenience functions known from math/rand.
|
||||
type Rand struct {
|
||||
buf [4]byte
|
||||
}
|
||||
|
||||
func (r *Rand) Int31() int32 {
|
||||
rand.Read(r.buf[:])
|
||||
return int32(binary.BigEndian.Uint32(r.buf[:]) & ^uint32(1<<31))
|
||||
}
|
||||
|
||||
// copied from the standard library math/rand implementation of Int63n
|
||||
func (r *Rand) Int31n(n int32) int32 {
|
||||
if n&(n-1) == 0 { // n is power of two, can mask
|
||||
return r.Int31() & (n - 1)
|
||||
}
|
||||
max := int32((1 << 31) - 1 - (1<<31)%uint32(n))
|
||||
v := r.Int31()
|
||||
for v > max {
|
||||
v = r.Int31()
|
||||
}
|
||||
return v % n
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user