WIP: Save agent roles integration work before CHORUS rebrand

- Agent roles and coordination features
- Chat API integration testing
- New configuration and workspace management

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
anthonyrawlins
2025-08-01 02:21:11 +10:00
parent 81b473d48f
commit 5978a0b8f5
3713 changed files with 1103925 additions and 59 deletions

7
vendor/github.com/quic-go/qpack/.codecov.yml generated vendored Normal file
View File

@@ -0,0 +1,7 @@
coverage:
round: nearest
status:
project:
default:
threshold: 1
patch: false

6
vendor/github.com/quic-go/qpack/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,6 @@
fuzzing/*.zip
fuzzing/coverprofile
fuzzing/crashers
fuzzing/sonarprofile
fuzzing/suppressions
fuzzing/corpus/

3
vendor/github.com/quic-go/qpack/.gitmodules generated vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "integrationtests/interop/qifs"]
path = integrationtests/interop/qifs
url = https://github.com/qpackers/qifs.git

27
vendor/github.com/quic-go/qpack/.golangci.yml generated vendored Normal file
View File

@@ -0,0 +1,27 @@
run:
linters-settings:
linters:
disable-all: true
enable:
- asciicheck
- deadcode
- exhaustive
- exportloopref
- goconst
- gofmt # redundant, since gofmt *should* be a no-op after gofumpt
- gofumpt
- goimports
- gosimple
- ineffassign
- misspell
- prealloc
- scopelint
- staticcheck
- stylecheck
- structcheck
- unconvert
- unparam
- unused
- varcheck
- vet

7
vendor/github.com/quic-go/qpack/LICENSE.md generated vendored Normal file
View File

@@ -0,0 +1,7 @@
Copyright 2019 Marten Seemann
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

20
vendor/github.com/quic-go/qpack/README.md generated vendored Normal file
View File

@@ -0,0 +1,20 @@
# QPACK
[![Godoc Reference](https://img.shields.io/badge/godoc-reference-blue.svg?style=flat-square)](https://godoc.org/github.com/marten-seemann/qpack)
[![Code Coverage](https://img.shields.io/codecov/c/github/marten-seemann/qpack/master.svg?style=flat-square)](https://codecov.io/gh/marten-seemann/qpack)
This is a minimal QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) implementation in Go. It is minimal in the sense that it doesn't use the dynamic table at all, but just the static table and (Huffman encoded) string literals. Wherever possible, it reuses code from the [HPACK implementation in the Go standard library](https://github.com/golang/net/tree/master/http2/hpack).
It should be able to interoperate with other QPACK implemetations (both encoders and decoders), however it won't achieve a high compression efficiency.
## Running the interop tests
Install the [QPACK interop files](https://github.com/qpackers/qifs/) by running
```bash
git submodule update --init --recursive
```
Then run the tests:
```bash
ginkgo -r integrationtests
```

271
vendor/github.com/quic-go/qpack/decoder.go generated vendored Normal file
View File

@@ -0,0 +1,271 @@
package qpack
import (
"bytes"
"errors"
"fmt"
"sync"
"golang.org/x/net/http2/hpack"
)
// A decodingError is something the spec defines as a decoding error.
type decodingError struct {
err error
}
func (de decodingError) Error() string {
return fmt.Sprintf("decoding error: %v", de.err)
}
// An invalidIndexError is returned when an encoder references a table
// entry before the static table or after the end of the dynamic table.
type invalidIndexError int
func (e invalidIndexError) Error() string {
return fmt.Sprintf("invalid indexed representation index %d", int(e))
}
var errNoDynamicTable = decodingError{errors.New("no dynamic table")}
// errNeedMore is an internal sentinel error value that means the
// buffer is truncated and we need to read more data before we can
// continue parsing.
var errNeedMore = errors.New("need more data")
// A Decoder is the decoding context for incremental processing of
// header blocks.
type Decoder struct {
mutex sync.Mutex
emitFunc func(f HeaderField)
readRequiredInsertCount bool
readDeltaBase bool
// buf is the unparsed buffer. It's only written to
// saveBuf if it was truncated in the middle of a header
// block. Because it's usually not owned, we can only
// process it under Write.
buf []byte // not owned; only valid during Write
// saveBuf is previous data passed to Write which we weren't able
// to fully parse before. Unlike buf, we own this data.
saveBuf bytes.Buffer
}
// NewDecoder returns a new decoder
// The emitFunc will be called for each valid field parsed,
// in the same goroutine as calls to Write, before Write returns.
func NewDecoder(emitFunc func(f HeaderField)) *Decoder {
return &Decoder{emitFunc: emitFunc}
}
func (d *Decoder) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
d.mutex.Lock()
n, err := d.writeLocked(p)
d.mutex.Unlock()
return n, err
}
func (d *Decoder) writeLocked(p []byte) (int, error) {
// Only copy the data if we have to. Optimistically assume
// that p will contain a complete header block.
if d.saveBuf.Len() == 0 {
d.buf = p
} else {
d.saveBuf.Write(p)
d.buf = d.saveBuf.Bytes()
d.saveBuf.Reset()
}
if err := d.decode(); err != nil {
if err != errNeedMore {
return 0, err
}
// TODO: limit the size of the buffer
d.saveBuf.Write(d.buf)
}
return len(p), nil
}
// DecodeFull decodes an entire block.
func (d *Decoder) DecodeFull(p []byte) ([]HeaderField, error) {
if len(p) == 0 {
return []HeaderField{}, nil
}
d.mutex.Lock()
defer d.mutex.Unlock()
saveFunc := d.emitFunc
defer func() { d.emitFunc = saveFunc }()
var hf []HeaderField
d.emitFunc = func(f HeaderField) { hf = append(hf, f) }
if _, err := d.writeLocked(p); err != nil {
return nil, err
}
if err := d.Close(); err != nil {
return nil, err
}
return hf, nil
}
// Close declares that the decoding is complete and resets the Decoder
// to be reused again for a new header block. If there is any remaining
// data in the decoder's buffer, Close returns an error.
func (d *Decoder) Close() error {
if d.saveBuf.Len() > 0 {
d.saveBuf.Reset()
return decodingError{errors.New("truncated headers")}
}
d.readRequiredInsertCount = false
d.readDeltaBase = false
return nil
}
func (d *Decoder) decode() error {
if !d.readRequiredInsertCount {
requiredInsertCount, rest, err := readVarInt(8, d.buf)
if err != nil {
return err
}
d.readRequiredInsertCount = true
if requiredInsertCount != 0 {
return decodingError{errors.New("expected Required Insert Count to be zero")}
}
d.buf = rest
}
if !d.readDeltaBase {
base, rest, err := readVarInt(7, d.buf)
if err != nil {
return err
}
d.readDeltaBase = true
if base != 0 {
return decodingError{errors.New("expected Base to be zero")}
}
d.buf = rest
}
if len(d.buf) == 0 {
return errNeedMore
}
for len(d.buf) > 0 {
b := d.buf[0]
var err error
switch {
case b&0x80 > 0: // 1xxxxxxx
err = d.parseIndexedHeaderField()
case b&0xc0 == 0x40: // 01xxxxxx
err = d.parseLiteralHeaderField()
case b&0xe0 == 0x20: // 001xxxxx
err = d.parseLiteralHeaderFieldWithoutNameReference()
default:
err = fmt.Errorf("unexpected type byte: %#x", b)
}
if err != nil {
return err
}
}
return nil
}
func (d *Decoder) parseIndexedHeaderField() error {
buf := d.buf
if buf[0]&0x40 == 0 {
return errNoDynamicTable
}
index, buf, err := readVarInt(6, buf)
if err != nil {
return err
}
hf, ok := d.at(index)
if !ok {
return decodingError{invalidIndexError(index)}
}
d.emitFunc(hf)
d.buf = buf
return nil
}
func (d *Decoder) parseLiteralHeaderField() error {
buf := d.buf
if buf[0]&0x20 > 0 || buf[0]&0x10 == 0 {
return errNoDynamicTable
}
index, buf, err := readVarInt(4, buf)
if err != nil {
return err
}
hf, ok := d.at(index)
if !ok {
return decodingError{invalidIndexError(index)}
}
if len(buf) == 0 {
return errNeedMore
}
usesHuffman := buf[0]&0x80 > 0
val, buf, err := d.readString(buf, 7, usesHuffman)
if err != nil {
return err
}
hf.Value = val
d.emitFunc(hf)
d.buf = buf
return nil
}
func (d *Decoder) parseLiteralHeaderFieldWithoutNameReference() error {
buf := d.buf
usesHuffmanForName := buf[0]&0x8 > 0
name, buf, err := d.readString(buf, 3, usesHuffmanForName)
if err != nil {
return err
}
if len(buf) == 0 {
return errNeedMore
}
usesHuffmanForVal := buf[0]&0x80 > 0
val, buf, err := d.readString(buf, 7, usesHuffmanForVal)
if err != nil {
return err
}
d.emitFunc(HeaderField{Name: name, Value: val})
d.buf = buf
return nil
}
func (d *Decoder) readString(buf []byte, n uint8, usesHuffman bool) (string, []byte, error) {
l, buf, err := readVarInt(n, buf)
if err != nil {
return "", nil, err
}
if uint64(len(buf)) < l {
return "", nil, errNeedMore
}
var val string
if usesHuffman {
var err error
val, err = hpack.HuffmanDecodeToString(buf[:l])
if err != nil {
return "", nil, err
}
} else {
val = string(buf[:l])
}
buf = buf[l:]
return val, buf, nil
}
func (d *Decoder) at(i uint64) (hf HeaderField, ok bool) {
if i >= uint64(len(staticTableEntries)) {
return
}
return staticTableEntries[i], true
}

95
vendor/github.com/quic-go/qpack/encoder.go generated vendored Normal file
View File

@@ -0,0 +1,95 @@
package qpack
import (
"io"
"golang.org/x/net/http2/hpack"
)
// An Encoder performs QPACK encoding.
type Encoder struct {
wrotePrefix bool
w io.Writer
buf []byte
}
// NewEncoder returns a new Encoder which performs QPACK encoding. An
// encoded data is written to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w}
}
// WriteField encodes f into a single Write to e's underlying Writer.
// This function may also produce bytes for the Header Block Prefix
// if necessary. If produced, it is done before encoding f.
func (e *Encoder) WriteField(f HeaderField) error {
// write the Header Block Prefix
if !e.wrotePrefix {
e.buf = appendVarInt(e.buf, 8, 0)
e.buf = appendVarInt(e.buf, 7, 0)
e.wrotePrefix = true
}
idxAndVals, nameFound := encoderMap[f.Name]
if nameFound {
if idxAndVals.values == nil {
if len(f.Value) == 0 {
e.writeIndexedField(idxAndVals.idx)
} else {
e.writeLiteralFieldWithNameReference(&f, idxAndVals.idx)
}
} else {
valIdx, valueFound := idxAndVals.values[f.Value]
if valueFound {
e.writeIndexedField(valIdx)
} else {
e.writeLiteralFieldWithNameReference(&f, idxAndVals.idx)
}
}
} else {
e.writeLiteralFieldWithoutNameReference(f)
}
_, err := e.w.Write(e.buf)
e.buf = e.buf[:0]
return err
}
// Close declares that the encoding is complete and resets the Encoder
// to be reused again for a new header block.
func (e *Encoder) Close() error {
e.wrotePrefix = false
return nil
}
func (e *Encoder) writeLiteralFieldWithoutNameReference(f HeaderField) {
offset := len(e.buf)
e.buf = appendVarInt(e.buf, 3, hpack.HuffmanEncodeLength(f.Name))
e.buf[offset] ^= 0x20 ^ 0x8
e.buf = hpack.AppendHuffmanString(e.buf, f.Name)
offset = len(e.buf)
e.buf = appendVarInt(e.buf, 7, hpack.HuffmanEncodeLength(f.Value))
e.buf[offset] ^= 0x80
e.buf = hpack.AppendHuffmanString(e.buf, f.Value)
}
// Encodes a header field whose name is present in one of the tables.
func (e *Encoder) writeLiteralFieldWithNameReference(f *HeaderField, id uint8) {
offset := len(e.buf)
e.buf = appendVarInt(e.buf, 4, uint64(id))
// Set the 01NTxxxx pattern, forcing N to 0 and T to 1
e.buf[offset] ^= 0x50
offset = len(e.buf)
e.buf = appendVarInt(e.buf, 7, hpack.HuffmanEncodeLength(f.Value))
e.buf[offset] ^= 0x80
e.buf = hpack.AppendHuffmanString(e.buf, f.Value)
}
// Encodes an indexed field, meaning it's entirely defined in one of the tables.
func (e *Encoder) writeIndexedField(id uint8) {
offset := len(e.buf)
e.buf = appendVarInt(e.buf, 6, uint64(id))
// Set the 1Txxxxxx pattern, forcing T to 1
e.buf[offset] ^= 0xc0
}

16
vendor/github.com/quic-go/qpack/header_field.go generated vendored Normal file
View File

@@ -0,0 +1,16 @@
package qpack
// A HeaderField is a name-value pair. Both the name and value are
// treated as opaque sequences of octets.
type HeaderField struct {
Name string
Value string
}
// IsPseudo reports whether the header field is an HTTP3 pseudo header.
// That is, it reports whether it starts with a colon.
// It is not otherwise guaranteed to be a valid pseudo header field,
// though.
func (hf HeaderField) IsPseudo() bool {
return len(hf.Name) != 0 && hf.Name[0] == ':'
}

255
vendor/github.com/quic-go/qpack/static_table.go generated vendored Normal file
View File

@@ -0,0 +1,255 @@
package qpack
var staticTableEntries = [...]HeaderField{
{Name: ":authority"},
{Name: ":path", Value: "/"},
{Name: "age", Value: "0"},
{Name: "content-disposition"},
{Name: "content-length", Value: "0"},
{Name: "cookie"},
{Name: "date"},
{Name: "etag"},
{Name: "if-modified-since"},
{Name: "if-none-match"},
{Name: "last-modified"},
{Name: "link"},
{Name: "location"},
{Name: "referer"},
{Name: "set-cookie"},
{Name: ":method", Value: "CONNECT"},
{Name: ":method", Value: "DELETE"},
{Name: ":method", Value: "GET"},
{Name: ":method", Value: "HEAD"},
{Name: ":method", Value: "OPTIONS"},
{Name: ":method", Value: "POST"},
{Name: ":method", Value: "PUT"},
{Name: ":scheme", Value: "http"},
{Name: ":scheme", Value: "https"},
{Name: ":status", Value: "103"},
{Name: ":status", Value: "200"},
{Name: ":status", Value: "304"},
{Name: ":status", Value: "404"},
{Name: ":status", Value: "503"},
{Name: "accept", Value: "*/*"},
{Name: "accept", Value: "application/dns-message"},
{Name: "accept-encoding", Value: "gzip, deflate, br"},
{Name: "accept-ranges", Value: "bytes"},
{Name: "access-control-allow-headers", Value: "cache-control"},
{Name: "access-control-allow-headers", Value: "content-type"},
{Name: "access-control-allow-origin", Value: "*"},
{Name: "cache-control", Value: "max-age=0"},
{Name: "cache-control", Value: "max-age=2592000"},
{Name: "cache-control", Value: "max-age=604800"},
{Name: "cache-control", Value: "no-cache"},
{Name: "cache-control", Value: "no-store"},
{Name: "cache-control", Value: "public, max-age=31536000"},
{Name: "content-encoding", Value: "br"},
{Name: "content-encoding", Value: "gzip"},
{Name: "content-type", Value: "application/dns-message"},
{Name: "content-type", Value: "application/javascript"},
{Name: "content-type", Value: "application/json"},
{Name: "content-type", Value: "application/x-www-form-urlencoded"},
{Name: "content-type", Value: "image/gif"},
{Name: "content-type", Value: "image/jpeg"},
{Name: "content-type", Value: "image/png"},
{Name: "content-type", Value: "text/css"},
{Name: "content-type", Value: "text/html; charset=utf-8"},
{Name: "content-type", Value: "text/plain"},
{Name: "content-type", Value: "text/plain;charset=utf-8"},
{Name: "range", Value: "bytes=0-"},
{Name: "strict-transport-security", Value: "max-age=31536000"},
{Name: "strict-transport-security", Value: "max-age=31536000; includesubdomains"},
{Name: "strict-transport-security", Value: "max-age=31536000; includesubdomains; preload"},
{Name: "vary", Value: "accept-encoding"},
{Name: "vary", Value: "origin"},
{Name: "x-content-type-options", Value: "nosniff"},
{Name: "x-xss-protection", Value: "1; mode=block"},
{Name: ":status", Value: "100"},
{Name: ":status", Value: "204"},
{Name: ":status", Value: "206"},
{Name: ":status", Value: "302"},
{Name: ":status", Value: "400"},
{Name: ":status", Value: "403"},
{Name: ":status", Value: "421"},
{Name: ":status", Value: "425"},
{Name: ":status", Value: "500"},
{Name: "accept-language"},
{Name: "access-control-allow-credentials", Value: "FALSE"},
{Name: "access-control-allow-credentials", Value: "TRUE"},
{Name: "access-control-allow-headers", Value: "*"},
{Name: "access-control-allow-methods", Value: "get"},
{Name: "access-control-allow-methods", Value: "get, post, options"},
{Name: "access-control-allow-methods", Value: "options"},
{Name: "access-control-expose-headers", Value: "content-length"},
{Name: "access-control-request-headers", Value: "content-type"},
{Name: "access-control-request-method", Value: "get"},
{Name: "access-control-request-method", Value: "post"},
{Name: "alt-svc", Value: "clear"},
{Name: "authorization"},
{Name: "content-security-policy", Value: "script-src 'none'; object-src 'none'; base-uri 'none'"},
{Name: "early-data", Value: "1"},
{Name: "expect-ct"},
{Name: "forwarded"},
{Name: "if-range"},
{Name: "origin"},
{Name: "purpose", Value: "prefetch"},
{Name: "server"},
{Name: "timing-allow-origin", Value: "*"},
{Name: "upgrade-insecure-requests", Value: "1"},
{Name: "user-agent"},
{Name: "x-forwarded-for"},
{Name: "x-frame-options", Value: "deny"},
{Name: "x-frame-options", Value: "sameorigin"},
}
// Only needed for tests.
// use go:linkname to retrieve the static table.
//
//nolint:deadcode,unused
func getStaticTable() []HeaderField {
return staticTableEntries[:]
}
type indexAndValues struct {
idx uint8
values map[string]uint8
}
// A map of the header names from the static table to their index in the table.
// This is used by the encoder to quickly find if a header is in the static table
// and what value should be used to encode it.
// There's a second level of mapping for the headers that have some predefined
// values in the static table.
var encoderMap = map[string]indexAndValues{
":authority": {0, nil},
":path": {1, map[string]uint8{"/": 1}},
"age": {2, map[string]uint8{"0": 2}},
"content-disposition": {3, nil},
"content-length": {4, map[string]uint8{"0": 4}},
"cookie": {5, nil},
"date": {6, nil},
"etag": {7, nil},
"if-modified-since": {8, nil},
"if-none-match": {9, nil},
"last-modified": {10, nil},
"link": {11, nil},
"location": {12, nil},
"referer": {13, nil},
"set-cookie": {14, nil},
":method": {15, map[string]uint8{
"CONNECT": 15,
"DELETE": 16,
"GET": 17,
"HEAD": 18,
"OPTIONS": 19,
"POST": 20,
"PUT": 21,
}},
":scheme": {22, map[string]uint8{
"http": 22,
"https": 23,
}},
":status": {24, map[string]uint8{
"103": 24,
"200": 25,
"304": 26,
"404": 27,
"503": 28,
"100": 63,
"204": 64,
"206": 65,
"302": 66,
"400": 67,
"403": 68,
"421": 69,
"425": 70,
"500": 71,
}},
"accept": {29, map[string]uint8{
"*/*": 29,
"application/dns-message": 30,
}},
"accept-encoding": {31, map[string]uint8{"gzip, deflate, br": 31}},
"accept-ranges": {32, map[string]uint8{"bytes": 32}},
"access-control-allow-headers": {33, map[string]uint8{
"cache-control": 33,
"content-type": 34,
"*": 75,
}},
"access-control-allow-origin": {35, map[string]uint8{"*": 35}},
"cache-control": {36, map[string]uint8{
"max-age=0": 36,
"max-age=2592000": 37,
"max-age=604800": 38,
"no-cache": 39,
"no-store": 40,
"public, max-age=31536000": 41,
}},
"content-encoding": {42, map[string]uint8{
"br": 42,
"gzip": 43,
}},
"content-type": {44, map[string]uint8{
"application/dns-message": 44,
"application/javascript": 45,
"application/json": 46,
"application/x-www-form-urlencoded": 47,
"image/gif": 48,
"image/jpeg": 49,
"image/png": 50,
"text/css": 51,
"text/html; charset=utf-8": 52,
"text/plain": 53,
"text/plain;charset=utf-8": 54,
}},
"range": {55, map[string]uint8{"bytes=0-": 55}},
"strict-transport-security": {56, map[string]uint8{
"max-age=31536000": 56,
"max-age=31536000; includesubdomains": 57,
"max-age=31536000; includesubdomains; preload": 58,
}},
"vary": {59, map[string]uint8{
"accept-encoding": 59,
"origin": 60,
}},
"x-content-type-options": {61, map[string]uint8{"nosniff": 61}},
"x-xss-protection": {62, map[string]uint8{"1; mode=block": 62}},
// ":status" is duplicated and takes index 63 to 71
"accept-language": {72, nil},
"access-control-allow-credentials": {73, map[string]uint8{
"FALSE": 73,
"TRUE": 74,
}},
// "access-control-allow-headers" is duplicated and takes index 75
"access-control-allow-methods": {76, map[string]uint8{
"get": 76,
"get, post, options": 77,
"options": 78,
}},
"access-control-expose-headers": {79, map[string]uint8{"content-length": 79}},
"access-control-request-headers": {80, map[string]uint8{"content-type": 80}},
"access-control-request-method": {81, map[string]uint8{
"get": 81,
"post": 82,
}},
"alt-svc": {83, map[string]uint8{"clear": 83}},
"authorization": {84, nil},
"content-security-policy": {85, map[string]uint8{
"script-src 'none'; object-src 'none'; base-uri 'none'": 85,
}},
"early-data": {86, map[string]uint8{"1": 86}},
"expect-ct": {87, nil},
"forwarded": {88, nil},
"if-range": {89, nil},
"origin": {90, nil},
"purpose": {91, map[string]uint8{"prefetch": 91}},
"server": {92, nil},
"timing-allow-origin": {93, map[string]uint8{"*": 93}},
"upgrade-insecure-requests": {94, map[string]uint8{"1": 94}},
"user-agent": {95, nil},
"x-forwarded-for": {96, nil},
"x-frame-options": {97, map[string]uint8{
"deny": 97,
"sameorigin": 98,
}},
}

5
vendor/github.com/quic-go/qpack/tools.go generated vendored Normal file
View File

@@ -0,0 +1,5 @@
//go:build tools
package qpack
import _ "github.com/onsi/ginkgo/v2/ginkgo"

66
vendor/github.com/quic-go/qpack/varint.go generated vendored Normal file
View File

@@ -0,0 +1,66 @@
package qpack
// copied from the Go standard library HPACK implementation
import "errors"
var errVarintOverflow = errors.New("varint integer overflow")
// appendVarInt appends i, as encoded in variable integer form using n
// bit prefix, to dst and returns the extended buffer.
//
// See
// http://http2.github.io/http2-spec/compression.html#integer.representation
func appendVarInt(dst []byte, n byte, i uint64) []byte {
k := uint64((1 << n) - 1)
if i < k {
return append(dst, byte(i))
}
dst = append(dst, byte(k))
i -= k
for ; i >= 128; i >>= 7 {
dst = append(dst, byte(0x80|(i&0x7f)))
}
return append(dst, byte(i))
}
// readVarInt reads an unsigned variable length integer off the
// beginning of p. n is the parameter as described in
// http://http2.github.io/http2-spec/compression.html#rfc.section.5.1.
//
// n must always be between 1 and 8.
//
// The returned remain buffer is either a smaller suffix of p, or err != nil.
// The error is errNeedMore if p doesn't contain a complete integer.
func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) {
if n < 1 || n > 8 {
panic("bad n")
}
if len(p) == 0 {
return 0, p, errNeedMore
}
i = uint64(p[0])
if n < 8 {
i &= (1 << uint64(n)) - 1
}
if i < (1<<uint64(n))-1 {
return i, p[1:], nil
}
origP := p
p = p[1:]
var m uint64
for len(p) > 0 {
b := p[0]
p = p[1:]
i += uint64(b&127) << m
if b&128 == 0 {
return i, p, nil
}
m += 7
if m >= 63 { // TODO: proper overflow check. making this up.
return 0, origP, errVarintOverflow
}
}
return 0, origP, errNeedMore
}

27
vendor/github.com/quic-go/qtls-go1-20/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

6
vendor/github.com/quic-go/qtls-go1-20/README.md generated vendored Normal file
View File

@@ -0,0 +1,6 @@
# qtls
[![Go Reference](https://pkg.go.dev/badge/github.com/quic-go/qtls-go1-20.svg)](https://pkg.go.dev/github.com/quic-go/qtls-go1-20)
[![.github/workflows/go-test.yml](https://github.com/quic-go/qtls-go1-20/actions/workflows/go-test.yml/badge.svg)](https://github.com/quic-go/qtls-go1-20/actions/workflows/go-test.yml)
This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/quic-go/quic-go).

109
vendor/github.com/quic-go/qtls-go1-20/alert.go generated vendored Normal file
View File

@@ -0,0 +1,109 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import "strconv"
// An AlertError is a TLS alert.
//
// When using a QUIC transport, QUICConn methods will return an error
// which wraps AlertError rather than sending a TLS alert.
type AlertError uint8
func (e AlertError) Error() string {
return alert(e).String()
}
type alert uint8
const (
// alert level
alertLevelWarning = 1
alertLevelError = 2
)
const (
alertCloseNotify alert = 0
alertUnexpectedMessage alert = 10
alertBadRecordMAC alert = 20
alertDecryptionFailed alert = 21
alertRecordOverflow alert = 22
alertDecompressionFailure alert = 30
alertHandshakeFailure alert = 40
alertBadCertificate alert = 42
alertUnsupportedCertificate alert = 43
alertCertificateRevoked alert = 44
alertCertificateExpired alert = 45
alertCertificateUnknown alert = 46
alertIllegalParameter alert = 47
alertUnknownCA alert = 48
alertAccessDenied alert = 49
alertDecodeError alert = 50
alertDecryptError alert = 51
alertExportRestriction alert = 60
alertProtocolVersion alert = 70
alertInsufficientSecurity alert = 71
alertInternalError alert = 80
alertInappropriateFallback alert = 86
alertUserCanceled alert = 90
alertNoRenegotiation alert = 100
alertMissingExtension alert = 109
alertUnsupportedExtension alert = 110
alertCertificateUnobtainable alert = 111
alertUnrecognizedName alert = 112
alertBadCertificateStatusResponse alert = 113
alertBadCertificateHashValue alert = 114
alertUnknownPSKIdentity alert = 115
alertCertificateRequired alert = 116
alertNoApplicationProtocol alert = 120
)
var alertText = map[alert]string{
alertCloseNotify: "close notify",
alertUnexpectedMessage: "unexpected message",
alertBadRecordMAC: "bad record MAC",
alertDecryptionFailed: "decryption failed",
alertRecordOverflow: "record overflow",
alertDecompressionFailure: "decompression failure",
alertHandshakeFailure: "handshake failure",
alertBadCertificate: "bad certificate",
alertUnsupportedCertificate: "unsupported certificate",
alertCertificateRevoked: "revoked certificate",
alertCertificateExpired: "expired certificate",
alertCertificateUnknown: "unknown certificate",
alertIllegalParameter: "illegal parameter",
alertUnknownCA: "unknown certificate authority",
alertAccessDenied: "access denied",
alertDecodeError: "error decoding message",
alertDecryptError: "error decrypting message",
alertExportRestriction: "export restriction",
alertProtocolVersion: "protocol version not supported",
alertInsufficientSecurity: "insufficient security level",
alertInternalError: "internal error",
alertInappropriateFallback: "inappropriate fallback",
alertUserCanceled: "user canceled",
alertNoRenegotiation: "no renegotiation",
alertMissingExtension: "missing extension",
alertUnsupportedExtension: "unsupported extension",
alertCertificateUnobtainable: "certificate unobtainable",
alertUnrecognizedName: "unrecognized name",
alertBadCertificateStatusResponse: "bad certificate status response",
alertBadCertificateHashValue: "bad certificate hash value",
alertUnknownPSKIdentity: "unknown PSK identity",
alertCertificateRequired: "certificate required",
alertNoApplicationProtocol: "no application protocol",
}
func (e alert) String() string {
s, ok := alertText[e]
if ok {
return "tls: " + s
}
return "tls: alert(" + strconv.Itoa(int(e)) + ")"
}
func (e alert) Error() string {
return e.String()
}

293
vendor/github.com/quic-go/qtls-go1-20/auth.go generated vendored Normal file
View File

@@ -0,0 +1,293 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"errors"
"fmt"
"hash"
"io"
)
// verifyHandshakeSignature verifies a signature against pre-hashed
// (if required) handshake contents.
func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error {
switch sigType {
case signatureECDSA:
pubKey, ok := pubkey.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("expected an ECDSA public key, got %T", pubkey)
}
if !ecdsa.VerifyASN1(pubKey, signed, sig) {
return errors.New("ECDSA verification failure")
}
case signatureEd25519:
pubKey, ok := pubkey.(ed25519.PublicKey)
if !ok {
return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey)
}
if !ed25519.Verify(pubKey, signed, sig) {
return errors.New("Ed25519 verification failure")
}
case signaturePKCS1v15:
pubKey, ok := pubkey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
}
if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil {
return err
}
case signatureRSAPSS:
pubKey, ok := pubkey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
}
signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}
if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil {
return err
}
default:
return errors.New("internal error: unknown signature type")
}
return nil
}
const (
serverSignatureContext = "TLS 1.3, server CertificateVerify\x00"
clientSignatureContext = "TLS 1.3, client CertificateVerify\x00"
)
var signaturePadding = []byte{
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
}
// signedMessage returns the pre-hashed (if necessary) message to be signed by
// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3.
func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte {
if sigHash == directSigning {
b := &bytes.Buffer{}
b.Write(signaturePadding)
io.WriteString(b, context)
b.Write(transcript.Sum(nil))
return b.Bytes()
}
h := sigHash.New()
h.Write(signaturePadding)
io.WriteString(h, context)
h.Write(transcript.Sum(nil))
return h.Sum(nil)
}
// typeAndHashFromSignatureScheme returns the corresponding signature type and
// crypto.Hash for a given TLS SignatureScheme.
func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) {
switch signatureAlgorithm {
case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512:
sigType = signaturePKCS1v15
case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512:
sigType = signatureRSAPSS
case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512:
sigType = signatureECDSA
case Ed25519:
sigType = signatureEd25519
default:
return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm)
}
switch signatureAlgorithm {
case PKCS1WithSHA1, ECDSAWithSHA1:
hash = crypto.SHA1
case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256:
hash = crypto.SHA256
case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384:
hash = crypto.SHA384
case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512:
hash = crypto.SHA512
case Ed25519:
hash = directSigning
default:
return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm)
}
return sigType, hash, nil
}
// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for
// a given public key used with TLS 1.0 and 1.1, before the introduction of
// signature algorithm negotiation.
func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) {
switch pub.(type) {
case *rsa.PublicKey:
return signaturePKCS1v15, crypto.MD5SHA1, nil
case *ecdsa.PublicKey:
return signatureECDSA, crypto.SHA1, nil
case ed25519.PublicKey:
// RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1,
// but it requires holding on to a handshake transcript to do a
// full signature, and not even OpenSSL bothers with the
// complexity, so we can't even test it properly.
return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2")
default:
return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub)
}
}
var rsaSignatureSchemes = []struct {
scheme SignatureScheme
minModulusBytes int
maxVersion uint16
}{
// RSA-PSS is used with PSSSaltLengthEqualsHash, and requires
// emLen >= hLen + sLen + 2
{PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13},
{PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13},
{PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13},
// PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires
// emLen >= len(prefix) + hLen + 11
// TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS.
{PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12},
{PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12},
{PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12},
{PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12},
}
// signatureSchemesForCertificate returns the list of supported SignatureSchemes
// for a given certificate, based on the public key and the protocol version,
// and optionally filtered by its explicit SupportedSignatureAlgorithms.
//
// This function must be kept in sync with supportedSignatureAlgorithms.
// FIPS filtering is applied in the caller, selectSignatureScheme.
func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme {
priv, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil
}
var sigAlgs []SignatureScheme
switch pub := priv.Public().(type) {
case *ecdsa.PublicKey:
if version != VersionTLS13 {
// In TLS 1.2 and earlier, ECDSA algorithms are not
// constrained to a single curve.
sigAlgs = []SignatureScheme{
ECDSAWithP256AndSHA256,
ECDSAWithP384AndSHA384,
ECDSAWithP521AndSHA512,
ECDSAWithSHA1,
}
break
}
switch pub.Curve {
case elliptic.P256():
sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256}
case elliptic.P384():
sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384}
case elliptic.P521():
sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512}
default:
return nil
}
case *rsa.PublicKey:
size := pub.Size()
sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes))
for _, candidate := range rsaSignatureSchemes {
if size >= candidate.minModulusBytes && version <= candidate.maxVersion {
sigAlgs = append(sigAlgs, candidate.scheme)
}
}
case ed25519.PublicKey:
sigAlgs = []SignatureScheme{Ed25519}
default:
return nil
}
if cert.SupportedSignatureAlgorithms != nil {
var filteredSigAlgs []SignatureScheme
for _, sigAlg := range sigAlgs {
if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) {
filteredSigAlgs = append(filteredSigAlgs, sigAlg)
}
}
return filteredSigAlgs
}
return sigAlgs
}
// selectSignatureScheme picks a SignatureScheme from the peer's preference list
// that works with the selected certificate. It's only called for protocol
// versions that support signature algorithms, so TLS 1.2 and 1.3.
func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) {
supportedAlgs := signatureSchemesForCertificate(vers, c)
if len(supportedAlgs) == 0 {
return 0, unsupportedCertificateError(c)
}
if len(peerAlgs) == 0 && vers == VersionTLS12 {
// For TLS 1.2, if the client didn't send signature_algorithms then we
// can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1.
peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1}
}
// Pick signature scheme in the peer's preference order, as our
// preference order is not configurable.
for _, preferredAlg := range peerAlgs {
if needFIPS() && !isSupportedSignatureAlgorithm(preferredAlg, fipsSupportedSignatureAlgorithms) {
continue
}
if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) {
return preferredAlg, nil
}
}
return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms")
}
// unsupportedCertificateError returns a helpful error for certificates with
// an unsupported private key.
func unsupportedCertificateError(cert *Certificate) error {
switch cert.PrivateKey.(type) {
case rsa.PrivateKey, ecdsa.PrivateKey:
return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T",
cert.PrivateKey, cert.PrivateKey)
case *ed25519.PrivateKey:
return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey")
}
signer, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer",
cert.PrivateKey)
}
switch pub := signer.Public().(type) {
case *ecdsa.PublicKey:
switch pub.Curve {
case elliptic.P256():
case elliptic.P384():
case elliptic.P521():
default:
return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name)
}
case *rsa.PublicKey:
return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms")
case ed25519.PublicKey:
default:
return fmt.Errorf("tls: unsupported certificate key (%T)", pub)
}
if cert.SupportedSignatureAlgorithms != nil {
return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms")
}
return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey)
}

95
vendor/github.com/quic-go/qtls-go1-20/cache.go generated vendored Normal file
View File

@@ -0,0 +1,95 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"crypto/x509"
"runtime"
"sync"
"sync/atomic"
)
type cacheEntry struct {
refs atomic.Int64
cert *x509.Certificate
}
// certCache implements an intern table for reference counted x509.Certificates,
// implemented in a similar fashion to BoringSSL's CRYPTO_BUFFER_POOL. This
// allows for a single x509.Certificate to be kept in memory and referenced from
// multiple Conns. Returned references should not be mutated by callers. Certificates
// are still safe to use after they are removed from the cache.
//
// Certificates are returned wrapped in a activeCert struct that should be held by
// the caller. When references to the activeCert are freed, the number of references
// to the certificate in the cache is decremented. Once the number of references
// reaches zero, the entry is evicted from the cache.
//
// The main difference between this implementation and CRYPTO_BUFFER_POOL is that
// CRYPTO_BUFFER_POOL is a more generic structure which supports blobs of data,
// rather than specific structures. Since we only care about x509.Certificates,
// certCache is implemented as a specific cache, rather than a generic one.
//
// See https://boringssl.googlesource.com/boringssl/+/master/include/openssl/pool.h
// and https://boringssl.googlesource.com/boringssl/+/master/crypto/pool/pool.c
// for the BoringSSL reference.
type certCache struct {
sync.Map
}
var clientCertCache = new(certCache)
// activeCert is a handle to a certificate held in the cache. Once there are
// no alive activeCerts for a given certificate, the certificate is removed
// from the cache by a finalizer.
type activeCert struct {
cert *x509.Certificate
}
// active increments the number of references to the entry, wraps the
// certificate in the entry in a activeCert, and sets the finalizer.
//
// Note that there is a race between active and the finalizer set on the
// returned activeCert, triggered if active is called after the ref count is
// decremented such that refs may be > 0 when evict is called. We consider this
// safe, since the caller holding an activeCert for an entry that is no longer
// in the cache is fine, with the only side effect being the memory overhead of
// there being more than one distinct reference to a certificate alive at once.
func (cc *certCache) active(e *cacheEntry) *activeCert {
e.refs.Add(1)
a := &activeCert{e.cert}
runtime.SetFinalizer(a, func(_ *activeCert) {
if e.refs.Add(-1) == 0 {
cc.evict(e)
}
})
return a
}
// evict removes a cacheEntry from the cache.
func (cc *certCache) evict(e *cacheEntry) {
cc.Delete(string(e.cert.Raw))
}
// newCert returns a x509.Certificate parsed from der. If there is already a copy
// of the certificate in the cache, a reference to the existing certificate will
// be returned. Otherwise, a fresh certificate will be added to the cache, and
// the reference returned. The returned reference should not be mutated.
func (cc *certCache) newCert(der []byte) (*activeCert, error) {
if entry, ok := cc.Load(string(der)); ok {
return cc.active(entry.(*cacheEntry)), nil
}
cert, err := x509.ParseCertificate(der)
if err != nil {
return nil, err
}
entry := &cacheEntry{cert: cert}
if entry, loaded := cc.LoadOrStore(string(der), entry); loaded {
return cc.active(entry.(*cacheEntry)), nil
}
return cc.active(entry), nil
}

691
vendor/github.com/quic-go/qtls-go1-20/cipher_suites.go generated vendored Normal file
View File

@@ -0,0 +1,691 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/hmac"
"crypto/rc4"
"crypto/sha1"
"crypto/sha256"
"fmt"
"hash"
"runtime"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/sys/cpu"
)
// CipherSuite is a TLS cipher suite. Note that most functions in this package
// accept and expose cipher suite IDs instead of this type.
type CipherSuite struct {
ID uint16
Name string
// Supported versions is the list of TLS protocol versions that can
// negotiate this cipher suite.
SupportedVersions []uint16
// Insecure is true if the cipher suite has known security issues
// due to its primitives, design, or implementation.
Insecure bool
}
var (
supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12}
supportedOnlyTLS12 = []uint16{VersionTLS12}
supportedOnlyTLS13 = []uint16{VersionTLS13}
)
// CipherSuites returns a list of cipher suites currently implemented by this
// package, excluding those with security issues, which are returned by
// InsecureCipherSuites.
//
// The list is sorted by ID. Note that the default cipher suites selected by
// this package might depend on logic that can't be captured by a static list,
// and might not match those returned by this function.
func CipherSuites() []*CipherSuite {
return []*CipherSuite{
{TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
{TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
{TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false},
{TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false},
{TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
}
}
// InsecureCipherSuites returns a list of cipher suites currently implemented by
// this package and which have security issues.
//
// Most applications should not use the cipher suites in this list, and should
// only use those returned by CipherSuites.
func InsecureCipherSuites() []*CipherSuite {
// This list includes RC4, CBC_SHA256, and 3DES cipher suites. See
// cipherSuitesPreferenceOrder for details.
return []*CipherSuite{
{TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true},
{TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
}
}
// CipherSuiteName returns the standard name for the passed cipher suite ID
// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation
// of the ID value if the cipher suite is not implemented by this package.
func CipherSuiteName(id uint16) string {
for _, c := range CipherSuites() {
if c.ID == id {
return c.Name
}
}
for _, c := range InsecureCipherSuites() {
if c.ID == id {
return c.Name
}
}
return fmt.Sprintf("0x%04X", id)
}
const (
// suiteECDHE indicates that the cipher suite involves elliptic curve
// Diffie-Hellman. This means that it should only be selected when the
// client indicates that it supports ECC with a curve and point format
// that we're happy with.
suiteECDHE = 1 << iota
// suiteECSign indicates that the cipher suite involves an ECDSA or
// EdDSA signature and therefore may only be selected when the server's
// certificate is ECDSA or EdDSA. If this is not set then the cipher suite
// is RSA based.
suiteECSign
// suiteTLS12 indicates that the cipher suite should only be advertised
// and accepted when using TLS 1.2.
suiteTLS12
// suiteSHA384 indicates that the cipher suite uses SHA384 as the
// handshake hash.
suiteSHA384
)
// A cipherSuite is a TLS 1.01.2 cipher suite, and defines the key exchange
// mechanism, as well as the cipher+MAC pair or the AEAD.
type cipherSuite struct {
id uint16
// the lengths, in bytes, of the key material needed for each component.
keyLen int
macLen int
ivLen int
ka func(version uint16) keyAgreement
// flags is a bitmask of the suite* values, above.
flags int
cipher func(key, iv []byte, isRead bool) any
mac func(key []byte) hash.Hash
aead func(key, fixedNonce []byte) aead
}
var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order doesn't matter.
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305},
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305},
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM},
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM},
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12, cipherAES, macSHA256, nil},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, cipherAES, macSHA256, nil},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil},
{TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM},
{TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil},
{TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
{TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil},
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil},
{TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherRC4, macSHA1, nil},
}
// selectCipherSuite returns the first TLS 1.01.2 cipher suite from ids which
// is also in supportedIDs and passes the ok filter.
func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite {
for _, id := range ids {
candidate := cipherSuiteByID(id)
if candidate == nil || !ok(candidate) {
continue
}
for _, suppID := range supportedIDs {
if id == suppID {
return candidate
}
}
}
return nil
}
// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash
// algorithm to be used with HKDF. See RFC 8446, Appendix B.4.
type cipherSuiteTLS13 struct {
id uint16
keyLen int
aead func(key, fixedNonce []byte) aead
hash crypto.Hash
}
var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map.
{TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256},
{TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256},
{TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384},
}
// cipherSuitesPreferenceOrder is the order in which we'll select (on the
// server) or advertise (on the client) TLS 1.01.2 cipher suites.
//
// Cipher suites are filtered but not reordered based on the application and
// peer's preferences, meaning we'll never select a suite lower in this list if
// any higher one is available. This makes it more defensible to keep weaker
// cipher suites enabled, especially on the server side where we get the last
// word, since there are no known downgrade attacks on cipher suites selection.
//
// The list is sorted by applying the following priority rules, stopping at the
// first (most important) applicable one:
//
// - Anything else comes before RC4
//
// RC4 has practically exploitable biases. See https://www.rc4nomore.com.
//
// - Anything else comes before CBC_SHA256
//
// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13
// countermeasures. See http://www.isg.rhul.ac.uk/tls/Lucky13.html and
// https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
//
// - Anything else comes before 3DES
//
// 3DES has 64-bit blocks, which makes it fundamentally susceptible to
// birthday attacks. See https://sweet32.info.
//
// - ECDHE comes before anything else
//
// Once we got the broken stuff out of the way, the most important
// property a cipher suite can have is forward secrecy. We don't
// implement FFDHE, so that means ECDHE.
//
// - AEADs come before CBC ciphers
//
// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites
// are fundamentally fragile, and suffered from an endless sequence of
// padding oracle attacks. See https://eprint.iacr.org/2015/1129,
// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and
// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/.
//
// - AES comes before ChaCha20
//
// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster
// than ChaCha20Poly1305.
//
// When AES hardware is not available, AES-128-GCM is one or more of: much
// slower, way more complex, and less safe (because not constant time)
// than ChaCha20Poly1305.
//
// We use this list if we think both peers have AES hardware, and
// cipherSuitesPreferenceOrderNoAES otherwise.
//
// - AES-128 comes before AES-256
//
// The only potential advantages of AES-256 are better multi-target
// margins, and hypothetical post-quantum properties. Neither apply to
// TLS, and AES-256 is slower due to its four extra rounds (which don't
// contribute to the advantages above).
//
// - ECDSA comes before RSA
//
// The relative order of ECDSA and RSA cipher suites doesn't matter,
// as they depend on the certificate. Pick one to get a stable order.
var cipherSuitesPreferenceOrder = []uint16{
// AEADs w/ ECDHE
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
// CBC w/ ECDHE
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
// AEADs w/o ECDHE
TLS_RSA_WITH_AES_128_GCM_SHA256,
TLS_RSA_WITH_AES_256_GCM_SHA384,
// CBC w/o ECDHE
TLS_RSA_WITH_AES_128_CBC_SHA,
TLS_RSA_WITH_AES_256_CBC_SHA,
// 3DES
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_RSA_WITH_3DES_EDE_CBC_SHA,
// CBC_SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
TLS_RSA_WITH_AES_128_CBC_SHA256,
// RC4
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
TLS_RSA_WITH_RC4_128_SHA,
}
var cipherSuitesPreferenceOrderNoAES = []uint16{
// ChaCha20Poly1305
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
// AES-GCM w/ ECDHE
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
// The rest of cipherSuitesPreferenceOrder.
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
TLS_RSA_WITH_AES_128_GCM_SHA256,
TLS_RSA_WITH_AES_256_GCM_SHA384,
TLS_RSA_WITH_AES_128_CBC_SHA,
TLS_RSA_WITH_AES_256_CBC_SHA,
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
TLS_RSA_WITH_AES_128_CBC_SHA256,
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
TLS_RSA_WITH_RC4_128_SHA,
}
// disabledCipherSuites are not used unless explicitly listed in
// Config.CipherSuites. They MUST be at the end of cipherSuitesPreferenceOrder.
var disabledCipherSuites = []uint16{
// CBC_SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
TLS_RSA_WITH_AES_128_CBC_SHA256,
// RC4
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
TLS_RSA_WITH_RC4_128_SHA,
}
var (
defaultCipherSuitesLen = len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites)
defaultCipherSuites = cipherSuitesPreferenceOrder[:defaultCipherSuitesLen]
)
// defaultCipherSuitesTLS13 is also the preference order, since there are no
// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as
// cipherSuitesPreferenceOrder applies.
var defaultCipherSuitesTLS13 = []uint16{
TLS_AES_128_GCM_SHA256,
TLS_AES_256_GCM_SHA384,
TLS_CHACHA20_POLY1305_SHA256,
}
var defaultCipherSuitesTLS13NoAES = []uint16{
TLS_CHACHA20_POLY1305_SHA256,
TLS_AES_128_GCM_SHA256,
TLS_AES_256_GCM_SHA384,
}
var (
hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ
hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
// Keep in sync with crypto/aes/cipher_s390x.go.
hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR &&
(cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM)
hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 ||
runtime.GOARCH == "arm64" && hasGCMAsmARM64 ||
runtime.GOARCH == "s390x" && hasGCMAsmS390X
)
var aesgcmCiphers = map[uint16]bool{
// TLS 1.2
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true,
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true,
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true,
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true,
// TLS 1.3
TLS_AES_128_GCM_SHA256: true,
TLS_AES_256_GCM_SHA384: true,
}
var nonAESGCMAEADCiphers = map[uint16]bool{
// TLS 1.2
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true,
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true,
// TLS 1.3
TLS_CHACHA20_POLY1305_SHA256: true,
}
// aesgcmPreferred returns whether the first known cipher in the preference list
// is an AES-GCM cipher, implying the peer has hardware support for it.
func aesgcmPreferred(ciphers []uint16) bool {
for _, cID := range ciphers {
if c := cipherSuiteByID(cID); c != nil {
return aesgcmCiphers[cID]
}
if c := cipherSuiteTLS13ByID(cID); c != nil {
return aesgcmCiphers[cID]
}
}
return false
}
func cipherRC4(key, iv []byte, isRead bool) any {
cipher, _ := rc4.NewCipher(key)
return cipher
}
func cipher3DES(key, iv []byte, isRead bool) any {
block, _ := des.NewTripleDESCipher(key)
if isRead {
return cipher.NewCBCDecrypter(block, iv)
}
return cipher.NewCBCEncrypter(block, iv)
}
func cipherAES(key, iv []byte, isRead bool) any {
block, _ := aes.NewCipher(key)
if isRead {
return cipher.NewCBCDecrypter(block, iv)
}
return cipher.NewCBCEncrypter(block, iv)
}
// macSHA1 returns a SHA-1 based constant time MAC.
func macSHA1(key []byte) hash.Hash {
h := sha1.New
h = newConstantTimeHash(h)
return hmac.New(h, key)
}
// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and
// is currently only used in disabled-by-default cipher suites.
func macSHA256(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}
type aead interface {
cipher.AEAD
// explicitNonceLen returns the number of bytes of explicit nonce
// included in each record. This is eight for older AEADs and
// zero for modern ones.
explicitNonceLen() int
}
const (
aeadNonceLength = 12
noncePrefixLength = 4
)
// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to
// each call.
type prefixNonceAEAD struct {
// nonce contains the fixed part of the nonce in the first four bytes.
nonce [aeadNonceLength]byte
aead cipher.AEAD
}
func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength }
func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() }
func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
copy(f.nonce[4:], nonce)
return f.aead.Seal(out, f.nonce[:], plaintext, additionalData)
}
func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
copy(f.nonce[4:], nonce)
return f.aead.Open(out, f.nonce[:], ciphertext, additionalData)
}
// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
// before each call.
type xorNonceAEAD struct {
nonceMask [aeadNonceLength]byte
aead cipher.AEAD
}
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result
}
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result, err
}
func aeadAESGCM(key, noncePrefix []byte) aead {
if len(noncePrefix) != noncePrefixLength {
panic("tls: internal error: wrong nonce length")
}
aes, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
var aead cipher.AEAD
aead, err = cipher.NewGCM(aes)
if err != nil {
panic(err)
}
ret := &prefixNonceAEAD{aead: aead}
copy(ret.nonce[:], noncePrefix)
return ret
}
func aeadAESGCMTLS13(key, nonceMask []byte) aead {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aes, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(aes)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
func aeadChaCha20Poly1305(key, nonceMask []byte) aead {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aead, err := chacha20poly1305.New(key)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
type constantTimeHash interface {
hash.Hash
ConstantTimeSum(b []byte) []byte
}
// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces
// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC.
type cthWrapper struct {
h constantTimeHash
}
func (c *cthWrapper) Size() int { return c.h.Size() }
func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() }
func (c *cthWrapper) Reset() { c.h.Reset() }
func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) }
func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) }
func newConstantTimeHash(h func() hash.Hash) func() hash.Hash {
return func() hash.Hash {
return &cthWrapper{h().(constantTimeHash)}
}
}
// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3.
func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte {
h.Reset()
h.Write(seq)
h.Write(header)
h.Write(data)
res := h.Sum(out)
if extra != nil {
h.Write(extra)
}
return res
}
func rsaKA(version uint16) keyAgreement {
return rsaKeyAgreement{}
}
func ecdheECDSAKA(version uint16) keyAgreement {
return &ecdheKeyAgreement{
isRSA: false,
version: version,
}
}
func ecdheRSAKA(version uint16) keyAgreement {
return &ecdheKeyAgreement{
isRSA: true,
version: version,
}
}
// mutualCipherSuite returns a cipherSuite given a list of supported
// ciphersuites and the id requested by the peer.
func mutualCipherSuite(have []uint16, want uint16) *cipherSuite {
for _, id := range have {
if id == want {
return cipherSuiteByID(id)
}
}
return nil
}
func cipherSuiteByID(id uint16) *cipherSuite {
for _, cipherSuite := range cipherSuites {
if cipherSuite.id == id {
return cipherSuite
}
}
return nil
}
func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 {
for _, id := range have {
if id == want {
return cipherSuiteTLS13ByID(id)
}
}
return nil
}
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 {
for _, cipherSuite := range cipherSuitesTLS13 {
if cipherSuite.id == id {
return cipherSuite
}
}
return nil
}
// A list of cipher suite IDs that are, or have been, implemented by this
// package.
//
// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml
const (
// TLS 1.0 - 1.2 cipher suites.
TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005
TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a
TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f
TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035
TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c
TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c
TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a
TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9
// TLS 1.3 cipher suites.
TLS_AES_128_GCM_SHA256 uint16 = 0x1301
TLS_AES_256_GCM_SHA384 uint16 = 0x1302
TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303
// TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator
// that the client is doing version fallback. See RFC 7507.
TLS_FALLBACK_SCSV uint16 = 0x5600
// Legacy names for the corresponding cipher suites with the correct _SHA256
// suffix, retained for backward compatibility.
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
)

1454
vendor/github.com/quic-go/qtls-go1-20/common.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

1643
vendor/github.com/quic-go/qtls-go1-20/conn.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

1130
vendor/github.com/quic-go/qtls-go1-20/handshake_client.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,782 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"bytes"
"context"
"crypto"
"crypto/ecdh"
"crypto/hmac"
"crypto/rsa"
"encoding/binary"
"errors"
"hash"
"time"
"golang.org/x/crypto/cryptobyte"
)
type clientHandshakeStateTLS13 struct {
c *Conn
ctx context.Context
serverHello *serverHelloMsg
hello *clientHelloMsg
ecdheKey *ecdh.PrivateKey
session *clientSessionState
earlySecret []byte
binderKey []byte
certReq *certificateRequestMsgTLS13
usingPSK bool
sentDummyCCS bool
suite *cipherSuiteTLS13
transcript hash.Hash
masterSecret []byte
trafficSecret []byte // client_application_traffic_secret_0
}
// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheKey, and,
// optionally, hs.session, hs.earlySecret and hs.binderKey to be set.
func (hs *clientHandshakeStateTLS13) handshake() error {
c := hs.c
if needFIPS() {
return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode")
}
// The server must not select TLS 1.3 in a renegotiation. See RFC 8446,
// sections 4.1.2 and 4.1.3.
if c.handshakes > 0 {
c.sendAlert(alertProtocolVersion)
return errors.New("tls: server selected TLS 1.3 in a renegotiation")
}
// Consistency check on the presence of a keyShare and its parameters.
if hs.ecdheKey == nil || len(hs.hello.keyShares) != 1 {
return c.sendAlert(alertInternalError)
}
if err := hs.checkServerHelloOrHRR(); err != nil {
return err
}
hs.transcript = hs.suite.hash.New()
if err := transcriptMsg(hs.hello, hs.transcript); err != nil {
return err
}
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
if err := hs.processHelloRetryRequest(); err != nil {
return err
}
}
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
return err
}
c.buffering = true
if err := hs.processServerHello(); err != nil {
return err
}
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
if err := hs.establishHandshakeKeys(); err != nil {
return err
}
if err := hs.readServerParameters(); err != nil {
return err
}
if err := hs.readServerCertificate(); err != nil {
return err
}
if err := hs.readServerFinished(); err != nil {
return err
}
if err := hs.sendClientCertificate(); err != nil {
return err
}
if err := hs.sendClientFinished(); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
c.isHandshakeComplete.Store(true)
return nil
}
// checkServerHelloOrHRR does validity checks that apply to both ServerHello and
// HelloRetryRequest messages. It sets hs.suite.
func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error {
c := hs.c
if hs.serverHello.supportedVersion == 0 {
c.sendAlert(alertMissingExtension)
return errors.New("tls: server selected TLS 1.3 using the legacy version field")
}
if hs.serverHello.supportedVersion != VersionTLS13 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid version after a HelloRetryRequest")
}
if hs.serverHello.vers != VersionTLS12 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an incorrect legacy version")
}
if hs.serverHello.ocspStapling ||
hs.serverHello.ticketSupported ||
hs.serverHello.secureRenegotiationSupported ||
len(hs.serverHello.secureRenegotiation) != 0 ||
len(hs.serverHello.alpnProtocol) != 0 ||
len(hs.serverHello.scts) != 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3")
}
if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server did not echo the legacy session ID")
}
if hs.serverHello.compressionMethod != compressionNone {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported compression format")
}
selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite)
if hs.suite != nil && selectedSuite != hs.suite {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server changed cipher suite after a HelloRetryRequest")
}
if selectedSuite == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server chose an unconfigured cipher suite")
}
hs.suite = selectedSuite
c.cipherSuite = hs.suite.id
return nil
}
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
if hs.c.quic != nil {
return nil
}
if hs.sentDummyCCS {
return nil
}
hs.sentDummyCCS = true
return hs.c.writeChangeCipherRecord()
}
// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and
// resends hs.hello, and reads the new ServerHello into hs.serverHello.
func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
c := hs.c
// The first ClientHello gets double-hashed into the transcript upon a
// HelloRetryRequest. (The idea is that the server might offload transcript
// storage to the client in the cookie.) See RFC 8446, Section 4.4.1.
chHash := hs.transcript.Sum(nil)
hs.transcript.Reset()
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
hs.transcript.Write(chHash)
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
return err
}
// The only HelloRetryRequest extensions we support are key_share and
// cookie, and clients must abort the handshake if the HRR would not result
// in any change in the ClientHello.
if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an unnecessary HelloRetryRequest message")
}
if hs.serverHello.cookie != nil {
hs.hello.cookie = hs.serverHello.cookie
}
if hs.serverHello.serverShare.group != 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: received malformed key_share extension")
}
// If the server sent a key_share extension selecting a group, ensure it's
// a group we advertised but did not send a key share for, and send a key
// share for it this time.
if curveID := hs.serverHello.selectedGroup; curveID != 0 {
curveOK := false
for _, id := range hs.hello.supportedCurves {
if id == curveID {
curveOK = true
break
}
}
if !curveOK {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); sentID == curveID {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
}
if _, ok := curveForCurveID(curveID); !ok {
c.sendAlert(alertInternalError)
return errors.New("tls: CurvePreferences includes unsupported curve")
}
key, err := generateECDHEKey(c.config.rand(), curveID)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.ecdheKey = key
hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
}
hs.hello.raw = nil
if len(hs.hello.pskIdentities) > 0 {
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
if pskSuite == nil {
return c.sendAlert(alertInternalError)
}
if pskSuite.hash == hs.suite.hash {
// Update binders and obfuscated_ticket_age.
ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond)
hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd
transcript := hs.suite.hash.New()
transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
transcript.Write(chHash)
if err := transcriptMsg(hs.serverHello, transcript); err != nil {
return err
}
helloBytes, err := hs.hello.marshalWithoutBinders()
if err != nil {
return err
}
transcript.Write(helloBytes)
pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)}
if err := hs.hello.updateBinders(pskBinders); err != nil {
return err
}
} else {
// Server selected a cipher suite incompatible with the PSK.
hs.hello.pskIdentities = nil
hs.hello.pskBinders = nil
}
}
if hs.hello.earlyData {
hs.hello.earlyData = false
c.quicRejectedEarlyData()
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
return err
}
// serverHelloMsg is not included in the transcript
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
serverHello, ok := msg.(*serverHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(serverHello, msg)
}
hs.serverHello = serverHello
if err := hs.checkServerHelloOrHRR(); err != nil {
return err
}
return nil
}
func (hs *clientHandshakeStateTLS13) processServerHello() error {
c := hs.c
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: server sent two HelloRetryRequest messages")
}
if len(hs.serverHello.cookie) != 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent a cookie in a normal ServerHello")
}
if hs.serverHello.selectedGroup != 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: malformed key_share extension")
}
if hs.serverHello.serverShare.group == 0 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server did not send a key share")
}
if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); hs.serverHello.serverShare.group != sentID {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
if !hs.serverHello.selectedIdentityPresent {
return nil
}
if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid PSK")
}
if len(hs.hello.pskIdentities) != 1 || hs.session == nil {
return c.sendAlert(alertInternalError)
}
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
if pskSuite == nil {
return c.sendAlert(alertInternalError)
}
if pskSuite.hash != hs.suite.hash {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid PSK and cipher suite pair")
}
hs.usingPSK = true
c.didResume = true
c.peerCertificates = hs.session.serverCertificates
c.verifiedChains = hs.session.verifiedChains
c.ocspResponse = hs.session.ocspResponse
c.scts = hs.session.scts
return nil
}
func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
c := hs.c
peerKey, err := hs.ecdheKey.Curve().NewPublicKey(hs.serverHello.serverShare.data)
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid server key share")
}
sharedKey, err := hs.ecdheKey.ECDH(peerKey)
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid server key share")
}
earlySecret := hs.earlySecret
if !hs.usingPSK {
earlySecret = hs.suite.extract(nil, nil)
}
handshakeSecret := hs.suite.extract(sharedKey,
hs.suite.deriveSecret(earlySecret, "derived", nil))
clientSecret := hs.suite.deriveSecret(handshakeSecret,
clientHandshakeTrafficLabel, hs.transcript)
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
serverSecret := hs.suite.deriveSecret(handshakeSecret,
serverHandshakeTrafficLabel, hs.transcript)
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
if c.quic != nil {
if c.hand.Len() != 0 {
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
}
err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.masterSecret = hs.suite.extract(nil,
hs.suite.deriveSecret(handshakeSecret, "derived", nil))
return nil
}
func (hs *clientHandshakeStateTLS13) readServerParameters() error {
c := hs.c
msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
encryptedExtensions, ok := msg.(*encryptedExtensionsMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(encryptedExtensions, msg)
}
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol, c.quic != nil); err != nil {
// RFC 8446 specifies that no_application_protocol is sent by servers, but
// does not specify how clients handle the selection of an incompatible protocol.
// RFC 9001 Section 8.1 specifies that QUIC clients send no_application_protocol
// in this case. Always sending no_application_protocol seems reasonable.
c.sendAlert(alertNoApplicationProtocol)
return err
}
c.clientProtocol = encryptedExtensions.alpnProtocol
if c.quic != nil {
if encryptedExtensions.quicTransportParameters == nil {
// RFC 9001 Section 8.2.
c.sendAlert(alertMissingExtension)
return errors.New("tls: server did not send a quic_transport_parameters extension")
}
c.quicSetTransportParameters(encryptedExtensions.quicTransportParameters)
} else {
if encryptedExtensions.quicTransportParameters != nil {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent an unexpected quic_transport_parameters extension")
}
}
if hs.hello.earlyData && !encryptedExtensions.earlyData {
c.quicRejectedEarlyData()
}
return nil
}
func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
c := hs.c
// Either a PSK or a certificate is always used, but not both.
// See RFC 8446, Section 4.1.1.
if hs.usingPSK {
// Make sure the connection is still being verified whether or not this
// is a resumption. Resumptions currently don't reverify certificates so
// they don't call verifyServerCertificate. See Issue 31641.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
certReq, ok := msg.(*certificateRequestMsgTLS13)
if ok {
hs.certReq = certReq
msg, err = c.readHandshake(hs.transcript)
if err != nil {
return err
}
}
certMsg, ok := msg.(*certificateMsgTLS13)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
if len(certMsg.certificate.Certificate) == 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: received empty certificates message")
}
c.scts = certMsg.certificate.SignedCertificateTimestamps
c.ocspResponse = certMsg.certificate.OCSPStaple
if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil {
return err
}
// certificateVerifyMsg is included in the transcript, but not until
// after we verify the handshake signature, since the state before
// this message was sent is used.
msg, err = c.readHandshake(nil)
if err != nil {
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
// See RFC 8446, Section 4.4.3.
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: certificate used with invalid signature algorithm")
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: certificate used with invalid signature algorithm")
}
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
if err := transcriptMsg(certVerify, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *clientHandshakeStateTLS13) readServerFinished() error {
c := hs.c
// finishedMsg is included in the transcript, but not until after we
// check the client version, since the state before this message was
// sent is used during verification.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
finished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(finished, msg)
}
expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
if !hmac.Equal(expectedMAC, finished.verifyData) {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid server finished hash")
}
if err := transcriptMsg(finished, hs.transcript); err != nil {
return err
}
// Derive secrets that take context through the server Finished.
hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret,
clientApplicationTrafficLabel, hs.transcript)
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
serverApplicationTrafficLabel, hs.transcript)
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
return nil
}
func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
c := hs.c
if hs.certReq == nil {
return nil
}
cert, err := c.getClientCertificate(toCertificateRequestInfo(&certificateRequestInfo{
AcceptableCAs: hs.certReq.certificateAuthorities,
SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
Version: c.vers,
ctx: hs.ctx,
}))
if err != nil {
return err
}
certMsg := new(certificateMsgTLS13)
certMsg.certificate = *cert
certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0
certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0
if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
return err
}
// If we sent an empty certificate message, skip the CertificateVerify.
if len(cert.Certificate) == 0 {
return nil
}
certVerifyMsg := new(certificateVerifyMsg)
certVerifyMsg.hasSignatureAlgorithm = true
certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms)
if err != nil {
// getClientCertificate returned a certificate incompatible with the
// CertificateRequestInfo supported signature algorithms.
c.sendAlert(alertHandshakeFailure)
return err
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
if err != nil {
c.sendAlert(alertInternalError)
return errors.New("tls: failed to sign handshake: " + err.Error())
}
certVerifyMsg.signature = sig
if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
c := hs.c
finished := &finishedMsg{
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
}
if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
return err
}
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
resumptionLabel, hs.transcript)
}
if c.quic != nil {
if c.hand.Len() != 0 {
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, hs.trafficSecret)
}
return nil
}
func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
if !c.isClient {
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: received new session ticket from a client")
}
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
return nil
}
// See RFC 8446, Section 4.6.1.
if msg.lifetime == 0 {
return nil
}
lifetime := time.Duration(msg.lifetime) * time.Second
if lifetime > maxSessionTicketLifetime {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: received a session ticket with invalid lifetime")
}
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
if cipherSuite == nil || c.resumptionSecret == nil {
return c.sendAlert(alertInternalError)
}
// We need to save the max_early_data_size that the server sent us, in order
// to decide if we're going to try 0-RTT with this ticket.
// However, at the same time, the qtls.ClientSessionTicket needs to be equal to
// the tls.ClientSessionTicket, so we can't just add a new field to the struct.
// We therefore abuse the nonce field (which is a byte slice)
nonceWithEarlyData := make([]byte, len(msg.nonce)+4)
binary.BigEndian.PutUint32(nonceWithEarlyData, msg.maxEarlyData)
copy(nonceWithEarlyData[4:], msg.nonce)
var appData []byte
if c.extraConfig != nil && c.extraConfig.GetAppDataForSessionState != nil {
appData = c.extraConfig.GetAppDataForSessionState()
}
var b cryptobyte.Builder
b.AddUint16(clientSessionStateVersion) // revision
b.AddUint32(msg.maxEarlyData)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(appData)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(msg.nonce)
})
// Save the resumption_master_secret and nonce instead of deriving the PSK
// to do the least amount of work on NewSessionTicket messages before we
// know if the ticket will be used. Forward secrecy of resumed connections
// is guaranteed by the requirement for pskModeDHE.
session := &clientSessionState{
sessionTicket: msg.label,
vers: c.vers,
cipherSuite: c.cipherSuite,
masterSecret: c.resumptionSecret,
serverCertificates: c.peerCertificates,
verifiedChains: c.verifiedChains,
receivedAt: c.config.time(),
nonce: b.BytesOrPanic(),
useBy: c.config.time().Add(lifetime),
ageAdd: msg.ageAdd,
ocspResponse: c.ocspResponse,
scts: c.scts,
}
cacheKey := c.clientSessionCacheKey()
if cacheKey != "" {
c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(session))
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,899 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/subtle"
"crypto/x509"
"errors"
"fmt"
"hash"
"io"
"time"
)
// serverHandshakeState contains details of a server handshake in progress.
// It's discarded once the handshake has completed.
type serverHandshakeState struct {
c *Conn
ctx context.Context
clientHello *clientHelloMsg
hello *serverHelloMsg
suite *cipherSuite
ecdheOk bool
ecSignOk bool
rsaDecryptOk bool
rsaSignOk bool
sessionState *sessionState
finishedHash finishedHash
masterSecret []byte
cert *Certificate
}
// serverHandshake performs a TLS handshake as a server.
func (c *Conn) serverHandshake(ctx context.Context) error {
clientHello, err := c.readClientHello(ctx)
if err != nil {
return err
}
if c.vers == VersionTLS13 {
hs := serverHandshakeStateTLS13{
c: c,
ctx: ctx,
clientHello: clientHello,
}
return hs.handshake()
}
hs := serverHandshakeState{
c: c,
ctx: ctx,
clientHello: clientHello,
}
return hs.handshake()
}
func (hs *serverHandshakeState) handshake() error {
c := hs.c
if err := hs.processClientHello(); err != nil {
return err
}
// For an overview of TLS handshaking, see RFC 5246, Section 7.3.
c.buffering = true
if hs.checkForResumption() {
// The client has included a session ticket and so we do an abbreviated handshake.
c.didResume = true
if err := hs.doResumeHandshake(); err != nil {
return err
}
if err := hs.establishKeys(); err != nil {
return err
}
if err := hs.sendSessionTicket(); err != nil {
return err
}
if err := hs.sendFinished(c.serverFinished[:]); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
c.clientFinishedIsFirst = false
if err := hs.readFinished(nil); err != nil {
return err
}
} else {
// The client didn't include a session ticket, or it wasn't
// valid so we do a full handshake.
if err := hs.pickCipherSuite(); err != nil {
return err
}
if err := hs.doFullHandshake(); err != nil {
return err
}
if err := hs.establishKeys(); err != nil {
return err
}
if err := hs.readFinished(c.clientFinished[:]); err != nil {
return err
}
c.clientFinishedIsFirst = true
c.buffering = true
if err := hs.sendSessionTicket(); err != nil {
return err
}
if err := hs.sendFinished(nil); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
}
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
c.isHandshakeComplete.Store(true)
return nil
}
// readClientHello reads a ClientHello message and selects the protocol version.
func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
// clientHelloMsg is included in the transcript, but we haven't initialized
// it yet. The respective handshake functions will record it themselves.
msg, err := c.readHandshake(nil)
if err != nil {
return nil, err
}
clientHello, ok := msg.(*clientHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return nil, unexpectedMessageError(clientHello, msg)
}
var configForClient *config
originalConfig := c.config
if c.config.GetConfigForClient != nil {
chi := newClientHelloInfo(ctx, c, clientHello)
if cfc, err := c.config.GetConfigForClient(chi); err != nil {
c.sendAlert(alertInternalError)
return nil, err
} else if cfc != nil {
configForClient = fromConfig(cfc)
c.config = configForClient
}
}
c.ticketKeys = originalConfig.ticketKeys(configForClient)
clientVersions := clientHello.supportedVersions
if len(clientHello.supportedVersions) == 0 {
clientVersions = supportedVersionsFromMax(clientHello.vers)
}
c.vers, ok = c.config.mutualVersion(roleServer, clientVersions)
if !ok {
c.sendAlert(alertProtocolVersion)
return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions)
}
c.haveVers = true
c.in.version = c.vers
c.out.version = c.vers
return clientHello, nil
}
func (hs *serverHandshakeState) processClientHello() error {
c := hs.c
hs.hello = new(serverHelloMsg)
hs.hello.vers = c.vers
foundCompression := false
// We only support null compression, so check that the client offered it.
for _, compression := range hs.clientHello.compressionMethods {
if compression == compressionNone {
foundCompression = true
break
}
}
if !foundCompression {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: client does not support uncompressed connections")
}
hs.hello.random = make([]byte, 32)
serverRandom := hs.hello.random
// Downgrade protection canaries. See RFC 8446, Section 4.1.3.
maxVers := c.config.maxSupportedVersion(roleServer)
if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary {
if c.vers == VersionTLS12 {
copy(serverRandom[24:], downgradeCanaryTLS12)
} else {
copy(serverRandom[24:], downgradeCanaryTLS11)
}
serverRandom = serverRandom[:24]
}
_, err := io.ReadFull(c.config.rand(), serverRandom)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
if len(hs.clientHello.secureRenegotiation) != 0 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: initial handshake had non-empty renegotiation extension")
}
hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported
hs.hello.compressionMethod = compressionNone
if len(hs.clientHello.serverName) > 0 {
c.serverName = hs.clientHello.serverName
}
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, false)
if err != nil {
c.sendAlert(alertNoApplicationProtocol)
return err
}
hs.hello.alpnProtocol = selectedProto
c.clientProtocol = selectedProto
hs.cert, err = c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello))
if err != nil {
if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName)
} else {
c.sendAlert(alertInternalError)
}
return err
}
if hs.clientHello.scts {
hs.hello.scts = hs.cert.SignedCertificateTimestamps
}
hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints)
if hs.ecdheOk && len(hs.clientHello.supportedPoints) > 0 {
// Although omitting the ec_point_formats extension is permitted, some
// old OpenSSL version will refuse to handshake if not present.
//
// Per RFC 4492, section 5.1.2, implementations MUST support the
// uncompressed point format. See golang.org/issue/31943.
hs.hello.supportedPoints = []uint8{pointFormatUncompressed}
}
if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok {
switch priv.Public().(type) {
case *ecdsa.PublicKey:
hs.ecSignOk = true
case ed25519.PublicKey:
hs.ecSignOk = true
case *rsa.PublicKey:
hs.rsaSignOk = true
default:
c.sendAlert(alertInternalError)
return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public())
}
}
if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok {
switch priv.Public().(type) {
case *rsa.PublicKey:
hs.rsaDecryptOk = true
default:
c.sendAlert(alertInternalError)
return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public())
}
}
return nil
}
// negotiateALPN picks a shared ALPN protocol that both sides support in server
// preference order. If ALPN is not configured or the peer doesn't support it,
// it returns "" and no error.
func negotiateALPN(serverProtos, clientProtos []string, quic bool) (string, error) {
if len(serverProtos) == 0 || len(clientProtos) == 0 {
if quic && len(serverProtos) != 0 {
// RFC 9001, Section 8.1
return "", fmt.Errorf("tls: client did not request an application protocol")
}
return "", nil
}
var http11fallback bool
for _, s := range serverProtos {
for _, c := range clientProtos {
if s == c {
return s, nil
}
if s == "h2" && c == "http/1.1" {
http11fallback = true
}
}
}
// As a special case, let http/1.1 clients connect to h2 servers as if they
// didn't support ALPN. We used not to enforce protocol overlap, so over
// time a number of HTTP servers were configured with only "h2", but
// expected to accept connections from "http/1.1" clients. See Issue 46310.
if http11fallback {
return "", nil
}
return "", fmt.Errorf("tls: client requested unsupported application protocols (%s)", clientProtos)
}
// supportsECDHE returns whether ECDHE key exchanges can be used with this
// pre-TLS 1.3 client.
func supportsECDHE(c *config, supportedCurves []CurveID, supportedPoints []uint8) bool {
supportsCurve := false
for _, curve := range supportedCurves {
if c.supportsCurve(curve) {
supportsCurve = true
break
}
}
supportsPointFormat := false
for _, pointFormat := range supportedPoints {
if pointFormat == pointFormatUncompressed {
supportsPointFormat = true
break
}
}
// Per RFC 8422, Section 5.1.2, if the Supported Point Formats extension is
// missing, uncompressed points are supported. If supportedPoints is empty,
// the extension must be missing, as an empty extension body is rejected by
// the parser. See https://go.dev/issue/49126.
if len(supportedPoints) == 0 {
supportsPointFormat = true
}
return supportsCurve && supportsPointFormat
}
func (hs *serverHandshakeState) pickCipherSuite() error {
c := hs.c
preferenceOrder := cipherSuitesPreferenceOrder
if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) {
preferenceOrder = cipherSuitesPreferenceOrderNoAES
}
configCipherSuites := c.config.cipherSuites()
preferenceList := make([]uint16, 0, len(configCipherSuites))
for _, suiteID := range preferenceOrder {
for _, id := range configCipherSuites {
if id == suiteID {
preferenceList = append(preferenceList, id)
break
}
}
}
hs.suite = selectCipherSuite(preferenceList, hs.clientHello.cipherSuites, hs.cipherSuiteOk)
if hs.suite == nil {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: no cipher suite supported by both client and server")
}
c.cipherSuite = hs.suite.id
for _, id := range hs.clientHello.cipherSuites {
if id == TLS_FALLBACK_SCSV {
// The client is doing a fallback connection. See RFC 7507.
if hs.clientHello.vers < c.config.maxSupportedVersion(roleServer) {
c.sendAlert(alertInappropriateFallback)
return errors.New("tls: client using inappropriate protocol fallback")
}
break
}
}
return nil
}
func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool {
if c.flags&suiteECDHE != 0 {
if !hs.ecdheOk {
return false
}
if c.flags&suiteECSign != 0 {
if !hs.ecSignOk {
return false
}
} else if !hs.rsaSignOk {
return false
}
} else if !hs.rsaDecryptOk {
return false
}
if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 {
return false
}
return true
}
// checkForResumption reports whether we should perform resumption on this connection.
func (hs *serverHandshakeState) checkForResumption() bool {
c := hs.c
if c.config.SessionTicketsDisabled {
return false
}
plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket)
if plaintext == nil {
return false
}
hs.sessionState = &sessionState{usedOldKey: usedOldKey}
ok := hs.sessionState.unmarshal(plaintext)
if !ok {
return false
}
createdAt := time.Unix(int64(hs.sessionState.createdAt), 0)
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
return false
}
// Never resume a session for a different TLS version.
if c.vers != hs.sessionState.vers {
return false
}
cipherSuiteOk := false
// Check that the client is still offering the ciphersuite in the session.
for _, id := range hs.clientHello.cipherSuites {
if id == hs.sessionState.cipherSuite {
cipherSuiteOk = true
break
}
}
if !cipherSuiteOk {
return false
}
// Check that we also support the ciphersuite from the session.
hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite},
c.config.cipherSuites(), hs.cipherSuiteOk)
if hs.suite == nil {
return false
}
sessionHasClientCerts := len(hs.sessionState.certificates) != 0
needClientCerts := requiresClientCert(c.config.ClientAuth)
if needClientCerts && !sessionHasClientCerts {
return false
}
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
return false
}
return true
}
func (hs *serverHandshakeState) doResumeHandshake() error {
c := hs.c
hs.hello.cipherSuite = hs.suite.id
c.cipherSuite = hs.suite.id
// We echo the client's session ID in the ServerHello to let it know
// that we're doing a resumption.
hs.hello.sessionId = hs.clientHello.sessionId
hs.hello.ticketSupported = hs.sessionState.usedOldKey
hs.finishedHash = newFinishedHash(c.vers, hs.suite)
hs.finishedHash.discardHandshakeBuffer()
if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
return err
}
if err := c.processCertsFromClient(Certificate{
Certificate: hs.sessionState.certificates,
}); err != nil {
return err
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
hs.masterSecret = hs.sessionState.masterSecret
return nil
}
func (hs *serverHandshakeState) doFullHandshake() error {
c := hs.c
if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
hs.hello.ocspStapling = true
}
hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled
hs.hello.cipherSuite = hs.suite.id
hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite)
if c.config.ClientAuth == NoClientCert {
// No need to keep a full record of the handshake if client
// certificates won't be used.
hs.finishedHash.discardHandshakeBuffer()
}
if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
return err
}
certMsg := new(certificateMsg)
certMsg.certificates = hs.cert.Certificate
if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil {
return err
}
if hs.hello.ocspStapling {
certStatus := new(certificateStatusMsg)
certStatus.response = hs.cert.OCSPStaple
if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil {
return err
}
}
keyAgreement := hs.suite.ka(c.vers)
skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello)
if err != nil {
c.sendAlert(alertHandshakeFailure)
return err
}
if skx != nil {
if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil {
return err
}
}
var certReq *certificateRequestMsg
if c.config.ClientAuth >= RequestClientCert {
// Request a client certificate
certReq = new(certificateRequestMsg)
certReq.certificateTypes = []byte{
byte(certTypeRSASign),
byte(certTypeECDSASign),
}
if c.vers >= VersionTLS12 {
certReq.hasSignatureAlgorithm = true
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
}
// An empty list of certificateAuthorities signals to
// the client that it may send any certificate in response
// to our request. When we know the CAs we trust, then
// we can send them down, so that the client can choose
// an appropriate certificate to give to us.
if c.config.ClientCAs != nil {
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
}
if _, err := hs.c.writeHandshakeRecord(certReq, &hs.finishedHash); err != nil {
return err
}
}
helloDone := new(serverHelloDoneMsg)
if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
var pub crypto.PublicKey // public key for client auth, if any
msg, err := c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
// If we requested a client certificate, then the client must send a
// certificate message, even if it's empty.
if c.config.ClientAuth >= RequestClientCert {
certMsg, ok := msg.(*certificateMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
if err := c.processCertsFromClient(Certificate{
Certificate: certMsg.certificates,
}); err != nil {
return err
}
if len(certMsg.certificates) != 0 {
pub = c.peerCertificates[0].PublicKey
}
msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
// Get client key exchange
ckx, ok := msg.(*clientKeyExchangeMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(ckx, msg)
}
preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
if err != nil {
c.sendAlert(alertHandshakeFailure)
return err
}
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil {
c.sendAlert(alertInternalError)
return err
}
// If we received a client cert in response to our certificate request message,
// the client will send us a certificateVerifyMsg immediately after the
// clientKeyExchangeMsg. This message is a digest of all preceding
// handshake-layer messages that is signed using the private key corresponding
// to the client's certificate. This allows us to verify that the client is in
// possession of the private key of the certificate.
if len(c.peerCertificates) > 0 {
// certificateVerifyMsg is included in the transcript, but not until
// after we verify the handshake signature, since the state before
// this message was sent is used.
msg, err = c.readHandshake(nil)
if err != nil {
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
var sigType uint8
var sigHash crypto.Hash
if c.vers >= VersionTLS12 {
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client certificate used with invalid signature algorithm")
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub)
if err != nil {
c.sendAlert(alertIllegalParameter)
return err
}
}
signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash)
if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil {
return err
}
}
hs.finishedHash.discardHandshakeBuffer()
return nil
}
func (hs *serverHandshakeState) establishKeys() error {
c := hs.c
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
var clientCipher, serverCipher any
var clientHash, serverHash hash.Hash
if hs.suite.aead == nil {
clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */)
clientHash = hs.suite.mac(clientMAC)
serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */)
serverHash = hs.suite.mac(serverMAC)
} else {
clientCipher = hs.suite.aead(clientKey, clientIV)
serverCipher = hs.suite.aead(serverKey, serverIV)
}
c.in.prepareCipherSpec(c.vers, clientCipher, clientHash)
c.out.prepareCipherSpec(c.vers, serverCipher, serverHash)
return nil
}
func (hs *serverHandshakeState) readFinished(out []byte) error {
c := hs.c
if err := c.readChangeCipherSpec(); err != nil {
return err
}
// finishedMsg is included in the transcript, but not until after we
// check the client version, since the state before this message was
// sent is used during verification.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
clientFinished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(clientFinished, msg)
}
verify := hs.finishedHash.clientSum(hs.masterSecret)
if len(verify) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: client's Finished message is incorrect")
}
if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil {
return err
}
copy(out, verify)
return nil
}
func (hs *serverHandshakeState) sendSessionTicket() error {
// ticketSupported is set in a resumption handshake if the
// ticket from the client was encrypted with an old session
// ticket key and thus a refreshed ticket should be sent.
if !hs.hello.ticketSupported {
return nil
}
c := hs.c
m := new(newSessionTicketMsg)
createdAt := uint64(c.config.time().Unix())
if hs.sessionState != nil {
// If this is re-wrapping an old key, then keep
// the original time it was created.
createdAt = hs.sessionState.createdAt
}
var certsFromClient [][]byte
for _, cert := range c.peerCertificates {
certsFromClient = append(certsFromClient, cert.Raw)
}
state := sessionState{
vers: c.vers,
cipherSuite: hs.suite.id,
createdAt: createdAt,
masterSecret: hs.masterSecret,
certificates: certsFromClient,
}
stateBytes, err := state.marshal()
if err != nil {
return err
}
m.ticket, err = c.encryptTicket(stateBytes)
if err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeState) sendFinished(out []byte) error {
c := hs.c
if err := c.writeChangeCipherRecord(); err != nil {
return err
}
finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil {
return err
}
copy(out, finished.verifyData)
return nil
}
// processCertsFromClient takes a chain of client certificates either from a
// Certificates message or from a sessionState and verifies them. It returns
// the public key of the leaf certificate.
func (c *Conn) processCertsFromClient(certificate Certificate) error {
certificates := certificate.Certificate
certs := make([]*x509.Certificate, len(certificates))
var err error
for i, asn1Data := range certificates {
if certs[i], err = x509.ParseCertificate(asn1Data); err != nil {
c.sendAlert(alertBadCertificate)
return errors.New("tls: failed to parse client certificate: " + err.Error())
}
if certs[i].PublicKeyAlgorithm == x509.RSA && certs[i].PublicKey.(*rsa.PublicKey).N.BitLen() > maxRSAKeySize {
c.sendAlert(alertBadCertificate)
return fmt.Errorf("tls: client sent certificate containing RSA key larger than %d bits", maxRSAKeySize)
}
}
if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) {
c.sendAlert(alertBadCertificate)
return errors.New("tls: client didn't provide a certificate")
}
if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 {
opts := x509.VerifyOptions{
Roots: c.config.ClientCAs,
CurrentTime: c.config.time(),
Intermediates: x509.NewCertPool(),
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
chains, err := certs[0].Verify(opts)
if err != nil {
c.sendAlert(alertBadCertificate)
return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err}
}
c.verifiedChains = chains
}
c.peerCertificates = certs
c.ocspResponse = certificate.OCSPStaple
c.scts = certificate.SignedCertificateTimestamps
if len(certs) > 0 {
switch certs[0].PublicKey.(type) {
case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
default:
c.sendAlert(alertUnsupportedCertificate)
return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey)
}
}
if c.config.VerifyPeerCertificate != nil {
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
func newClientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
supportedVersions := clientHello.supportedVersions
if len(clientHello.supportedVersions) == 0 {
supportedVersions = supportedVersionsFromMax(clientHello.vers)
}
return toClientHelloInfo(&clientHelloInfo{
CipherSuites: clientHello.cipherSuites,
ServerName: clientHello.serverName,
SupportedCurves: clientHello.supportedCurves,
SupportedPoints: clientHello.supportedPoints,
SignatureSchemes: clientHello.supportedSignatureAlgorithms,
SupportedProtos: clientHello.alpnProtocols,
SupportedVersions: supportedVersions,
Conn: c.conn,
config: toConfig(c.config),
ctx: ctx,
})
}

View File

@@ -0,0 +1,979 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"bytes"
"context"
"crypto"
"crypto/hmac"
"crypto/rsa"
"errors"
"hash"
"io"
"time"
)
// maxClientPSKIdentities is the number of client PSK identities the server will
// attempt to validate. It will ignore the rest not to let cheap ClientHello
// messages cause too much work in session ticket decryption attempts.
const maxClientPSKIdentities = 5
type serverHandshakeStateTLS13 struct {
c *Conn
ctx context.Context
clientHello *clientHelloMsg
hello *serverHelloMsg
alpnNegotiationErr error
encryptedExtensions *encryptedExtensionsMsg
sentDummyCCS bool
usingPSK bool
suite *cipherSuiteTLS13
cert *Certificate
sigAlg SignatureScheme
earlySecret []byte
sharedKey []byte
handshakeSecret []byte
masterSecret []byte
trafficSecret []byte // client_application_traffic_secret_0
transcript hash.Hash
clientFinished []byte
earlyData bool
}
func (hs *serverHandshakeStateTLS13) handshake() error {
c := hs.c
if needFIPS() {
return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode")
}
// For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2.
if err := hs.processClientHello(); err != nil {
return err
}
if err := hs.checkForResumption(); err != nil {
return err
}
if err := hs.pickCertificate(); err != nil {
return err
}
c.buffering = true
if err := hs.sendServerParameters(); err != nil {
return err
}
if err := hs.sendServerCertificate(); err != nil {
return err
}
if err := hs.sendServerFinished(); err != nil {
return err
}
// Note that at this point we could start sending application data without
// waiting for the client's second flight, but the application might not
// expect the lack of replay protection of the ClientHello parameters.
if _, err := c.flush(); err != nil {
return err
}
if err := hs.readClientCertificate(); err != nil {
return err
}
if err := hs.readClientFinished(); err != nil {
return err
}
c.isHandshakeComplete.Store(true)
return nil
}
func (hs *serverHandshakeStateTLS13) processClientHello() error {
c := hs.c
hs.hello = new(serverHelloMsg)
hs.encryptedExtensions = new(encryptedExtensionsMsg)
// TLS 1.3 froze the ServerHello.legacy_version field, and uses
// supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1.
hs.hello.vers = VersionTLS12
hs.hello.supportedVersion = c.vers
if len(hs.clientHello.supportedVersions) == 0 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client used the legacy version field to negotiate TLS 1.3")
}
// Abort if the client is doing a fallback and landing lower than what we
// support. See RFC 7507, which however does not specify the interaction
// with supported_versions. The only difference is that with
// supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4]
// handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case,
// it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to
// TLS 1.2, because a TLS 1.3 server would abort here. The situation before
// supported_versions was not better because there was just no way to do a
// TLS 1.4 handshake without risking the server selecting TLS 1.3.
for _, id := range hs.clientHello.cipherSuites {
if id == TLS_FALLBACK_SCSV {
// Use c.vers instead of max(supported_versions) because an attacker
// could defeat this by adding an arbitrary high version otherwise.
if c.vers < c.config.maxSupportedVersion(roleServer) {
c.sendAlert(alertInappropriateFallback)
return errors.New("tls: client using inappropriate protocol fallback")
}
break
}
}
if len(hs.clientHello.compressionMethods) != 1 ||
hs.clientHello.compressionMethods[0] != compressionNone {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: TLS 1.3 client supports illegal compression methods")
}
hs.hello.random = make([]byte, 32)
if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil {
c.sendAlert(alertInternalError)
return err
}
if len(hs.clientHello.secureRenegotiation) != 0 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: initial handshake had non-empty renegotiation extension")
}
hs.hello.sessionId = hs.clientHello.sessionId
hs.hello.compressionMethod = compressionNone
preferenceList := defaultCipherSuitesTLS13
if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) {
preferenceList = defaultCipherSuitesTLS13NoAES
}
for _, suiteID := range preferenceList {
hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID)
if hs.suite != nil {
break
}
}
if hs.suite == nil {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: no cipher suite supported by both client and server")
}
c.cipherSuite = hs.suite.id
hs.hello.cipherSuite = hs.suite.id
hs.transcript = hs.suite.hash.New()
// Pick the ECDHE group in server preference order, but give priority to
// groups with a key share, to avoid a HelloRetryRequest round-trip.
var selectedGroup CurveID
var clientKeyShare *keyShare
GroupSelection:
for _, preferredGroup := range c.config.curvePreferences() {
for _, ks := range hs.clientHello.keyShares {
if ks.group == preferredGroup {
selectedGroup = ks.group
clientKeyShare = &ks
break GroupSelection
}
}
if selectedGroup != 0 {
continue
}
for _, group := range hs.clientHello.supportedCurves {
if group == preferredGroup {
selectedGroup = group
break
}
}
}
if selectedGroup == 0 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: no ECDHE curve supported by both client and server")
}
if clientKeyShare == nil {
if err := hs.doHelloRetryRequest(selectedGroup); err != nil {
return err
}
clientKeyShare = &hs.clientHello.keyShares[0]
}
if _, ok := curveForCurveID(selectedGroup); !ok {
c.sendAlert(alertInternalError)
return errors.New("tls: CurvePreferences includes unsupported curve")
}
key, err := generateECDHEKey(c.config.rand(), selectedGroup)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.hello.serverShare = keyShare{group: selectedGroup, data: key.PublicKey().Bytes()}
peerKey, err := key.Curve().NewPublicKey(clientKeyShare.data)
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid client key share")
}
hs.sharedKey, err = key.ECDH(peerKey)
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid client key share")
}
if c.quic != nil {
if hs.clientHello.quicTransportParameters == nil {
// RFC 9001 Section 8.2.
c.sendAlert(alertMissingExtension)
return errors.New("tls: client did not send a quic_transport_parameters extension")
}
c.quicSetTransportParameters(hs.clientHello.quicTransportParameters)
} else {
if hs.clientHello.quicTransportParameters != nil {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: client sent an unexpected quic_transport_parameters extension")
}
}
c.serverName = hs.clientHello.serverName
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
if err != nil {
hs.alpnNegotiationErr = err
}
hs.encryptedExtensions.alpnProtocol = selectedProto
c.clientProtocol = selectedProto
return nil
}
func (hs *serverHandshakeStateTLS13) checkForResumption() error {
c := hs.c
if c.config.SessionTicketsDisabled {
return nil
}
modeOK := false
for _, mode := range hs.clientHello.pskModes {
if mode == pskModeDHE {
modeOK = true
break
}
}
if !modeOK {
return nil
}
if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid or missing PSK binders")
}
if len(hs.clientHello.pskIdentities) == 0 {
return nil
}
for i, identity := range hs.clientHello.pskIdentities {
if i >= maxClientPSKIdentities {
break
}
plaintext, _ := c.decryptTicket(identity.label)
if plaintext == nil {
continue
}
sessionState := new(sessionStateTLS13)
if ok := sessionState.unmarshal(plaintext); !ok {
continue
}
if hs.clientHello.earlyData {
if sessionState.maxEarlyData == 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: client sent unexpected early data")
}
if hs.alpnNegotiationErr == nil && sessionState.alpn == c.clientProtocol &&
c.extraConfig != nil && c.extraConfig.Enable0RTT &&
c.extraConfig.Accept0RTT != nil && c.extraConfig.Accept0RTT(sessionState.appData) {
hs.encryptedExtensions.earlyData = true
}
}
createdAt := time.Unix(int64(sessionState.createdAt), 0)
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
continue
}
// We don't check the obfuscated ticket age because it's affected by
// clock skew and it's only a freshness signal useful for shrinking the
// window for replay attacks, which don't affect us as we don't do 0-RTT.
pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite)
if pskSuite == nil || pskSuite.hash != hs.suite.hash {
continue
}
// PSK connections don't re-establish client certificates, but carry
// them over in the session ticket. Ensure the presence of client certs
// in the ticket is consistent with the configured requirements.
sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0
needClientCerts := requiresClientCert(c.config.ClientAuth)
if needClientCerts && !sessionHasClientCerts {
continue
}
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
continue
}
psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption",
nil, hs.suite.hash.Size())
hs.earlySecret = hs.suite.extract(psk, nil)
binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil)
// Clone the transcript in case a HelloRetryRequest was recorded.
transcript := cloneHash(hs.transcript, hs.suite.hash)
if transcript == nil {
c.sendAlert(alertInternalError)
return errors.New("tls: internal error: failed to clone hash")
}
clientHelloBytes, err := hs.clientHello.marshalWithoutBinders()
if err != nil {
c.sendAlert(alertInternalError)
return err
}
transcript.Write(clientHelloBytes)
pskBinder := hs.suite.finishedHash(binderKey, transcript)
if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid PSK binder")
}
if c.quic != nil && hs.clientHello.earlyData && hs.encryptedExtensions.earlyData && i == 0 &&
sessionState.maxEarlyData > 0 && sessionState.cipherSuite == hs.suite.id {
hs.earlyData = true
transcript := hs.suite.hash.New()
if err := transcriptMsg(hs.clientHello, transcript); err != nil {
return err
}
earlyTrafficSecret := hs.suite.deriveSecret(hs.earlySecret, clientEarlyTrafficLabel, transcript)
c.quicSetReadSecret(QUICEncryptionLevelEarly, hs.suite.id, earlyTrafficSecret)
}
c.didResume = true
if err := c.processCertsFromClient(sessionState.certificate); err != nil {
return err
}
hs.hello.selectedIdentityPresent = true
hs.hello.selectedIdentity = uint16(i)
hs.usingPSK = true
return nil
}
return nil
}
// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
// interfaces implemented by standard library hashes to clone the state of in
// to a new instance of h. It returns nil if the operation fails.
func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash {
// Recreate the interface to avoid importing encoding.
type binaryMarshaler interface {
MarshalBinary() (data []byte, err error)
UnmarshalBinary(data []byte) error
}
marshaler, ok := in.(binaryMarshaler)
if !ok {
return nil
}
state, err := marshaler.MarshalBinary()
if err != nil {
return nil
}
out := h.New()
unmarshaler, ok := out.(binaryMarshaler)
if !ok {
return nil
}
if err := unmarshaler.UnmarshalBinary(state); err != nil {
return nil
}
return out
}
func (hs *serverHandshakeStateTLS13) pickCertificate() error {
c := hs.c
// Only one of PSK and certificates are used at a time.
if hs.usingPSK {
return nil
}
// signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3.
if len(hs.clientHello.supportedSignatureAlgorithms) == 0 {
return c.sendAlert(alertMissingExtension)
}
certificate, err := c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello))
if err != nil {
if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName)
} else {
c.sendAlert(alertInternalError)
}
return err
}
hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms)
if err != nil {
// getCertificate returned a certificate that is unsupported or
// incompatible with the client's signature algorithms.
c.sendAlert(alertHandshakeFailure)
return err
}
hs.cert = certificate
return nil
}
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
if hs.c.quic != nil {
return nil
}
if hs.sentDummyCCS {
return nil
}
hs.sentDummyCCS = true
return hs.c.writeChangeCipherRecord()
}
func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error {
c := hs.c
// The first ClientHello gets double-hashed into the transcript upon a
// HelloRetryRequest. See RFC 8446, Section 4.4.1.
if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
return err
}
chHash := hs.transcript.Sum(nil)
hs.transcript.Reset()
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
hs.transcript.Write(chHash)
helloRetryRequest := &serverHelloMsg{
vers: hs.hello.vers,
random: helloRetryRequestRandom,
sessionId: hs.hello.sessionId,
cipherSuite: hs.hello.cipherSuite,
compressionMethod: hs.hello.compressionMethod,
supportedVersion: hs.hello.supportedVersion,
selectedGroup: selectedGroup,
}
if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil {
return err
}
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
// clientHelloMsg is not included in the transcript.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
clientHello, ok := msg.(*clientHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(clientHello, msg)
}
if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client sent invalid key share in second ClientHello")
}
if clientHello.earlyData {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client indicated early data in second ClientHello")
}
if illegalClientHelloChange(clientHello, hs.clientHello) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client illegally modified second ClientHello")
}
if illegalClientHelloChange(clientHello, hs.clientHello) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client illegally modified second ClientHello")
}
hs.clientHello = clientHello
return nil
}
// illegalClientHelloChange reports whether the two ClientHello messages are
// different, with the exception of the changes allowed before and after a
// HelloRetryRequest. See RFC 8446, Section 4.1.2.
func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool {
if len(ch.supportedVersions) != len(ch1.supportedVersions) ||
len(ch.cipherSuites) != len(ch1.cipherSuites) ||
len(ch.supportedCurves) != len(ch1.supportedCurves) ||
len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) ||
len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) ||
len(ch.alpnProtocols) != len(ch1.alpnProtocols) {
return true
}
for i := range ch.supportedVersions {
if ch.supportedVersions[i] != ch1.supportedVersions[i] {
return true
}
}
for i := range ch.cipherSuites {
if ch.cipherSuites[i] != ch1.cipherSuites[i] {
return true
}
}
for i := range ch.supportedCurves {
if ch.supportedCurves[i] != ch1.supportedCurves[i] {
return true
}
}
for i := range ch.supportedSignatureAlgorithms {
if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] {
return true
}
}
for i := range ch.supportedSignatureAlgorithmsCert {
if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] {
return true
}
}
for i := range ch.alpnProtocols {
if ch.alpnProtocols[i] != ch1.alpnProtocols[i] {
return true
}
}
return ch.vers != ch1.vers ||
!bytes.Equal(ch.random, ch1.random) ||
!bytes.Equal(ch.sessionId, ch1.sessionId) ||
!bytes.Equal(ch.compressionMethods, ch1.compressionMethods) ||
ch.serverName != ch1.serverName ||
ch.ocspStapling != ch1.ocspStapling ||
!bytes.Equal(ch.supportedPoints, ch1.supportedPoints) ||
ch.ticketSupported != ch1.ticketSupported ||
!bytes.Equal(ch.sessionTicket, ch1.sessionTicket) ||
ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported ||
!bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) ||
ch.scts != ch1.scts ||
!bytes.Equal(ch.cookie, ch1.cookie) ||
!bytes.Equal(ch.pskModes, ch1.pskModes)
}
func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
c := hs.c
if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
return err
}
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
earlySecret := hs.earlySecret
if earlySecret == nil {
earlySecret = hs.suite.extract(nil, nil)
}
hs.handshakeSecret = hs.suite.extract(hs.sharedKey,
hs.suite.deriveSecret(earlySecret, "derived", nil))
clientSecret := hs.suite.deriveSecret(hs.handshakeSecret,
clientHandshakeTrafficLabel, hs.transcript)
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
serverSecret := hs.suite.deriveSecret(hs.handshakeSecret,
serverHandshakeTrafficLabel, hs.transcript)
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
if c.quic != nil {
if c.hand.Len() != 0 {
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
}
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
if err != nil {
c.sendAlert(alertNoApplicationProtocol)
return err
}
hs.encryptedExtensions.alpnProtocol = selectedProto
c.clientProtocol = selectedProto
if c.quic != nil {
p, err := c.quicGetTransportParameters()
if err != nil {
return err
}
hs.encryptedExtensions.quicTransportParameters = p
}
if _, err := hs.c.writeHandshakeRecord(hs.encryptedExtensions, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) requestClientCert() bool {
return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK
}
func (hs *serverHandshakeStateTLS13) sendServerCertificate() error {
c := hs.c
// Only one of PSK and certificates are used at a time.
if hs.usingPSK {
return nil
}
if hs.requestClientCert() {
// Request a client certificate
certReq := new(certificateRequestMsgTLS13)
certReq.ocspStapling = true
certReq.scts = true
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
if c.config.ClientCAs != nil {
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
}
if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil {
return err
}
}
certMsg := new(certificateMsgTLS13)
certMsg.certificate = *hs.cert
certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0
certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0
if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
return err
}
certVerifyMsg := new(certificateVerifyMsg)
certVerifyMsg.hasSignatureAlgorithm = true
certVerifyMsg.signatureAlgorithm = hs.sigAlg
sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg)
if err != nil {
return c.sendAlert(alertInternalError)
}
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
if err != nil {
public := hs.cert.PrivateKey.(crypto.Signer).Public()
if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS &&
rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS
c.sendAlert(alertHandshakeFailure)
} else {
c.sendAlert(alertInternalError)
}
return errors.New("tls: failed to sign handshake: " + err.Error())
}
certVerifyMsg.signature = sig
if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) sendServerFinished() error {
c := hs.c
finished := &finishedMsg{
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
}
if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
return err
}
// Derive secrets that take context through the server Finished.
hs.masterSecret = hs.suite.extract(nil,
hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil))
hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret,
clientApplicationTrafficLabel, hs.transcript)
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
serverApplicationTrafficLabel, hs.transcript)
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
if c.quic != nil {
if c.hand.Len() != 0 {
// TODO: Handle this in setTrafficSecret?
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, serverSecret)
}
err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
// If we did not request client certificates, at this point we can
// precompute the client finished and roll the transcript forward to send
// session tickets in our first flight.
if !hs.requestClientCert() {
if err := hs.sendSessionTickets(); err != nil {
return err
}
}
return nil
}
func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool {
if hs.c.config.SessionTicketsDisabled {
return false
}
// QUIC tickets are sent by QUICConn.SendSessionTicket, not automatically.
if hs.c.quic != nil {
return false
}
// Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9.
for _, pskMode := range hs.clientHello.pskModes {
if pskMode == pskModeDHE {
return true
}
}
return false
}
func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
c := hs.c
hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
finishedMsg := &finishedMsg{
verifyData: hs.clientFinished,
}
if err := transcriptMsg(finishedMsg, hs.transcript); err != nil {
return err
}
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
resumptionLabel, hs.transcript)
if !hs.shouldSendSessionTickets() {
return nil
}
return c.sendSessionTicket(false)
}
func (c *Conn) sendSessionTicket(earlyData bool) error {
suite := cipherSuiteTLS13ByID(c.cipherSuite)
if suite == nil {
return errors.New("tls: internal error: unknown cipher suite")
}
m := new(newSessionTicketMsgTLS13)
var certsFromClient [][]byte
for _, cert := range c.peerCertificates {
certsFromClient = append(certsFromClient, cert.Raw)
}
state := sessionStateTLS13{
cipherSuite: suite.id,
createdAt: uint64(c.config.time().Unix()),
resumptionSecret: c.resumptionSecret,
certificate: Certificate{
Certificate: certsFromClient,
OCSPStaple: c.ocspResponse,
SignedCertificateTimestamps: c.scts,
},
alpn: c.clientProtocol,
}
if earlyData {
state.maxEarlyData = 0xffffffff
state.appData = c.extraConfig.GetAppDataForSessionTicket()
}
stateBytes, err := state.marshal()
if err != nil {
c.sendAlert(alertInternalError)
return err
}
m.label, err = c.encryptTicket(stateBytes)
if err != nil {
return err
}
m.lifetime = uint32(maxSessionTicketLifetime / time.Second)
// ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1
// The value is not stored anywhere; we never need to check the ticket age
// because 0-RTT is not supported.
ageAdd := make([]byte, 4)
_, err = c.config.rand().Read(ageAdd)
if err != nil {
return err
}
if earlyData {
// RFC 9001, Section 4.6.1
m.maxEarlyData = 0xffffffff
}
if _, err := c.writeHandshakeRecord(m, nil); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
c := hs.c
if !hs.requestClientCert() {
// Make sure the connection is still being verified whether or not
// the server requested a client certificate.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
// If we requested a client certificate, then the client must send a
// certificate message. If it's empty, no CertificateVerify is sent.
msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
certMsg, ok := msg.(*certificateMsgTLS13)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
if err := c.processCertsFromClient(certMsg.certificate); err != nil {
return err
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
if len(certMsg.certificate.Certificate) != 0 {
// certificateVerifyMsg is included in the transcript, but not until
// after we verify the handshake signature, since the state before
// this message was sent is used.
msg, err = c.readHandshake(nil)
if err != nil {
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
// See RFC 8446, Section 4.4.3.
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client certificate used with invalid signature algorithm")
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client certificate used with invalid signature algorithm")
}
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
if err := transcriptMsg(certVerify, hs.transcript); err != nil {
return err
}
}
// If we waited until the client certificates to send session tickets, we
// are ready to do it now.
if err := hs.sendSessionTickets(); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) readClientFinished() error {
c := hs.c
// finishedMsg is not included in the transcript.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
finished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(finished, msg)
}
if !hmac.Equal(hs.clientFinished, finished.verifyData) {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid client finished hash")
}
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
return nil
}

366
vendor/github.com/quic-go/qtls-go1-20/key_agreement.go generated vendored Normal file
View File

@@ -0,0 +1,366 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"crypto"
"crypto/ecdh"
"crypto/md5"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"errors"
"fmt"
"io"
)
// a keyAgreement implements the client and server side of a TLS key agreement
// protocol by generating and processing key exchange messages.
type keyAgreement interface {
// On the server side, the first two methods are called in order.
// In the case that the key agreement protocol doesn't use a
// ServerKeyExchange message, generateServerKeyExchange can return nil,
// nil.
generateServerKeyExchange(*config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error)
processClientKeyExchange(*config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error)
// On the client side, the next two methods are called in order.
// This method may not be called if the server doesn't send a
// ServerKeyExchange message.
processServerKeyExchange(*config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error
generateClientKeyExchange(*config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error)
}
var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
// rsaKeyAgreement implements the standard TLS key agreement where the client
// encrypts the pre-master secret to the server's public key.
type rsaKeyAgreement struct{}
func (ka rsaKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
return nil, nil
}
func (ka rsaKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
if len(ckx.ciphertext) < 2 {
return nil, errClientKeyExchange
}
ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
if ciphertextLen != len(ckx.ciphertext)-2 {
return nil, errClientKeyExchange
}
ciphertext := ckx.ciphertext[2:]
priv, ok := cert.PrivateKey.(crypto.Decrypter)
if !ok {
return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter")
}
// Perform constant time RSA PKCS #1 v1.5 decryption
preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48})
if err != nil {
return nil, err
}
// We don't check the version number in the premaster secret. For one,
// by checking it, we would leak information about the validity of the
// encrypted pre-master secret. Secondly, it provides only a small
// benefit against a downgrade attack and some implementations send the
// wrong version anyway. See the discussion at the end of section
// 7.4.7.1 of RFC 4346.
return preMasterSecret, nil
}
func (ka rsaKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
return errors.New("tls: unexpected ServerKeyExchange")
}
func (ka rsaKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
preMasterSecret := make([]byte, 48)
preMasterSecret[0] = byte(clientHello.vers >> 8)
preMasterSecret[1] = byte(clientHello.vers)
_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
if err != nil {
return nil, nil, err
}
rsaKey, ok := cert.PublicKey.(*rsa.PublicKey)
if !ok {
return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite")
}
encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret)
if err != nil {
return nil, nil, err
}
ckx := new(clientKeyExchangeMsg)
ckx.ciphertext = make([]byte, len(encrypted)+2)
ckx.ciphertext[0] = byte(len(encrypted) >> 8)
ckx.ciphertext[1] = byte(len(encrypted))
copy(ckx.ciphertext[2:], encrypted)
return preMasterSecret, ckx, nil
}
// sha1Hash calculates a SHA1 hash over the given byte slices.
func sha1Hash(slices [][]byte) []byte {
hsha1 := sha1.New()
for _, slice := range slices {
hsha1.Write(slice)
}
return hsha1.Sum(nil)
}
// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the
// concatenation of an MD5 and SHA1 hash.
func md5SHA1Hash(slices [][]byte) []byte {
md5sha1 := make([]byte, md5.Size+sha1.Size)
hmd5 := md5.New()
for _, slice := range slices {
hmd5.Write(slice)
}
copy(md5sha1, hmd5.Sum(nil))
copy(md5sha1[md5.Size:], sha1Hash(slices))
return md5sha1
}
// hashForServerKeyExchange hashes the given slices and returns their digest
// using the given hash function (for >= TLS 1.2) or using a default based on
// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't
// do pre-hashing, it returns the concatenation of the slices.
func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte {
if sigType == signatureEd25519 {
var signed []byte
for _, slice := range slices {
signed = append(signed, slice...)
}
return signed
}
if version >= VersionTLS12 {
h := hashFunc.New()
for _, slice := range slices {
h.Write(slice)
}
digest := h.Sum(nil)
return digest
}
if sigType == signatureECDSA {
return sha1Hash(slices)
}
return md5SHA1Hash(slices)
}
// ecdheKeyAgreement implements a TLS key agreement where the server
// generates an ephemeral EC public/private key pair and signs it. The
// pre-master secret is then calculated using ECDH. The signature may
// be ECDSA, Ed25519 or RSA.
type ecdheKeyAgreement struct {
version uint16
isRSA bool
key *ecdh.PrivateKey
// ckx and preMasterSecret are generated in processServerKeyExchange
// and returned in generateClientKeyExchange.
ckx *clientKeyExchangeMsg
preMasterSecret []byte
}
func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
var curveID CurveID
for _, c := range clientHello.supportedCurves {
if config.supportsCurve(c) {
curveID = c
break
}
}
if curveID == 0 {
return nil, errors.New("tls: no supported elliptic curves offered")
}
if _, ok := curveForCurveID(curveID); !ok {
return nil, errors.New("tls: CurvePreferences includes unsupported curve")
}
key, err := generateECDHEKey(config.rand(), curveID)
if err != nil {
return nil, err
}
ka.key = key
// See RFC 4492, Section 5.4.
ecdhePublic := key.PublicKey().Bytes()
serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic))
serverECDHEParams[0] = 3 // named curve
serverECDHEParams[1] = byte(curveID >> 8)
serverECDHEParams[2] = byte(curveID)
serverECDHEParams[3] = byte(len(ecdhePublic))
copy(serverECDHEParams[4:], ecdhePublic)
priv, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey)
}
var signatureAlgorithm SignatureScheme
var sigType uint8
var sigHash crypto.Hash
if ka.version >= VersionTLS12 {
signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms)
if err != nil {
return nil, err
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
if err != nil {
return nil, err
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public())
if err != nil {
return nil, err
}
}
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
return nil, errors.New("tls: certificate cannot be used with the selected cipher suite")
}
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := priv.Sign(config.rand(), signed, signOpts)
if err != nil {
return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
}
skx := new(serverKeyExchangeMsg)
sigAndHashLen := 0
if ka.version >= VersionTLS12 {
sigAndHashLen = 2
}
skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig))
copy(skx.key, serverECDHEParams)
k := skx.key[len(serverECDHEParams):]
if ka.version >= VersionTLS12 {
k[0] = byte(signatureAlgorithm >> 8)
k[1] = byte(signatureAlgorithm)
k = k[2:]
}
k[0] = byte(len(sig) >> 8)
k[1] = byte(len(sig))
copy(k[2:], sig)
return skx, nil
}
func (ka *ecdheKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
return nil, errClientKeyExchange
}
peerKey, err := ka.key.Curve().NewPublicKey(ckx.ciphertext[1:])
if err != nil {
return nil, errClientKeyExchange
}
preMasterSecret, err := ka.key.ECDH(peerKey)
if err != nil {
return nil, errClientKeyExchange
}
return preMasterSecret, nil
}
func (ka *ecdheKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
if len(skx.key) < 4 {
return errServerKeyExchange
}
if skx.key[0] != 3 { // named curve
return errors.New("tls: server selected unsupported curve")
}
curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
publicLen := int(skx.key[3])
if publicLen+4 > len(skx.key) {
return errServerKeyExchange
}
serverECDHEParams := skx.key[:4+publicLen]
publicKey := serverECDHEParams[4:]
sig := skx.key[4+publicLen:]
if len(sig) < 2 {
return errServerKeyExchange
}
if _, ok := curveForCurveID(curveID); !ok {
return errors.New("tls: server selected unsupported curve")
}
key, err := generateECDHEKey(config.rand(), curveID)
if err != nil {
return err
}
ka.key = key
peerKey, err := key.Curve().NewPublicKey(publicKey)
if err != nil {
return errServerKeyExchange
}
ka.preMasterSecret, err = key.ECDH(peerKey)
if err != nil {
return errServerKeyExchange
}
ourPublicKey := key.PublicKey().Bytes()
ka.ckx = new(clientKeyExchangeMsg)
ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey))
ka.ckx.ciphertext[0] = byte(len(ourPublicKey))
copy(ka.ckx.ciphertext[1:], ourPublicKey)
var sigType uint8
var sigHash crypto.Hash
if ka.version >= VersionTLS12 {
signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1])
sig = sig[2:]
if len(sig) < 2 {
return errServerKeyExchange
}
if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) {
return errors.New("tls: certificate used with invalid signature algorithm")
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
if err != nil {
return err
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey)
if err != nil {
return err
}
}
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
return errServerKeyExchange
}
sigLen := int(sig[0])<<8 | int(sig[1])
if sigLen+2 != len(sig) {
return errServerKeyExchange
}
sig = sig[2:]
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams)
if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil {
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
return nil
}
func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
if ka.ckx == nil {
return nil, nil, errors.New("tls: missing ServerKeyExchange message")
}
return ka.preMasterSecret, ka.ckx, nil
}

159
vendor/github.com/quic-go/qtls-go1-20/key_schedule.go generated vendored Normal file
View File

@@ -0,0 +1,159 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"crypto/ecdh"
"crypto/hmac"
"errors"
"fmt"
"hash"
"io"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/hkdf"
)
// This file contains the functions necessary to compute the TLS 1.3 key
// schedule. See RFC 8446, Section 7.
const (
resumptionBinderLabel = "res binder"
clientEarlyTrafficLabel = "c e traffic"
clientHandshakeTrafficLabel = "c hs traffic"
serverHandshakeTrafficLabel = "s hs traffic"
clientApplicationTrafficLabel = "c ap traffic"
serverApplicationTrafficLabel = "s ap traffic"
exporterLabel = "exp master"
resumptionLabel = "res master"
trafficUpdateLabel = "traffic upd"
)
// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1.
func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte {
var hkdfLabel cryptobyte.Builder
hkdfLabel.AddUint16(uint16(length))
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte("tls13 "))
b.AddBytes([]byte(label))
})
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(context)
})
hkdfLabelBytes, err := hkdfLabel.Bytes()
if err != nil {
// Rather than calling BytesOrPanic, we explicitly handle this error, in
// order to provide a reasonable error message. It should be basically
// impossible for this to panic, and routing errors back through the
// tree rooted in this function is quite painful. The labels are fixed
// size, and the context is either a fixed-length computed hash, or
// parsed from a field which has the same length limitation. As such, an
// error here is likely to only be caused during development.
//
// NOTE: another reasonable approach here might be to return a
// randomized slice if we encounter an error, which would break the
// connection, but avoid panicking. This would perhaps be safer but
// significantly more confusing to users.
panic(fmt.Errorf("failed to construct HKDF label: %s", err))
}
out := make([]byte, length)
n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out)
if err != nil || n != length {
panic("tls: HKDF-Expand-Label invocation failed unexpectedly")
}
return out
}
// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1.
func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte {
if transcript == nil {
transcript = c.hash.New()
}
return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size())
}
// extract implements HKDF-Extract with the cipher suite hash.
func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte {
if newSecret == nil {
newSecret = make([]byte, c.hash.Size())
}
return hkdf.Extract(c.hash.New, newSecret, currentSecret)
}
// nextTrafficSecret generates the next traffic secret, given the current one,
// according to RFC 8446, Section 7.2.
func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte {
return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size())
}
// trafficKey generates traffic keys according to RFC 8446, Section 7.3.
func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) {
key = c.expandLabel(trafficSecret, "key", nil, c.keyLen)
iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength)
return
}
// finishedHash generates the Finished verify_data or PskBinderEntry according
// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey
// selection.
func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte {
finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size())
verifyData := hmac.New(c.hash.New, finishedKey)
verifyData.Write(transcript.Sum(nil))
return verifyData.Sum(nil)
}
// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to
// RFC 8446, Section 7.5.
func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) {
expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript)
return func(label string, context []byte, length int) ([]byte, error) {
secret := c.deriveSecret(expMasterSecret, label, nil)
h := c.hash.New()
h.Write(context)
return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil
}
}
// generateECDHEKey returns a PrivateKey that implements Diffie-Hellman
// according to RFC 8446, Section 4.2.8.2.
func generateECDHEKey(rand io.Reader, curveID CurveID) (*ecdh.PrivateKey, error) {
curve, ok := curveForCurveID(curveID)
if !ok {
return nil, errors.New("tls: internal error: unsupported curve")
}
return curve.GenerateKey(rand)
}
func curveForCurveID(id CurveID) (ecdh.Curve, bool) {
switch id {
case X25519:
return ecdh.X25519(), true
case CurveP256:
return ecdh.P256(), true
case CurveP384:
return ecdh.P384(), true
case CurveP521:
return ecdh.P521(), true
default:
return nil, false
}
}
func curveIDForCurve(curve ecdh.Curve) (CurveID, bool) {
switch curve {
case ecdh.X25519():
return X25519, true
case ecdh.P256():
return CurveP256, true
case ecdh.P384():
return CurveP384, true
case ecdh.P521():
return CurveP521, true
default:
return 0, false
}
}

18
vendor/github.com/quic-go/qtls-go1-20/notboring.go generated vendored Normal file
View File

@@ -0,0 +1,18 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
func needFIPS() bool { return false }
func supportedSignatureAlgorithms() []SignatureScheme {
return defaultSupportedSignatureAlgorithms
}
func fipsMinVersion(c *config) uint16 { panic("fipsMinVersion") }
func fipsMaxVersion(c *config) uint16 { panic("fipsMaxVersion") }
func fipsCurvePreferences(c *config) []CurveID { panic("fipsCurvePreferences") }
func fipsCipherSuites(c *config) []uint16 { panic("fipsCipherSuites") }
var fipsSupportedSignatureAlgorithms []SignatureScheme

283
vendor/github.com/quic-go/qtls-go1-20/prf.go generated vendored Normal file
View File

@@ -0,0 +1,283 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"crypto"
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"errors"
"fmt"
"hash"
)
// Split a premaster secret in two as specified in RFC 4346, Section 5.
func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
s1 = secret[0 : (len(secret)+1)/2]
s2 = secret[len(secret)/2:]
return
}
// pHash implements the P_hash function, as defined in RFC 4346, Section 5.
func pHash(result, secret, seed []byte, hash func() hash.Hash) {
h := hmac.New(hash, secret)
h.Write(seed)
a := h.Sum(nil)
j := 0
for j < len(result) {
h.Reset()
h.Write(a)
h.Write(seed)
b := h.Sum(nil)
copy(result[j:], b)
j += len(b)
h.Reset()
h.Write(a)
a = h.Sum(nil)
}
}
// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5.
func prf10(result, secret, label, seed []byte) {
hashSHA1 := sha1.New
hashMD5 := md5.New
labelAndSeed := make([]byte, len(label)+len(seed))
copy(labelAndSeed, label)
copy(labelAndSeed[len(label):], seed)
s1, s2 := splitPreMasterSecret(secret)
pHash(result, s1, labelAndSeed, hashMD5)
result2 := make([]byte, len(result))
pHash(result2, s2, labelAndSeed, hashSHA1)
for i, b := range result2 {
result[i] ^= b
}
}
// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5.
func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) {
return func(result, secret, label, seed []byte) {
labelAndSeed := make([]byte, len(label)+len(seed))
copy(labelAndSeed, label)
copy(labelAndSeed[len(label):], seed)
pHash(result, secret, labelAndSeed, hashFunc)
}
}
const (
masterSecretLength = 48 // Length of a master secret in TLS 1.1.
finishedVerifyLength = 12 // Length of verify_data in a Finished message.
)
var masterSecretLabel = []byte("master secret")
var keyExpansionLabel = []byte("key expansion")
var clientFinishedLabel = []byte("client finished")
var serverFinishedLabel = []byte("server finished")
func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) {
switch version {
case VersionTLS10, VersionTLS11:
return prf10, crypto.Hash(0)
case VersionTLS12:
if suite.flags&suiteSHA384 != 0 {
return prf12(sha512.New384), crypto.SHA384
}
return prf12(sha256.New), crypto.SHA256
default:
panic("unknown version")
}
}
func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) {
prf, _ := prfAndHashForVersion(version, suite)
return prf
}
// masterFromPreMasterSecret generates the master secret from the pre-master
// secret. See RFC 5246, Section 8.1.
func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte {
seed := make([]byte, 0, len(clientRandom)+len(serverRandom))
seed = append(seed, clientRandom...)
seed = append(seed, serverRandom...)
masterSecret := make([]byte, masterSecretLength)
prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed)
return masterSecret
}
// keysFromMasterSecret generates the connection keys from the master
// secret, given the lengths of the MAC key, cipher key and IV, as defined in
// RFC 2246, Section 6.3.
func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) {
seed := make([]byte, 0, len(serverRandom)+len(clientRandom))
seed = append(seed, serverRandom...)
seed = append(seed, clientRandom...)
n := 2*macLen + 2*keyLen + 2*ivLen
keyMaterial := make([]byte, n)
prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed)
clientMAC = keyMaterial[:macLen]
keyMaterial = keyMaterial[macLen:]
serverMAC = keyMaterial[:macLen]
keyMaterial = keyMaterial[macLen:]
clientKey = keyMaterial[:keyLen]
keyMaterial = keyMaterial[keyLen:]
serverKey = keyMaterial[:keyLen]
keyMaterial = keyMaterial[keyLen:]
clientIV = keyMaterial[:ivLen]
keyMaterial = keyMaterial[ivLen:]
serverIV = keyMaterial[:ivLen]
return
}
func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash {
var buffer []byte
if version >= VersionTLS12 {
buffer = []byte{}
}
prf, hash := prfAndHashForVersion(version, cipherSuite)
if hash != 0 {
return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf}
}
return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf}
}
// A finishedHash calculates the hash of a set of handshake messages suitable
// for including in a Finished message.
type finishedHash struct {
client hash.Hash
server hash.Hash
// Prior to TLS 1.2, an additional MD5 hash is required.
clientMD5 hash.Hash
serverMD5 hash.Hash
// In TLS 1.2, a full buffer is sadly required.
buffer []byte
version uint16
prf func(result, secret, label, seed []byte)
}
func (h *finishedHash) Write(msg []byte) (n int, err error) {
h.client.Write(msg)
h.server.Write(msg)
if h.version < VersionTLS12 {
h.clientMD5.Write(msg)
h.serverMD5.Write(msg)
}
if h.buffer != nil {
h.buffer = append(h.buffer, msg...)
}
return len(msg), nil
}
func (h finishedHash) Sum() []byte {
if h.version >= VersionTLS12 {
return h.client.Sum(nil)
}
out := make([]byte, 0, md5.Size+sha1.Size)
out = h.clientMD5.Sum(out)
return h.client.Sum(out)
}
// clientSum returns the contents of the verify_data member of a client's
// Finished message.
func (h finishedHash) clientSum(masterSecret []byte) []byte {
out := make([]byte, finishedVerifyLength)
h.prf(out, masterSecret, clientFinishedLabel, h.Sum())
return out
}
// serverSum returns the contents of the verify_data member of a server's
// Finished message.
func (h finishedHash) serverSum(masterSecret []byte) []byte {
out := make([]byte, finishedVerifyLength)
h.prf(out, masterSecret, serverFinishedLabel, h.Sum())
return out
}
// hashForClientCertificate returns the handshake messages so far, pre-hashed if
// necessary, suitable for signing by a TLS client certificate.
func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash) []byte {
if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil {
panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer")
}
if sigType == signatureEd25519 {
return h.buffer
}
if h.version >= VersionTLS12 {
hash := hashAlg.New()
hash.Write(h.buffer)
return hash.Sum(nil)
}
if sigType == signatureECDSA {
return h.server.Sum(nil)
}
return h.Sum()
}
// discardHandshakeBuffer is called when there is no more need to
// buffer the entirety of the handshake messages.
func (h *finishedHash) discardHandshakeBuffer() {
h.buffer = nil
}
// noExportedKeyingMaterial is used as a value of
// ConnectionState.ekm when renegotiation is enabled and thus
// we wish to fail all key-material export requests.
func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled")
}
// ekmFromMasterSecret generates exported keying material as defined in RFC 5705.
func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) {
return func(label string, context []byte, length int) ([]byte, error) {
switch label {
case "client finished", "server finished", "master secret", "key expansion":
// These values are reserved and may not be used.
return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label)
}
seedLen := len(serverRandom) + len(clientRandom)
if context != nil {
seedLen += 2 + len(context)
}
seed := make([]byte, 0, seedLen)
seed = append(seed, clientRandom...)
seed = append(seed, serverRandom...)
if context != nil {
if len(context) >= 1<<16 {
return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long")
}
seed = append(seed, byte(len(context)>>8), byte(len(context)))
seed = append(seed, context...)
}
keyMaterial := make([]byte, length)
prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed)
return keyMaterial, nil
}
}

418
vendor/github.com/quic-go/qtls-go1-20/quic.go generated vendored Normal file
View File

@@ -0,0 +1,418 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"context"
"errors"
"fmt"
)
// QUICEncryptionLevel represents a QUIC encryption level used to transmit
// handshake messages.
type QUICEncryptionLevel int
const (
QUICEncryptionLevelInitial = QUICEncryptionLevel(iota)
QUICEncryptionLevelEarly
QUICEncryptionLevelHandshake
QUICEncryptionLevelApplication
)
func (l QUICEncryptionLevel) String() string {
switch l {
case QUICEncryptionLevelInitial:
return "Initial"
case QUICEncryptionLevelEarly:
return "Early"
case QUICEncryptionLevelHandshake:
return "Handshake"
case QUICEncryptionLevelApplication:
return "Application"
default:
return fmt.Sprintf("QUICEncryptionLevel(%v)", int(l))
}
}
// A QUICConn represents a connection which uses a QUIC implementation as the underlying
// transport as described in RFC 9001.
//
// Methods of QUICConn are not safe for concurrent use.
type QUICConn struct {
conn *Conn
sessionTicketSent bool
}
// A QUICConfig configures a QUICConn.
type QUICConfig struct {
TLSConfig *Config
ExtraConfig *ExtraConfig
}
// A QUICEventKind is a type of operation on a QUIC connection.
type QUICEventKind int
const (
// QUICNoEvent indicates that there are no events available.
QUICNoEvent QUICEventKind = iota
// QUICSetReadSecret and QUICSetWriteSecret provide the read and write
// secrets for a given encryption level.
// QUICEvent.Level, QUICEvent.Data, and QUICEvent.Suite are set.
//
// Secrets for the Initial encryption level are derived from the initial
// destination connection ID, and are not provided by the QUICConn.
QUICSetReadSecret
QUICSetWriteSecret
// QUICWriteData provides data to send to the peer in CRYPTO frames.
// QUICEvent.Data is set.
QUICWriteData
// QUICTransportParameters provides the peer's QUIC transport parameters.
// QUICEvent.Data is set.
QUICTransportParameters
// QUICTransportParametersRequired indicates that the caller must provide
// QUIC transport parameters to send to the peer. The caller should set
// the transport parameters with QUICConn.SetTransportParameters and call
// QUICConn.NextEvent again.
//
// If transport parameters are set before calling QUICConn.Start, the
// connection will never generate a QUICTransportParametersRequired event.
QUICTransportParametersRequired
// QUICRejectedEarlyData indicates that the server rejected 0-RTT data even
// if we offered it. It's returned before QUICEncryptionLevelApplication
// keys are returned.
QUICRejectedEarlyData
// QUICHandshakeDone indicates that the TLS handshake has completed.
QUICHandshakeDone
)
// A QUICEvent is an event occurring on a QUIC connection.
//
// The type of event is specified by the Kind field.
// The contents of the other fields are kind-specific.
type QUICEvent struct {
Kind QUICEventKind
// Set for QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
Level QUICEncryptionLevel
// Set for QUICTransportParameters, QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
// The contents are owned by crypto/tls, and are valid until the next NextEvent call.
Data []byte
// Set for QUICSetReadSecret and QUICSetWriteSecret.
Suite uint16
}
type quicState struct {
events []QUICEvent
nextEvent int
// eventArr is a statically allocated event array, large enough to handle
// the usual maximum number of events resulting from a single call: transport
// parameters, Initial data, Early read secret, Handshake write and read
// secrets, Handshake data, Application write secret, Application data.
eventArr [8]QUICEvent
started bool
signalc chan struct{} // handshake data is available to be read
blockedc chan struct{} // handshake is waiting for data, closed when done
cancelc <-chan struct{} // handshake has been canceled
cancel context.CancelFunc
// readbuf is shared between HandleData and the handshake goroutine.
// HandshakeCryptoData passes ownership to the handshake goroutine by
// reading from signalc, and reclaims ownership by reading from blockedc.
readbuf []byte
transportParams []byte // to send to the peer
}
// QUICClient returns a new TLS client side connection using QUICTransport as the
// underlying transport. The config cannot be nil.
//
// The config's MinVersion must be at least TLS 1.3.
func QUICClient(config *QUICConfig) *QUICConn {
return newQUICConn(Client(nil, config.TLSConfig), config.ExtraConfig)
}
// QUICServer returns a new TLS server side connection using QUICTransport as the
// underlying transport. The config cannot be nil.
//
// The config's MinVersion must be at least TLS 1.3.
func QUICServer(config *QUICConfig) *QUICConn {
return newQUICConn(Server(nil, config.TLSConfig), config.ExtraConfig)
}
func newQUICConn(conn *Conn, extraConfig *ExtraConfig) *QUICConn {
conn.quic = &quicState{
signalc: make(chan struct{}),
blockedc: make(chan struct{}),
}
conn.quic.events = conn.quic.eventArr[:0]
conn.extraConfig = extraConfig
return &QUICConn{
conn: conn,
}
}
// Start starts the client or server handshake protocol.
// It may produce connection events, which may be read with NextEvent.
//
// Start must be called at most once.
func (q *QUICConn) Start(ctx context.Context) error {
if q.conn.quic.started {
return quicError(errors.New("tls: Start called more than once"))
}
q.conn.quic.started = true
if q.conn.config.MinVersion < VersionTLS13 {
return quicError(errors.New("tls: Config MinVersion must be at least TLS 1.13"))
}
go q.conn.HandshakeContext(ctx)
if _, ok := <-q.conn.quic.blockedc; !ok {
return q.conn.handshakeErr
}
return nil
}
// NextEvent returns the next event occurring on the connection.
// It returns an event with a Kind of QUICNoEvent when no events are available.
func (q *QUICConn) NextEvent() QUICEvent {
qs := q.conn.quic
if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 {
// Write over some of the previous event's data,
// to catch callers erroniously retaining it.
qs.events[last].Data[0] = 0
}
if qs.nextEvent >= len(qs.events) {
qs.events = qs.events[:0]
qs.nextEvent = 0
return QUICEvent{Kind: QUICNoEvent}
}
e := qs.events[qs.nextEvent]
qs.events[qs.nextEvent] = QUICEvent{} // zero out references to data
qs.nextEvent++
return e
}
// Close closes the connection and stops any in-progress handshake.
func (q *QUICConn) Close() error {
if q.conn.quic.cancel == nil {
return nil // never started
}
q.conn.quic.cancel()
for range q.conn.quic.blockedc {
// Wait for the handshake goroutine to return.
}
return q.conn.handshakeErr
}
// HandleData handles handshake bytes received from the peer.
// It may produce connection events, which may be read with NextEvent.
func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
c := q.conn
if c.in.level != level {
return quicError(c.in.setErrorLocked(errors.New("tls: handshake data received at wrong level")))
}
c.quic.readbuf = data
<-c.quic.signalc
_, ok := <-c.quic.blockedc
if ok {
// The handshake goroutine is waiting for more data.
return nil
}
// The handshake goroutine has exited.
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
c.hand.Write(c.quic.readbuf)
c.quic.readbuf = nil
for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil {
b := q.conn.hand.Bytes()
n := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
if n > maxHandshake {
q.conn.handshakeErr = fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)
break
}
if len(b) < 4+n {
return nil
}
if err := q.conn.handlePostHandshakeMessage(); err != nil {
q.conn.handshakeErr = err
}
}
if q.conn.handshakeErr != nil {
return quicError(q.conn.handshakeErr)
}
return nil
}
// SendSessionTicket sends a session ticket to the client.
// It produces connection events, which may be read with NextEvent.
// Currently, it can only be called once.
func (q *QUICConn) SendSessionTicket(earlyData bool) error {
c := q.conn
if !c.isHandshakeComplete.Load() {
return quicError(errors.New("tls: SendSessionTicket called before handshake completed"))
}
if c.isClient {
return quicError(errors.New("tls: SendSessionTicket called on the client"))
}
if q.sessionTicketSent {
return quicError(errors.New("tls: SendSessionTicket called multiple times"))
}
q.sessionTicketSent = true
return quicError(c.sendSessionTicket(earlyData))
}
// ConnectionState returns basic TLS details about the connection.
func (q *QUICConn) ConnectionState() ConnectionState {
return q.conn.ConnectionState()
}
// SetTransportParameters sets the transport parameters to send to the peer.
//
// Server connections may delay setting the transport parameters until after
// receiving the client's transport parameters. See QUICTransportParametersRequired.
func (q *QUICConn) SetTransportParameters(params []byte) {
if params == nil {
params = []byte{}
}
q.conn.quic.transportParams = params
if q.conn.quic.started {
<-q.conn.quic.signalc
<-q.conn.quic.blockedc
}
}
// quicError ensures err is an AlertError.
// If err is not already, quicError wraps it with alertInternalError.
func quicError(err error) error {
if err == nil {
return nil
}
var ae AlertError
if errors.As(err, &ae) {
return err
}
var a alert
if !errors.As(err, &a) {
a = alertInternalError
}
// Return an error wrapping the original error and an AlertError.
// Truncate the text of the alert to 0 characters.
return fmt.Errorf("%w%.0w", err, AlertError(a))
}
func (c *Conn) quicReadHandshakeBytes(n int) error {
for c.hand.Len() < n {
if err := c.quicWaitForSignal(); err != nil {
return err
}
}
return nil
}
func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICSetReadSecret,
Level: level,
Suite: suite,
Data: secret,
})
}
func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICSetWriteSecret,
Level: level,
Suite: suite,
Data: secret,
})
}
func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) {
var last *QUICEvent
if len(c.quic.events) > 0 {
last = &c.quic.events[len(c.quic.events)-1]
}
if last == nil || last.Kind != QUICWriteData || last.Level != level {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICWriteData,
Level: level,
})
last = &c.quic.events[len(c.quic.events)-1]
}
last.Data = append(last.Data, data...)
}
func (c *Conn) quicSetTransportParameters(params []byte) {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICTransportParameters,
Data: params,
})
}
func (c *Conn) quicGetTransportParameters() ([]byte, error) {
if c.quic.transportParams == nil {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICTransportParametersRequired,
})
}
for c.quic.transportParams == nil {
if err := c.quicWaitForSignal(); err != nil {
return nil, err
}
}
return c.quic.transportParams, nil
}
func (c *Conn) quicHandshakeComplete() {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICHandshakeDone,
})
}
func (c *Conn) quicRejectedEarlyData() {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICRejectedEarlyData,
})
}
// quicWaitForSignal notifies the QUICConn that handshake progress is blocked,
// and waits for a signal that the handshake should proceed.
//
// The handshake may become blocked waiting for handshake bytes
// or for the user to provide transport parameters.
func (c *Conn) quicWaitForSignal() error {
// Drop the handshake mutex while blocked to allow the user
// to call ConnectionState before the handshake completes.
c.handshakeMutex.Unlock()
defer c.handshakeMutex.Lock()
// Send on blockedc to notify the QUICConn that the handshake is blocked.
// Exported methods of QUICConn wait for the handshake to become blocked
// before returning to the user.
select {
case c.quic.blockedc <- struct{}{}:
case <-c.quic.cancelc:
return c.sendAlertLocked(alertCloseNotify)
}
// The QUICConn reads from signalc to notify us that the handshake may
// be able to proceed. (The QUICConn reads, because we close signalc to
// indicate that the handshake has completed.)
select {
case c.quic.signalc <- struct{}{}:
c.hand.Write(c.quic.readbuf)
c.quic.readbuf = nil
case <-c.quic.cancelc:
return c.sendAlertLocked(alertCloseNotify)
}
return nil
}

203
vendor/github.com/quic-go/qtls-go1-20/ticket.go generated vendored Normal file
View File

@@ -0,0 +1,203 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"errors"
"golang.org/x/crypto/cryptobyte"
"io"
)
// sessionState contains the information that is serialized into a session
// ticket in order to later resume a connection.
type sessionState struct {
vers uint16
cipherSuite uint16
createdAt uint64
masterSecret []byte // opaque master_secret<1..2^16-1>;
// struct { opaque certificate<1..2^24-1> } Certificate;
certificates [][]byte // Certificate certificate_list<0..2^24-1>;
// usedOldKey is true if the ticket from which this session came from
// was encrypted with an older key and thus should be refreshed.
usedOldKey bool
}
func (m *sessionState) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(m.vers)
b.AddUint16(m.cipherSuite)
addUint64(&b, m.createdAt)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.masterSecret)
})
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
for _, cert := range m.certificates {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(cert)
})
}
})
return b.Bytes()
}
func (m *sessionState) unmarshal(data []byte) bool {
*m = sessionState{usedOldKey: m.usedOldKey}
s := cryptobyte.String(data)
if ok := s.ReadUint16(&m.vers) &&
s.ReadUint16(&m.cipherSuite) &&
readUint64(&s, &m.createdAt) &&
readUint16LengthPrefixed(&s, &m.masterSecret) &&
len(m.masterSecret) != 0; !ok {
return false
}
var certList cryptobyte.String
if !s.ReadUint24LengthPrefixed(&certList) {
return false
}
for !certList.Empty() {
var cert []byte
if !readUint24LengthPrefixed(&certList, &cert) {
return false
}
m.certificates = append(m.certificates, cert)
}
return s.Empty()
}
// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first
// version (revision = 0) doesn't carry any of the information needed for 0-RTT
// validation and the nonce is always empty.
// version (revision = 1) carries the max_early_data_size sent in the ticket.
// version (revision = 2) carries the ALPN sent in the ticket.
type sessionStateTLS13 struct {
// uint8 version = 0x0304;
// uint8 revision = 2;
cipherSuite uint16
createdAt uint64
resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>;
certificate Certificate // CertificateEntry certificate_list<0..2^24-1>;
maxEarlyData uint32
alpn string
appData []byte
}
func (m *sessionStateTLS13) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(VersionTLS13)
b.AddUint8(2) // revision
b.AddUint16(m.cipherSuite)
addUint64(&b, m.createdAt)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.resumptionSecret)
})
marshalCertificate(&b, m.certificate)
b.AddUint32(m.maxEarlyData)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte(m.alpn))
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.appData)
})
return b.Bytes()
}
func (m *sessionStateTLS13) unmarshal(data []byte) bool {
*m = sessionStateTLS13{}
s := cryptobyte.String(data)
var version uint16
var revision uint8
var alpn []byte
ret := s.ReadUint16(&version) &&
version == VersionTLS13 &&
s.ReadUint8(&revision) &&
revision == 2 &&
s.ReadUint16(&m.cipherSuite) &&
readUint64(&s, &m.createdAt) &&
readUint8LengthPrefixed(&s, &m.resumptionSecret) &&
len(m.resumptionSecret) != 0 &&
unmarshalCertificate(&s, &m.certificate) &&
s.ReadUint32(&m.maxEarlyData) &&
readUint8LengthPrefixed(&s, &alpn) &&
readUint16LengthPrefixed(&s, &m.appData) &&
s.Empty()
m.alpn = string(alpn)
return ret
}
func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
if len(c.ticketKeys) == 0 {
return nil, errors.New("tls: internal error: session ticket keys unavailable")
}
encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size)
keyName := encrypted[:ticketKeyNameLen]
iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
macBytes := encrypted[len(encrypted)-sha256.Size:]
if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
return nil, err
}
key := c.ticketKeys[0]
copy(keyName, key.keyName[:])
block, err := aes.NewCipher(key.aesKey[:])
if err != nil {
return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
}
cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state)
mac := hmac.New(sha256.New, key.hmacKey[:])
mac.Write(encrypted[:len(encrypted)-sha256.Size])
mac.Sum(macBytes[:0])
return encrypted, nil
}
func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) {
if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
return nil, false
}
keyName := encrypted[:ticketKeyNameLen]
iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
macBytes := encrypted[len(encrypted)-sha256.Size:]
ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
keyIndex := -1
for i, candidateKey := range c.ticketKeys {
if bytes.Equal(keyName, candidateKey.keyName[:]) {
keyIndex = i
break
}
}
if keyIndex == -1 {
return nil, false
}
key := &c.ticketKeys[keyIndex]
mac := hmac.New(sha256.New, key.hmacKey[:])
mac.Write(encrypted[:len(encrypted)-sha256.Size])
expected := mac.Sum(nil)
if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
return nil, false
}
block, err := aes.NewCipher(key.aesKey[:])
if err != nil {
return nil, false
}
plaintext = make([]byte, len(ciphertext))
cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
return plaintext, keyIndex > 0
}

356
vendor/github.com/quic-go/qtls-go1-20/tls.go generated vendored Normal file
View File

@@ -0,0 +1,356 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// package qtls partially implements TLS 1.2, as specified in RFC 5246,
// and TLS 1.3, as specified in RFC 8446.
package qtls
// BUG(agl): The crypto/tls package only implements some countermeasures
// against Lucky13 attacks on CBC-mode encryption, and only on SHA1
// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and
// https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net"
"os"
"strings"
)
// Server returns a new TLS server side connection
// using conn as the underlying transport.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Server(conn net.Conn, config *Config) *Conn {
c := &Conn{
conn: conn,
config: fromConfig(config),
}
c.handshakeFn = c.serverHandshake
return c
}
// Client returns a new TLS client side connection
// using conn as the underlying transport.
// The config cannot be nil: users must set either ServerName or
// InsecureSkipVerify in the config.
func Client(conn net.Conn, config *Config) *Conn {
c := &Conn{
conn: conn,
config: fromConfig(config),
isClient: true,
}
c.handshakeFn = c.clientHandshake
return c
}
// A listener implements a network listener (net.Listener) for TLS connections.
type listener struct {
net.Listener
config *Config
}
// Accept waits for and returns the next incoming TLS connection.
// The returned connection is of type *Conn.
func (l *listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return Server(c, l.config), nil
}
// NewListener creates a Listener which accepts connections from an inner
// Listener and wraps each connection with Server.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func NewListener(inner net.Listener, config *Config) net.Listener {
l := new(listener)
l.Listener = inner
l.config = config
return l
}
// Listen creates a TLS listener accepting connections on the
// given network address using net.Listen.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Listen(network, laddr string, config *Config) (net.Listener, error) {
if config == nil || len(config.Certificates) == 0 &&
config.GetCertificate == nil && config.GetConfigForClient == nil {
return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
}
l, err := net.Listen(network, laddr)
if err != nil {
return nil, err
}
return NewListener(l, config), nil
}
type timeoutError struct{}
func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true }
// DialWithDialer connects to the given network address using dialer.Dial and
// then initiates a TLS handshake, returning the resulting TLS connection. Any
// timeout or deadline given in the dialer apply to connection and TLS
// handshake as a whole.
//
// DialWithDialer interprets a nil configuration as equivalent to the zero
// configuration; see the documentation of Config for the defaults.
//
// DialWithDialer uses context.Background internally; to specify the context,
// use Dialer.DialContext with NetDialer set to the desired dialer.
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
return dial(context.Background(), dialer, network, addr, config)
}
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
if netDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
defer cancel()
}
if !netDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
defer cancel()
}
rawConn, err := netDialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
if config == nil {
config = defaultConfig()
}
// If no ServerName is set, infer the ServerName
// from the hostname we're connecting to.
if config.ServerName == "" {
// Make a copy to avoid polluting argument or default.
c := config.Clone()
c.ServerName = hostname
config = c
}
conn := Client(rawConn, config)
if err := conn.HandshakeContext(ctx); err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}
// Dial connects to the given network address using net.Dial
// and then initiates a TLS handshake, returning the resulting
// TLS connection.
// Dial interprets a nil configuration as equivalent to
// the zero configuration; see the documentation of Config
// for the defaults.
func Dial(network, addr string, config *Config) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, config)
}
// Dialer dials TLS connections given a configuration and a Dialer for the
// underlying connection.
type Dialer struct {
// NetDialer is the optional dialer to use for the TLS connections'
// underlying TCP connections.
// A nil NetDialer is equivalent to the net.Dialer zero value.
NetDialer *net.Dialer
// Config is the TLS configuration to use for new connections.
// A nil configuration is equivalent to the zero
// configuration; see the documentation of Config for the
// defaults.
Config *Config
}
// Dial connects to the given network address and initiates a TLS
// handshake, returning the resulting TLS connection.
//
// The returned Conn, if any, will always be of type *Conn.
//
// Dial uses context.Background internally; to specify the context,
// use DialContext.
func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
return d.DialContext(context.Background(), network, addr)
}
func (d *Dialer) netDialer() *net.Dialer {
if d.NetDialer != nil {
return d.NetDialer
}
return new(net.Dialer)
}
// DialContext connects to the given network address and initiates a TLS
// handshake, returning the resulting TLS connection.
//
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
//
// The returned Conn, if any, will always be of type *Conn.
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
if err != nil {
// Don't return c (a typed nil) in an interface.
return nil, err
}
return c, nil
}
// LoadX509KeyPair reads and parses a public/private key pair from a pair
// of files. The files must contain PEM encoded data. The certificate file
// may contain intermediate certificates following the leaf certificate to
// form a certificate chain. On successful return, Certificate.Leaf will
// be nil because the parsed form of the certificate is not retained.
func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
certPEMBlock, err := os.ReadFile(certFile)
if err != nil {
return Certificate{}, err
}
keyPEMBlock, err := os.ReadFile(keyFile)
if err != nil {
return Certificate{}, err
}
return X509KeyPair(certPEMBlock, keyPEMBlock)
}
// X509KeyPair parses a public/private key pair from a pair of
// PEM encoded data. On successful return, Certificate.Leaf will be nil because
// the parsed form of the certificate is not retained.
func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
fail := func(err error) (Certificate, error) { return Certificate{}, err }
var cert Certificate
var skippedBlockTypes []string
for {
var certDERBlock *pem.Block
certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
if certDERBlock == nil {
break
}
if certDERBlock.Type == "CERTIFICATE" {
cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
} else {
skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
}
}
if len(cert.Certificate) == 0 {
if len(skippedBlockTypes) == 0 {
return fail(errors.New("tls: failed to find any PEM data in certificate input"))
}
if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
}
return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
}
skippedBlockTypes = skippedBlockTypes[:0]
var keyDERBlock *pem.Block
for {
keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
if keyDERBlock == nil {
if len(skippedBlockTypes) == 0 {
return fail(errors.New("tls: failed to find any PEM data in key input"))
}
if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
}
return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
}
if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
break
}
skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
}
// We don't need to parse the public key for TLS, but we so do anyway
// to check that it looks sane and matches the private key.
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return fail(err)
}
cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
if err != nil {
return fail(err)
}
switch pub := x509Cert.PublicKey.(type) {
case *rsa.PublicKey:
priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if pub.N.Cmp(priv.N) != 0 {
return fail(errors.New("tls: private key does not match public key"))
}
case *ecdsa.PublicKey:
priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
return fail(errors.New("tls: private key does not match public key"))
}
case ed25519.PublicKey:
priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) {
return fail(errors.New("tls: private key does not match public key"))
}
default:
return fail(errors.New("tls: unknown public key algorithm"))
}
return cert, nil
}
// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys.
// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
return key, nil
}
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
switch key := key.(type) {
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
return key, nil
default:
return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
}
}
if key, err := x509.ParseECPrivateKey(der); err == nil {
return key, nil
}
return nil, errors.New("tls: failed to parse private key")
}

101
vendor/github.com/quic-go/qtls-go1-20/unsafe.go generated vendored Normal file
View File

@@ -0,0 +1,101 @@
package qtls
import (
"crypto/tls"
"reflect"
"unsafe"
)
func init() {
if !structsEqual(&tls.ConnectionState{}, &connectionState{}) {
panic("qtls.ConnectionState doesn't match")
}
if !structsEqual(&tls.ClientSessionState{}, &clientSessionState{}) {
panic("qtls.ClientSessionState doesn't match")
}
if !structsEqual(&tls.CertificateRequestInfo{}, &certificateRequestInfo{}) {
panic("qtls.CertificateRequestInfo doesn't match")
}
if !structsEqual(&tls.Config{}, &config{}) {
panic("qtls.Config doesn't match")
}
if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) {
panic("qtls.ClientHelloInfo doesn't match")
}
}
func toConnectionState(c connectionState) ConnectionState {
return *(*ConnectionState)(unsafe.Pointer(&c))
}
func toClientSessionState(s *clientSessionState) *ClientSessionState {
return (*ClientSessionState)(unsafe.Pointer(s))
}
func fromClientSessionState(s *ClientSessionState) *clientSessionState {
return (*clientSessionState)(unsafe.Pointer(s))
}
func toCertificateRequestInfo(i *certificateRequestInfo) *CertificateRequestInfo {
return (*CertificateRequestInfo)(unsafe.Pointer(i))
}
func toConfig(c *config) *Config {
return (*Config)(unsafe.Pointer(c))
}
func fromConfig(c *Config) *config {
return (*config)(unsafe.Pointer(c))
}
func toClientHelloInfo(chi *clientHelloInfo) *ClientHelloInfo {
return (*ClientHelloInfo)(unsafe.Pointer(chi))
}
func structsEqual(a, b interface{}) bool {
return compare(reflect.ValueOf(a), reflect.ValueOf(b))
}
func compare(a, b reflect.Value) bool {
sa := a.Elem()
sb := b.Elem()
if sa.NumField() != sb.NumField() {
return false
}
for i := 0; i < sa.NumField(); i++ {
fa := sa.Type().Field(i)
fb := sb.Type().Field(i)
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
if fa.Type.Kind() != fb.Type.Kind() {
return false
}
if fa.Type.Kind() == reflect.Slice {
if !compareStruct(fa.Type.Elem(), fb.Type.Elem()) {
return false
}
continue
}
return false
}
}
return true
}
func compareStruct(a, b reflect.Type) bool {
if a.NumField() != b.NumField() {
return false
}
for i := 0; i < a.NumField(); i++ {
fa := a.Field(i)
fb := b.Field(i)
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
return false
}
}
return true
}
// InitSessionTicketKeys triggers the initialization of session ticket keys.
func InitSessionTicketKeys(conf *Config) {
fromConfig(conf).ticketKeys(nil)
}

17
vendor/github.com/quic-go/quic-go/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,17 @@
debug
debug.test
main
mockgen_tmp.go
*.qtr
*.qlog
*.txt
race.[0-9]*
fuzzing/*/*.zip
fuzzing/*/coverprofile
fuzzing/*/crashers
fuzzing/*/sonarprofile
fuzzing/*/suppressions
fuzzing/*/corpus/
gomock_reflect_*/

44
vendor/github.com/quic-go/quic-go/.golangci.yml generated vendored Normal file
View File

@@ -0,0 +1,44 @@
run:
skip-files:
- internal/handshake/cipher_suite.go
linters-settings:
depguard:
type: blacklist
packages:
- github.com/marten-seemann/qtls
- github.com/quic-go/qtls-go1-19
- github.com/quic-go/qtls-go1-20
packages-with-error-message:
- github.com/marten-seemann/qtls: "importing qtls only allowed in internal/qtls"
- github.com/quic-go/qtls-go1-19: "importing qtls only allowed in internal/qtls"
- github.com/quic-go/qtls-go1-20: "importing qtls only allowed in internal/qtls"
misspell:
ignore-words:
- ect
linters:
disable-all: true
enable:
- asciicheck
- depguard
- exhaustive
- exportloopref
- goimports
- gofmt # redundant, since gofmt *should* be a no-op after gofumpt
- gofumpt
- gosimple
- ineffassign
- misspell
- prealloc
- staticcheck
- stylecheck
- unconvert
- unparam
- unused
- vet
issues:
exclude-rules:
- path: internal/qtls
linters:
- depguard

109
vendor/github.com/quic-go/quic-go/Changelog.md generated vendored Normal file
View File

@@ -0,0 +1,109 @@
# Changelog
## v0.22.0 (2021-07-25)
- Use `ReadBatch` to read multiple UDP packets from the socket with a single syscall
- Add a config option (`Config.DisableVersionNegotiationPackets`) to disable sending of Version Negotiation packets
- Drop support for QUIC draft versions 32 and 34
- Remove the `RetireBugBackwardsCompatibilityMode`, which was intended to mitigate a bug when retiring connection IDs in quic-go in v0.17.2 and ealier
## v0.21.2 (2021-07-15)
- Update qtls (for Go 1.15, 1.16 and 1.17rc1) to include the fix for the crypto/tls panic (see https://groups.google.com/g/golang-dev/c/5LJ2V7rd-Ag/m/YGLHVBZ6AAAJ for details)
## v0.21.0 (2021-06-01)
- quic-go now supports RFC 9000!
## v0.20.0 (2021-03-19)
- Remove the `quic.Config.HandshakeTimeout`. Introduce a `quic.Config.HandshakeIdleTimeout`.
## v0.17.1 (2020-06-20)
- Supports QUIC WG draft-29.
- Improve bundling of ACK frames (#2543).
## v0.16.0 (2020-05-31)
- Supports QUIC WG draft-28.
## v0.15.0 (2020-03-01)
- Supports QUIC WG draft-27.
- Add support for 0-RTT.
- Remove `Session.Close()`. Applications need to pass an application error code to the transport using `Session.CloseWithError()`.
- Make the TLS Cipher Suites configurable (via `tls.Config.CipherSuites`).
## v0.14.0 (2019-12-04)
- Supports QUIC WG draft-24.
## v0.13.0 (2019-11-05)
- Supports QUIC WG draft-23.
- Add an `EarlyListener` that allows sending of 0.5-RTT data.
- Add a `TokenStore` to store address validation tokens.
- Issue and use new connection IDs during a connection.
## v0.12.0 (2019-08-05)
- Implement HTTP/3.
- Rename `quic.Cookie` to `quic.Token` and `quic.Config.AcceptCookie` to `quic.Config.AcceptToken`.
- Distinguish between Retry tokens and tokens sent in NEW_TOKEN frames.
- Enforce application protocol negotiation (via `tls.Config.NextProtos`).
- Use a varint for error codes.
- Add support for [quic-trace](https://github.com/google/quic-trace).
- Add a context to `Listener.Accept`, `Session.Accept{Uni}Stream` and `Session.Open{Uni}StreamSync`.
- Implement TLS key updates.
## v0.11.0 (2019-04-05)
- Drop support for gQUIC. For qQUIC support, please switch to the *gquic* branch.
- Implement QUIC WG draft-19.
- Use [qtls](https://github.com/marten-seemann/qtls) for TLS 1.3.
- Return a `tls.ConnectionState` from `quic.Session.ConnectionState()`.
- Remove the error return values from `quic.Stream.CancelRead()` and `quic.Stream.CancelWrite()`
## v0.10.0 (2018-08-28)
- Add support for QUIC 44, drop support for QUIC 42.
## v0.9.0 (2018-08-15)
- Add a `quic.Config` option for the length of the connection ID (for IETF QUIC).
- Split Session.Close into one method for regular closing and one for closing with an error.
## v0.8.0 (2018-06-26)
- Add support for unidirectional streams (for IETF QUIC).
- Add a `quic.Config` option for the maximum number of incoming streams.
- Add support for QUIC 42 and 43.
- Add dial functions that use a context.
- Multiplex clients on a net.PacketConn, when using Dial(conn).
## v0.7.0 (2018-02-03)
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
- Expose the `ConnectionState` in the `Session` (experimental API).
- Implement packet pacing.
## v0.6.0 (2017-12-12)
- Add support for QUIC 39, drop support for QUIC 35 - 37
- Added `quic.Config` options for maximal flow control windows
- Add a `quic.Config` option for QUIC versions
- Add a `quic.Config` option to request omission of the connection ID from a server
- Add a `quic.Config` option to configure the source address validation
- Add a `quic.Config` option to configure the handshake timeout
- Add a `quic.Config` option to configure the idle timeout
- Add a `quic.Config` option to configure keep-alive
- Rename the STK to Cookie
- Implement `net.Conn`-style deadlines for streams
- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/quic-go/quic-go) for details.
- Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/quic-go/quic-go/wiki/Logging) for more details.
- Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper`
- Changed `h2quic.Server.Serve()` to accept a `net.PacketConn`
- Drop support for Go 1.7 and 1.8.
- Various bugfixes

21
vendor/github.com/quic-go/quic-go/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2016 the quic-go authors & Google, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

230
vendor/github.com/quic-go/quic-go/README.md generated vendored Normal file
View File

@@ -0,0 +1,230 @@
# A QUIC implementation in pure Go
<img src="docs/quic.png" width=303 height=124>
[![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go)](https://pkg.go.dev/github.com/quic-go/quic-go)
[![Code Coverage](https://img.shields.io/codecov/c/github/quic-go/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/quic-go/quic-go/)
[![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/quic-go.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:quic-go)
quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)).
In addition to these base RFCs, it also implements the following RFCs:
* Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221))
* Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899))
* QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369))
## Using QUIC
### Running a Server
The central entry point is the `quic.Transport`. A transport manages QUIC connections running on a single UDP socket. Since QUIC uses Connection IDs, it can demultiplex a listener (accepting incoming connections) and an arbitrary number of outgoing QUIC connections on the same UDP socket.
```go
udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: 1234})
// ... error handling
tr := quic.Transport{
Conn: udpConn,
}
ln, err := tr.Listen(tlsConf, quicConf)
// ... error handling
go func() {
for {
conn, err := ln.Accept()
// ... error handling
// handle the connection, usually in a new Go routine
}
}()
```
The listener `ln` can now be used to accept incoming QUIC connections by (repeatedly) calling the `Accept` method (see below for more information on the `quic.Connection`).
As a shortcut, `quic.Listen` and `quic.ListenAddr` can be used without explicitly initializing a `quic.Transport`:
```
ln, err := quic.Listen(udpConn, tlsConf, quicConf)
```
When using the shortcut, it's not possible to reuse the same UDP socket for outgoing connections.
### Running a Client
As mentioned above, multiple outgoing connections can share a single UDP socket, since QUIC uses Connection IDs to demultiplex connections.
```go
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
defer cancel()
conn, err := tr.Dial(ctx, <server address>, <tls.Config>, <quic.Config>)
// ... error handling
```
As a shortcut, `quic.Dial` and `quic.DialAddr` can be used without explictly initializing a `quic.Transport`:
```go
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
defer cancel()
conn, err := quic.Dial(ctx, conn, <server address>, <tls.Config>, <quic.Config>)
```
Just as we saw before when used a similar shortcut to run a server, it's also not possible to reuse the same UDP socket for other outgoing connections, or to listen for incoming connections.
### Using a QUIC Connection
#### Accepting Streams
QUIC is a stream-multiplexed transport. A `quic.Connection` fundamentally differs from the `net.Conn` and the `net.PacketConn` interface defined in the standard library. Data is sent and received on (unidirectional and bidirectional) streams (and, if supported, in [datagrams](#quic-datagrams)), not on the connection itself. The stream state machine is described in detail in [Section 3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-3).
Note: A unidirectional stream is a stream that the initiator can only write to (`quic.SendStream`), and the receiver can only read from (`quic.ReceiveStream`). A bidirectional stream (`quic.Stream`) allows reading from and writing to for both sides.
On the receiver side, streams are accepted using the `AcceptStream` (for bidirectional) and `AcceptUniStream` functions. For most user cases, it makes sense to call these functions in a loop:
```go
for {
str, err := conn.AcceptStream(context.Background()) // for bidirectional streams
// ... error handling
// handle the stream, usually in a new Go routine
}
```
These functions return an error when the underlying QUIC connection is closed.
#### Opening Streams
There are two slightly different ways to open streams, one synchronous and one (potentially) asynchronous. This API is necessary since the receiver grants us a certain number of streams that we're allowed to open. It may grant us additional streams later on (typically when existing streams are closed), but it means that at the time we want to open a new stream, we might not be able to do so.
Using the synchronous method `OpenStreamSync` for bidirectional streams, and `OpenUniStreamSync` for unidirectional streams, an application can block until the peer allows opening additional streams. In case that we're allowed to open a new stream, these methods return right away:
```go
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
str, err := conn.OpenStreamSync(ctx) // wait up to 5s to open a new bidirectional stream
```
The asynchronous version never blocks. If it's currently not possible to open a new stream, it returns a `net.Error` timeout error:
```go
str, err := conn.OpenStream()
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
// It's currently not possible to open another stream,
// but it might be possible later, once the peer allowed us to do so.
}
```
These functions return an error when the underlying QUIC connection is closed.
#### Using Streams
Using QUIC streams is pretty straightforward. The `quic.ReceiveStream` implements the `io.Reader` interface, and the `quic.SendStream` implements the `io.Writer` interface. A bidirectional stream (`quic.Stream`) implements both these interfaces. Conceptually, a bidirectional stream can be thought of as the composition of two unidirectional streams in opposite directions.
Calling `Close` on a `quic.SendStream` or a `quic.Stream` closes the send side of the stream. On the receiver side, this will be surfaced as an `io.EOF` returned from the `io.Reader` once all data has been consumed. Note that for bidirectional streams, `Close` _only_ closes the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
In case the application wishes to abort sending on a `quic.SendStream` or a `quic.Stream` , it can reset the send side by calling `CancelWrite` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Reader`. Note that for bidirectional streams, `CancelWrite` _only_ resets the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
Conversely, in case the application wishes to abort receiving from a `quic.ReceiveStream` or a `quic.Stream`, it can ask the sender to abort data transmission by calling `CancelRead` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Writer`. Note that for bidirectional streams, `CancelWrite` _only_ resets the receive side of the stream. It is still possible to write to the stream.
A bidirectional stream is only closed once both the read and the write side of the stream have been either closed and reset. Only then the peer is granted a new stream according to the maximum number of concurrent streams configured via `quic.Config.MaxIncomingStreams`.
### Configuring QUIC
The `quic.Config` struct passed to both the listen and dial calls (see above) contains a wide range of configuration options for QUIC connections, incl. the ability to fine-tune flow control limits, the number of streams that the peer is allowed to open concurrently, keep-alives, idle timeouts, and many more. Please refer to the documentation for the `quic.Config` for details.
The `quic.Transport` contains a few configuration options that don't apply to any single QUIC connection, but to all connections handled by that transport. It is highly recommend to set the `StatelessResetToken`, which allows endpoints to quickly recover from crashes / reboots of our node (see [Section 10.3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-10.3)).
### Closing a Connection
#### When the remote Peer closes the Connection
In case the peer closes the QUIC connection, all calls to open streams, accept streams, as well as all methods on streams immediately return an error. Additionally, it is set as cancellation cause of the connection context. Users can use errors assertions to find out what exactly went wrong:
* `quic.VersionNegotiationError`: Happens during the handshake, if there is no overlap between our and the remote's supported QUIC versions.
* `quic.HandshakeTimeoutError`: Happens if the QUIC handshake doesn't complete within the time specified in `quic.Config.HandshakeTimeout`.
* `quic.IdleTimeoutError`: Happens after completion of the handshake if the connection is idle for longer than the minimum of both peers idle timeouts (as configured by `quic.Config.IdleTimeout`). The connection is considered idle when no stream data (and datagrams, if applicable) are exchanged for that period. The QUIC connection can be instructed to regularly send a packet to prevent a connection from going idle by setting `quic.Config.KeepAlive`. However, this is no guarantee that the peer doesn't suddenly go away (e.g. by abruptly shutting down the node or by crashing), or by a NAT binding expiring, in which case this error might still occur.
* `quic.StatelessResetError`: Happens when the remote peer lost the state required to decrypt the packet. This requires the `quic.Transport.StatelessResetToken` to be configured by the peer.
* `quic.TransportError`: Happens if when the QUIC protocol is violated. Unless the error code is `APPLICATION_ERROR`, this will not happen unless one of the QUIC stacks involved is misbehaving. Please open an issue if you encounter this error.
* `quic.ApplicationError`: Happens when the remote decides to close the connection, see below.
#### Initiated by the Application
A `quic.Connection` can be closed using `CloseWithError`:
```go
conn.CloseWithError(0x42, "error 0x42 occurred")
```
Applications can transmit both an error code (an unsigned 62-bit number) as well as a UTF-8 encoded human-readable reason. The error code allows the receiver to learn why the connection was closed, and the reason can be useful for debugging purposes.
On the receiver side, this is surfaced as a `quic.ApplicationError`.
### QUIC Datagrams
Unreliable datagrams are a QUIC extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) that is negotiated during the handshake. Support can be enabled by setting the `quic.Config.EnableDatagram` flag. Note that this doesn't guarantee that the peer also supports datagrams. Whether or not the feature negotiation succeeded can be learned from the `quic.ConnectionState.SupportsDatagrams` obtained from `quic.Connection.ConnectionState()`.
QUIC DATAGRAMs are a new QUIC frame type sent in QUIC 1-RTT packets (i.e. after completion of the handshake). Therefore, they're end-to-end encrypted and congestion-controlled. However, if a DATAGRAM frame is deemed lost by QUIC's loss detection mechanism, they are not automatically retransmitted.
Datagrams are sent using the `SendMessage` method on the `quic.Connection`:
```go
conn.SendMessage([]byte("foobar"))
```
And received using `ReceiveMessage`:
```go
msg, err := conn.ReceiveMessage()
```
Note that this code path is currently not optimized. It works for datagrams that are sent occasionally, but it doesn't achieve the same throughput as writing data on a stream. Please get in touch on issue #3766 if your use case relies on high datagram throughput, or if you'd like to help fix this issue. There are also some restrictions regarding the maximum message size (see #3599).
## Using HTTP/3
### As a server
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard library http package in Go:
```go
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
http3.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil)
```
### As a client
See the [example client](example/client/main.go). Use a `http3.RoundTripper` as a `Transport` in a `http.Client`.
```go
http.Client{
Transport: &http3.RoundTripper{},
}
```
## Projects using quic-go
| Project | Description | Stars |
| --------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
| [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) |
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) |
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) |
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) |
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) |
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) |
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) |
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) |
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) |
| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) |
| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) |
| [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System | ![GitHub Repo stars](https://img.shields.io/github/stars/yomorun/yomo?style=flat-square) |
If you'd like to see your project added to this list, please send us a PR.
## Release Policy
quic-go always aims to support the latest two Go releases.
### Dependency on forked crypto/tls
Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20).
This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward.
## Contributing
We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/quic-go/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment.

19
vendor/github.com/quic-go/quic-go/SECURITY.md generated vendored Normal file
View File

@@ -0,0 +1,19 @@
# Security Policy
quic-go still in development. This means that there may be problems in our protocols,
or there may be mistakes in our implementations.
We take security vulnerabilities very seriously. If you discover a security issue,
please bring it to our attention right away!
## Reporting a Vulnerability
If you find a vulnerability that may affect live deployments -- for example, by exposing
a remote execution exploit -- please [**report privately**](https://github.com/quic-go/quic-go/security/advisories/new).
Please **DO NOT file a public issue**.
If the issue is an implementation weakness that cannot be immediately exploited or
something not yet deployed, just discuss it openly.
## Reporting a non security bug
For non-security bugs, please simply file a GitHub [issue](https://github.com/quic-go/quic-go/issues/new).

92
vendor/github.com/quic-go/quic-go/buffer_pool.go generated vendored Normal file
View File

@@ -0,0 +1,92 @@
package quic
import (
"sync"
"github.com/quic-go/quic-go/internal/protocol"
)
type packetBuffer struct {
Data []byte
// refCount counts how many packets Data is used in.
// It doesn't support concurrent use.
// It is > 1 when used for coalesced packet.
refCount int
}
// Split increases the refCount.
// It must be called when a packet buffer is used for more than one packet,
// e.g. when splitting coalesced packets.
func (b *packetBuffer) Split() {
b.refCount++
}
// Decrement decrements the reference counter.
// It doesn't put the buffer back into the pool.
func (b *packetBuffer) Decrement() {
b.refCount--
if b.refCount < 0 {
panic("negative packetBuffer refCount")
}
}
// MaybeRelease puts the packet buffer back into the pool,
// if the reference counter already reached 0.
func (b *packetBuffer) MaybeRelease() {
// only put the packetBuffer back if it's not used any more
if b.refCount == 0 {
b.putBack()
}
}
// Release puts back the packet buffer into the pool.
// It should be called when processing is definitely finished.
func (b *packetBuffer) Release() {
b.Decrement()
if b.refCount != 0 {
panic("packetBuffer refCount not zero")
}
b.putBack()
}
// Len returns the length of Data
func (b *packetBuffer) Len() protocol.ByteCount { return protocol.ByteCount(len(b.Data)) }
func (b *packetBuffer) Cap() protocol.ByteCount { return protocol.ByteCount(cap(b.Data)) }
func (b *packetBuffer) putBack() {
if cap(b.Data) == protocol.MaxPacketBufferSize {
bufferPool.Put(b)
return
}
if cap(b.Data) == protocol.MaxLargePacketBufferSize {
largeBufferPool.Put(b)
return
}
panic("putPacketBuffer called with packet of wrong size!")
}
var bufferPool, largeBufferPool sync.Pool
func getPacketBuffer() *packetBuffer {
buf := bufferPool.Get().(*packetBuffer)
buf.refCount = 1
buf.Data = buf.Data[:0]
return buf
}
func getLargePacketBuffer() *packetBuffer {
buf := largeBufferPool.Get().(*packetBuffer)
buf.refCount = 1
buf.Data = buf.Data[:0]
return buf
}
func init() {
bufferPool.New = func() any {
return &packetBuffer{Data: make([]byte, 0, protocol.MaxPacketBufferSize)}
}
largeBufferPool.New = func() any {
return &packetBuffer{Data: make([]byte, 0, protocol.MaxLargePacketBufferSize)}
}
}

251
vendor/github.com/quic-go/quic-go/client.go generated vendored Normal file
View File

@@ -0,0 +1,251 @@
package quic
import (
"context"
"crypto/tls"
"errors"
"net"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
type client struct {
sendConn sendConn
use0RTT bool
packetHandlers packetHandlerManager
onClose func()
tlsConf *tls.Config
config *Config
connIDGenerator ConnectionIDGenerator
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
initialPacketNumber protocol.PacketNumber
hasNegotiatedVersion bool
version protocol.VersionNumber
handshakeChan chan struct{}
conn quicConn
tracer *logging.ConnectionTracer
tracingID uint64
logger utils.Logger
}
// make it possible to mock connection ID for initial generation in the tests
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
// DialAddr establishes a new QUIC connection to a server.
// It resolves the address, and then creates a new UDP connection to dial the QUIC server.
// When the QUIC connection is closed, this UDP connection is closed.
// See Dial for more details.
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
tr, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
return tr.dial(ctx, udpAddr, addr, tlsConf, conf, false)
}
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
// See DialAddr for more details.
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
tr, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, true)
if err != nil {
tr.Close()
return nil, err
}
return conn, nil
}
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
// See Dial for more details.
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
return nil, err
}
conn, err := dl.DialEarly(ctx, addr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does),
// ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
// will be used instead of ReadFrom and WriteTo to read/write packets.
// The tls.Config must define an application protocol (using NextProtos).
//
// This is a convenience function. More advanced use cases should instantiate a Transport,
// which offers configuration options for a more fine-grained control of the connection establishment,
// including reusing the underlying UDP socket for multiple QUIC connections.
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
return nil, err
}
conn, err := dl.Dial(ctx, addr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
return &Transport{
Conn: c,
createdConn: createdPacketConn,
isSingleUse: true,
}, nil
}
func dial(
ctx context.Context,
conn sendConn,
connIDGenerator ConnectionIDGenerator,
packetHandlers packetHandlerManager,
tlsConf *tls.Config,
config *Config,
onClose func(),
use0RTT bool,
) (quicConn, error) {
c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
if err != nil {
return nil, err
}
c.packetHandlers = packetHandlers
c.tracingID = nextConnTracingID()
if c.config.Tracer != nil {
c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
}
if c.tracer != nil && c.tracer.StartedConnection != nil {
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
}
if err := c.dial(ctx); err != nil {
return nil, err
}
return c.conn, nil
}
func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
srcConnID, err := connIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
}
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return nil, err
}
c := &client{
connIDGenerator: connIDGenerator,
srcConnID: srcConnID,
destConnID: destConnID,
sendConn: sendConn,
use0RTT: use0RTT,
onClose: onClose,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"),
}
return c, nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.conn = newClientConnection(
c.sendConn,
c.packetHandlers,
c.destConnID,
c.srcConnID,
c.connIDGenerator,
c.config,
c.tlsConf,
c.initialPacketNumber,
c.use0RTT,
c.hasNegotiatedVersion,
c.tracer,
c.tracingID,
c.logger,
c.version,
)
c.packetHandlers.Add(c.srcConnID, c.conn)
errorChan := make(chan error, 1)
recreateChan := make(chan errCloseForRecreating)
go func() {
err := c.conn.run()
var recreateErr *errCloseForRecreating
if errors.As(err, &recreateErr) {
recreateChan <- *recreateErr
return
}
if c.onClose != nil {
c.onClose()
}
errorChan <- err // returns as soon as the connection is closed
}()
// only set when we're using 0-RTT
// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
var earlyConnChan <-chan struct{}
if c.use0RTT {
earlyConnChan = c.conn.earlyConnReady()
}
select {
case <-ctx.Done():
c.conn.shutdown()
return context.Cause(ctx)
case err := <-errorChan:
return err
case recreateErr := <-recreateChan:
c.initialPacketNumber = recreateErr.nextPacketNumber
c.version = recreateErr.nextVersion
c.hasNegotiatedVersion = true
return c.dial(ctx)
case <-earlyConnChan:
// ready to send 0-RTT data
return nil
case <-c.conn.HandshakeComplete():
// handshake successfully completed
return nil
}
}

64
vendor/github.com/quic-go/quic-go/closed_conn.go generated vendored Normal file
View File

@@ -0,0 +1,64 @@
package quic
import (
"math/bits"
"net"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
// A closedLocalConn is a connection that we closed locally.
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
// with an exponential backoff.
type closedLocalConn struct {
counter uint32
perspective protocol.Perspective
logger utils.Logger
sendPacket func(net.Addr, packetInfo)
}
var _ packetHandler = &closedLocalConn{}
// newClosedLocalConn creates a new closedLocalConn and runs it.
func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
return &closedLocalConn{
sendPacket: sendPacket,
perspective: pers,
logger: logger,
}
}
func (c *closedLocalConn) handlePacket(p receivedPacket) {
c.counter++
// exponential backoff
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
if bits.OnesCount32(c.counter) != 1 {
return
}
c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", c.counter)
c.sendPacket(p.remoteAddr, p.info)
}
func (c *closedLocalConn) shutdown() {}
func (c *closedLocalConn) destroy(error) {}
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }
// A closedRemoteConn is a connection that was closed remotely.
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
// We can just ignore those packets.
type closedRemoteConn struct {
perspective protocol.Perspective
}
var _ packetHandler = &closedRemoteConn{}
func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
return &closedRemoteConn{perspective: pers}
}
func (s *closedRemoteConn) handlePacket(receivedPacket) {}
func (s *closedRemoteConn) shutdown() {}
func (s *closedRemoteConn) destroy(error) {}
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }

14
vendor/github.com/quic-go/quic-go/codecov.yml generated vendored Normal file
View File

@@ -0,0 +1,14 @@
coverage:
round: nearest
ignore:
- http3/gzip_reader.go
- interop/
- internal/handshake/cipher_suite.go
- internal/utils/linkedlist/linkedlist.go
- fuzzing/
- metrics/
status:
project:
default:
threshold: 0.5
patch: false

129
vendor/github.com/quic-go/quic-go/config.go generated vendored Normal file
View File

@@ -0,0 +1,129 @@
package quic
import (
"fmt"
"net"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// Clone clones a Config
func (c *Config) Clone() *Config {
copy := *c
return &copy
}
func (c *Config) handshakeTimeout() time.Duration {
return 2 * c.HandshakeIdleTimeout
}
func (c *Config) maxRetryTokenAge() time.Duration {
return c.handshakeTimeout()
}
func validateConfig(config *Config) error {
if config == nil {
return nil
}
const maxStreams = 1 << 60
if config.MaxIncomingStreams > maxStreams {
config.MaxIncomingStreams = maxStreams
}
if config.MaxIncomingUniStreams > maxStreams {
config.MaxIncomingUniStreams = maxStreams
}
if config.MaxStreamReceiveWindow > quicvarint.Max {
config.MaxStreamReceiveWindow = quicvarint.Max
}
if config.MaxConnectionReceiveWindow > quicvarint.Max {
config.MaxConnectionReceiveWindow = quicvarint.Max
}
// check that all QUIC versions are actually supported
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return fmt.Errorf("invalid QUIC version: %s", v)
}
}
return nil
}
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateServerConfig(config *Config) *Config {
config = populateConfig(config)
if config.RequireAddressValidation == nil {
config.RequireAddressValidation = func(net.Addr) bool { return false }
}
return config
}
// populateConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
if config.HandshakeIdleTimeout != 0 {
handshakeIdleTimeout = config.HandshakeIdleTimeout
}
idleTimeout := protocol.DefaultIdleTimeout
if config.MaxIdleTimeout != 0 {
idleTimeout = config.MaxIdleTimeout
}
initialStreamReceiveWindow := config.InitialStreamReceiveWindow
if initialStreamReceiveWindow == 0 {
initialStreamReceiveWindow = protocol.DefaultInitialMaxStreamData
}
maxStreamReceiveWindow := config.MaxStreamReceiveWindow
if maxStreamReceiveWindow == 0 {
maxStreamReceiveWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
}
initialConnectionReceiveWindow := config.InitialConnectionReceiveWindow
if initialConnectionReceiveWindow == 0 {
initialConnectionReceiveWindow = protocol.DefaultInitialMaxData
}
maxConnectionReceiveWindow := config.MaxConnectionReceiveWindow
if maxConnectionReceiveWindow == 0 {
maxConnectionReceiveWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
}
maxIncomingStreams := config.MaxIncomingStreams
if maxIncomingStreams == 0 {
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
} else if maxIncomingStreams < 0 {
maxIncomingStreams = 0
}
maxIncomingUniStreams := config.MaxIncomingUniStreams
if maxIncomingUniStreams == 0 {
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
return &Config{
GetConfigForClient: config.GetConfigForClient,
Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout,
RequireAddressValidation: config.RequireAddressValidation,
KeepAlivePeriod: config.KeepAlivePeriod,
InitialStreamReceiveWindow: initialStreamReceiveWindow,
MaxStreamReceiveWindow: maxStreamReceiveWindow,
InitialConnectionReceiveWindow: initialConnectionReceiveWindow,
MaxConnectionReceiveWindow: maxConnectionReceiveWindow,
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
TokenStore: config.TokenStore,
EnableDatagrams: config.EnableDatagrams,
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
Allow0RTT: config.Allow0RTT,
Tracer: config.Tracer,
}
}

139
vendor/github.com/quic-go/quic-go/conn_id_generator.go generated vendored Normal file
View File

@@ -0,0 +1,139 @@
package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
)
type connIDGenerator struct {
generator ConnectionIDGenerator
highestSeq uint64
activeSrcConnIDs map[uint64]protocol.ConnectionID
initialClientDestConnID *protocol.ConnectionID // nil for the client
addConnectionID func(protocol.ConnectionID)
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
removeConnectionID func(protocol.ConnectionID)
retireConnectionID func(protocol.ConnectionID)
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte)
queueControlFrame func(wire.Frame)
}
func newConnIDGenerator(
initialConnectionID protocol.ConnectionID,
initialClientDestConnID *protocol.ConnectionID, // nil for the client
addConnectionID func(protocol.ConnectionID),
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte),
queueControlFrame func(wire.Frame),
generator ConnectionIDGenerator,
) *connIDGenerator {
m := &connIDGenerator{
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
addConnectionID: addConnectionID,
getStatelessResetToken: getStatelessResetToken,
removeConnectionID: removeConnectionID,
retireConnectionID: retireConnectionID,
replaceWithClosed: replaceWithClosed,
queueControlFrame: queueControlFrame,
}
m.activeSrcConnIDs[0] = initialConnectionID
m.initialClientDestConnID = initialClientDestConnID
return m
}
func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
if m.generator.ConnectionIDLen() == 0 {
return nil
}
// The active_connection_id_limit transport parameter is the number of
// connection IDs the peer will store. This limit includes the connection ID
// used during the handshake, and the one sent in the preferred_address
// transport parameter.
// We currently don't send the preferred_address transport parameter,
// so we can issue (limit - 1) connection IDs.
for i := uint64(len(m.activeSrcConnIDs)); i < utils.Min(limit, protocol.MaxIssuedConnectionIDs); i++ {
if err := m.issueNewConnID(); err != nil {
return err
}
}
return nil
}
func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error {
if seq > m.highestSeq {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq),
}
}
connID, ok := m.activeSrcConnIDs[seq]
// We might already have deleted this connection ID, if this is a duplicate frame.
if !ok {
return nil
}
if connID == sentWithDestConnID {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
}
}
m.retireConnectionID(connID)
delete(m.activeSrcConnIDs, seq)
// Don't issue a replacement for the initial connection ID.
if seq == 0 {
return nil
}
return m.issueNewConnID()
}
func (m *connIDGenerator) issueNewConnID() error {
connID, err := m.generator.GenerateConnectionID()
if err != nil {
return err
}
m.activeSrcConnIDs[m.highestSeq+1] = connID
m.addConnectionID(connID)
m.queueControlFrame(&wire.NewConnectionIDFrame{
SequenceNumber: m.highestSeq + 1,
ConnectionID: connID,
StatelessResetToken: m.getStatelessResetToken(connID),
})
m.highestSeq++
return nil
}
func (m *connIDGenerator) SetHandshakeComplete() {
if m.initialClientDestConnID != nil {
m.retireConnectionID(*m.initialClientDestConnID)
m.initialClientDestConnID = nil
}
}
func (m *connIDGenerator) RemoveAll() {
if m.initialClientDestConnID != nil {
m.removeConnectionID(*m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
m.removeConnectionID(connID)
}
}
func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) {
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
if m.initialClientDestConnID != nil {
connIDs = append(connIDs, *m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
connIDs = append(connIDs, connID)
}
m.replaceWithClosed(connIDs, pers, connClose)
}

214
vendor/github.com/quic-go/quic-go/conn_id_manager.go generated vendored Normal file
View File

@@ -0,0 +1,214 @@
package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
"github.com/quic-go/quic-go/internal/wire"
)
type newConnID struct {
SequenceNumber uint64
ConnectionID protocol.ConnectionID
StatelessResetToken protocol.StatelessResetToken
}
type connIDManager struct {
queue list.List[newConnID]
handshakeComplete bool
activeSequenceNumber uint64
highestRetired uint64
activeConnectionID protocol.ConnectionID
activeStatelessResetToken *protocol.StatelessResetToken
// We change the connection ID after sending on average
// protocol.PacketsPerConnectionID packets. The actual value is randomized
// hide the packet loss rate from on-path observers.
rand utils.Rand
packetsSinceLastChange uint32
packetsPerConnectionID uint32
addStatelessResetToken func(protocol.StatelessResetToken)
removeStatelessResetToken func(protocol.StatelessResetToken)
queueControlFrame func(wire.Frame)
}
func newConnIDManager(
initialDestConnID protocol.ConnectionID,
addStatelessResetToken func(protocol.StatelessResetToken),
removeStatelessResetToken func(protocol.StatelessResetToken),
queueControlFrame func(wire.Frame),
) *connIDManager {
return &connIDManager{
activeConnectionID: initialDestConnID,
addStatelessResetToken: addStatelessResetToken,
removeStatelessResetToken: removeStatelessResetToken,
queueControlFrame: queueControlFrame,
}
}
func (h *connIDManager) AddFromPreferredAddress(connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
return h.addConnectionID(1, connID, resetToken)
}
func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
if err := h.add(f); err != nil {
return err
}
if h.queue.Len() >= protocol.MaxActiveConnectionIDs {
return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}
}
return nil
}
func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
// If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active
// connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately.
if f.SequenceNumber < h.activeSequenceNumber || f.SequenceNumber < h.highestRetired {
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: f.SequenceNumber,
})
return nil
}
// Retire elements in the queue.
// Doesn't retire the active connection ID.
if f.RetirePriorTo > h.highestRetired {
var next *list.Element[newConnID]
for el := h.queue.Front(); el != nil; el = next {
if el.Value.SequenceNumber >= f.RetirePriorTo {
break
}
next = el.Next()
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: el.Value.SequenceNumber,
})
h.queue.Remove(el)
}
h.highestRetired = f.RetirePriorTo
}
if f.SequenceNumber == h.activeSequenceNumber {
return nil
}
if err := h.addConnectionID(f.SequenceNumber, f.ConnectionID, f.StatelessResetToken); err != nil {
return err
}
// Retire the active connection ID, if necessary.
if h.activeSequenceNumber < f.RetirePriorTo {
// The queue is guaranteed to have at least one element at this point.
h.updateConnectionID()
}
return nil
}
func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
// insert a new element at the end
if h.queue.Len() == 0 || h.queue.Back().Value.SequenceNumber < seq {
h.queue.PushBack(newConnID{
SequenceNumber: seq,
ConnectionID: connID,
StatelessResetToken: resetToken,
})
return nil
}
// insert a new element somewhere in the middle
for el := h.queue.Front(); el != nil; el = el.Next() {
if el.Value.SequenceNumber == seq {
if el.Value.ConnectionID != connID {
return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq)
}
if el.Value.StatelessResetToken != resetToken {
return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq)
}
break
}
if el.Value.SequenceNumber > seq {
h.queue.InsertBefore(newConnID{
SequenceNumber: seq,
ConnectionID: connID,
StatelessResetToken: resetToken,
}, el)
break
}
}
return nil
}
func (h *connIDManager) updateConnectionID() {
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: h.activeSequenceNumber,
})
h.highestRetired = utils.Max(h.highestRetired, h.activeSequenceNumber)
if h.activeStatelessResetToken != nil {
h.removeStatelessResetToken(*h.activeStatelessResetToken)
}
front := h.queue.Remove(h.queue.Front())
h.activeSequenceNumber = front.SequenceNumber
h.activeConnectionID = front.ConnectionID
h.activeStatelessResetToken = &front.StatelessResetToken
h.packetsSinceLastChange = 0
h.packetsPerConnectionID = protocol.PacketsPerConnectionID/2 + uint32(h.rand.Int31n(protocol.PacketsPerConnectionID))
h.addStatelessResetToken(*h.activeStatelessResetToken)
}
func (h *connIDManager) Close() {
if h.activeStatelessResetToken != nil {
h.removeStatelessResetToken(*h.activeStatelessResetToken)
}
}
// is called when the server performs a Retry
// and when the server changes the connection ID in the first Initial sent
func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) {
if h.activeSequenceNumber != 0 {
panic("expected first connection ID to have sequence number 0")
}
h.activeConnectionID = newConnID
}
// is called when the server provides a stateless reset token in the transport parameters
func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) {
if h.activeSequenceNumber != 0 {
panic("expected first connection ID to have sequence number 0")
}
h.activeStatelessResetToken = &token
h.addStatelessResetToken(token)
}
func (h *connIDManager) SentPacket() {
h.packetsSinceLastChange++
}
func (h *connIDManager) shouldUpdateConnID() bool {
if !h.handshakeComplete {
return false
}
// initiate the first change as early as possible (after handshake completion)
if h.queue.Len() > 0 && h.activeSequenceNumber == 0 {
return true
}
// For later changes, only change if
// 1. The queue of connection IDs is filled more than 50%.
// 2. We sent at least PacketsPerConnectionID packets
return 2*h.queue.Len() >= protocol.MaxActiveConnectionIDs &&
h.packetsSinceLastChange >= h.packetsPerConnectionID
}
func (h *connIDManager) Get() protocol.ConnectionID {
if h.shouldUpdateConnID() {
h.updateConnectionID()
}
return h.activeConnectionID
}
func (h *connIDManager) SetHandshakeComplete() {
h.handshakeComplete = true
}

2387
vendor/github.com/quic-go/quic-go/connection.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

51
vendor/github.com/quic-go/quic-go/connection_timer.go generated vendored Normal file
View File

@@ -0,0 +1,51 @@
package quic
import (
"time"
"github.com/quic-go/quic-go/internal/utils"
)
var deadlineSendImmediately = time.Time{}.Add(42 * time.Millisecond) // any value > time.Time{} and before time.Now() is fine
type connectionTimer struct {
timer *utils.Timer
last time.Time
}
func newTimer() *connectionTimer {
return &connectionTimer{timer: utils.NewTimer()}
}
func (t *connectionTimer) SetRead() {
if deadline := t.timer.Deadline(); deadline != deadlineSendImmediately {
t.last = deadline
}
t.timer.SetRead()
}
func (t *connectionTimer) Chan() <-chan time.Time {
return t.timer.Chan()
}
// SetTimer resets the timer.
// It makes sure that the deadline is strictly increasing.
// This prevents busy-looping in cases where the timer fires, but we can't actually send out a packet.
// This doesn't apply to the pacing deadline, which can be set multiple times to deadlineSendImmediately.
func (t *connectionTimer) SetTimer(idleTimeoutOrKeepAlive, ackAlarm, lossTime, pacing time.Time) {
deadline := idleTimeoutOrKeepAlive
if !ackAlarm.IsZero() && ackAlarm.Before(deadline) && ackAlarm.After(t.last) {
deadline = ackAlarm
}
if !lossTime.IsZero() && lossTime.Before(deadline) && lossTime.After(t.last) {
deadline = lossTime
}
if !pacing.IsZero() && pacing.Before(deadline) {
deadline = pacing
}
t.timer.Reset(deadline)
}
func (t *connectionTimer) Stop() {
t.timer.Stop()
}

107
vendor/github.com/quic-go/quic-go/crypto_stream.go generated vendored Normal file
View File

@@ -0,0 +1,107 @@
package quic
import (
"fmt"
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
)
type cryptoStream interface {
// for receiving data
HandleCryptoFrame(*wire.CryptoFrame) error
GetCryptoData() []byte
Finish() error
// for sending data
io.Writer
HasData() bool
PopCryptoFrame(protocol.ByteCount) *wire.CryptoFrame
}
type cryptoStreamImpl struct {
queue *frameSorter
msgBuf []byte
highestOffset protocol.ByteCount
finished bool
writeOffset protocol.ByteCount
writeBuf []byte
}
func newCryptoStream() cryptoStream {
return &cryptoStreamImpl{queue: newFrameSorter()}
}
func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
return &qerr.TransportError{
ErrorCode: qerr.CryptoBufferExceeded,
ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset),
}
}
if s.finished {
if highestOffset > s.highestOffset {
// reject crypto data received after this stream was already finished
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "received crypto data after change of encryption level",
}
}
// ignore data with a smaller offset than the highest received
// could e.g. be a retransmission
return nil
}
s.highestOffset = utils.Max(s.highestOffset, highestOffset)
if err := s.queue.Push(f.Data, f.Offset, nil); err != nil {
return err
}
for {
_, data, _ := s.queue.Pop()
if data == nil {
return nil
}
s.msgBuf = append(s.msgBuf, data...)
}
}
// GetCryptoData retrieves data that was received in CRYPTO frames
func (s *cryptoStreamImpl) GetCryptoData() []byte {
b := s.msgBuf
s.msgBuf = nil
return b
}
func (s *cryptoStreamImpl) Finish() error {
if s.queue.HasMoreData() {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "encryption level changed, but crypto stream has more data to read",
}
}
s.finished = true
return nil
}
// Writes writes data that should be sent out in CRYPTO frames
func (s *cryptoStreamImpl) Write(p []byte) (int, error) {
s.writeBuf = append(s.writeBuf, p...)
return len(p), nil
}
func (s *cryptoStreamImpl) HasData() bool {
return len(s.writeBuf) > 0
}
func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
f := &wire.CryptoFrame{Offset: s.writeOffset}
n := utils.Min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
f.Data = s.writeBuf[:n]
s.writeBuf = s.writeBuf[n:]
s.writeOffset += n
return f
}

View File

@@ -0,0 +1,82 @@
package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
type cryptoDataHandler interface {
HandleMessage([]byte, protocol.EncryptionLevel) error
NextEvent() handshake.Event
}
type cryptoStreamManager struct {
cryptoHandler cryptoDataHandler
initialStream cryptoStream
handshakeStream cryptoStream
oneRTTStream cryptoStream
}
func newCryptoStreamManager(
cryptoHandler cryptoDataHandler,
initialStream cryptoStream,
handshakeStream cryptoStream,
oneRTTStream cryptoStream,
) *cryptoStreamManager {
return &cryptoStreamManager{
cryptoHandler: cryptoHandler,
initialStream: initialStream,
handshakeStream: handshakeStream,
oneRTTStream: oneRTTStream,
}
}
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
var str cryptoStream
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
str = m.initialStream
case protocol.EncryptionHandshake:
str = m.handshakeStream
case protocol.Encryption1RTT:
str = m.oneRTTStream
default:
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
}
if err := str.HandleCryptoFrame(frame); err != nil {
return err
}
for {
data := str.GetCryptoData()
if data == nil {
return nil
}
if err := m.cryptoHandler.HandleMessage(data, encLevel); err != nil {
return err
}
}
}
func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame {
if !m.oneRTTStream.HasData() {
return nil
}
return m.oneRTTStream.PopCryptoFrame(maxSize)
}
func (m *cryptoStreamManager) Drop(encLevel protocol.EncryptionLevel) error {
//nolint:exhaustive // 1-RTT keys should never get dropped.
switch encLevel {
case protocol.EncryptionInitial:
return m.initialStream.Finish()
case protocol.EncryptionHandshake:
return m.handshakeStream.Finish()
default:
panic(fmt.Sprintf("dropped unexpected encryption level: %s", encLevel))
}
}

126
vendor/github.com/quic-go/quic-go/datagram_queue.go generated vendored Normal file
View File

@@ -0,0 +1,126 @@
package quic
import (
"context"
"sync"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
)
type datagramQueue struct {
sendQueue chan *wire.DatagramFrame
nextFrame *wire.DatagramFrame
rcvMx sync.Mutex
rcvQueue [][]byte
rcvd chan struct{} // used to notify Receive that a new datagram was received
closeErr error
closed chan struct{}
hasData func()
dequeued chan struct{}
logger utils.Logger
}
func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue {
return &datagramQueue{
hasData: hasData,
sendQueue: make(chan *wire.DatagramFrame, 1),
rcvd: make(chan struct{}, 1),
dequeued: make(chan struct{}),
closed: make(chan struct{}),
logger: logger,
}
}
// AddAndWait queues a new DATAGRAM frame for sending.
// It blocks until the frame has been dequeued.
func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error {
select {
case h.sendQueue <- f:
h.hasData()
case <-h.closed:
return h.closeErr
}
select {
case <-h.dequeued:
return nil
case <-h.closed:
return h.closeErr
}
}
// Peek gets the next DATAGRAM frame for sending.
// If actually sent out, Pop needs to be called before the next call to Peek.
func (h *datagramQueue) Peek() *wire.DatagramFrame {
if h.nextFrame != nil {
return h.nextFrame
}
select {
case h.nextFrame = <-h.sendQueue:
h.dequeued <- struct{}{}
default:
return nil
}
return h.nextFrame
}
func (h *datagramQueue) Pop() {
if h.nextFrame == nil {
panic("datagramQueue BUG: Pop called for nil frame")
}
h.nextFrame = nil
}
// HandleDatagramFrame handles a received DATAGRAM frame.
func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
data := make([]byte, len(f.Data))
copy(data, f.Data)
var queued bool
h.rcvMx.Lock()
if len(h.rcvQueue) < protocol.DatagramRcvQueueLen {
h.rcvQueue = append(h.rcvQueue, data)
queued = true
select {
case h.rcvd <- struct{}{}:
default:
}
}
h.rcvMx.Unlock()
if !queued && h.logger.Debug() {
h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data))
}
}
// Receive gets a received DATAGRAM frame.
func (h *datagramQueue) Receive(ctx context.Context) ([]byte, error) {
for {
h.rcvMx.Lock()
if len(h.rcvQueue) > 0 {
data := h.rcvQueue[0]
h.rcvQueue = h.rcvQueue[1:]
h.rcvMx.Unlock()
return data, nil
}
h.rcvMx.Unlock()
select {
case <-h.rcvd:
continue
case <-h.closed:
return nil, h.closeErr
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
func (h *datagramQueue) CloseWithError(e error) {
h.closeErr = e
close(h.closed)
}

63
vendor/github.com/quic-go/quic-go/errors.go generated vendored Normal file
View File

@@ -0,0 +1,63 @@
package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/qerr"
)
type (
TransportError = qerr.TransportError
ApplicationError = qerr.ApplicationError
VersionNegotiationError = qerr.VersionNegotiationError
StatelessResetError = qerr.StatelessResetError
IdleTimeoutError = qerr.IdleTimeoutError
HandshakeTimeoutError = qerr.HandshakeTimeoutError
)
type (
TransportErrorCode = qerr.TransportErrorCode
ApplicationErrorCode = qerr.ApplicationErrorCode
StreamErrorCode = qerr.StreamErrorCode
)
const (
NoError = qerr.NoError
InternalError = qerr.InternalError
ConnectionRefused = qerr.ConnectionRefused
FlowControlError = qerr.FlowControlError
StreamLimitError = qerr.StreamLimitError
StreamStateError = qerr.StreamStateError
FinalSizeError = qerr.FinalSizeError
FrameEncodingError = qerr.FrameEncodingError
TransportParameterError = qerr.TransportParameterError
ConnectionIDLimitError = qerr.ConnectionIDLimitError
ProtocolViolation = qerr.ProtocolViolation
InvalidToken = qerr.InvalidToken
ApplicationErrorErrorCode = qerr.ApplicationErrorErrorCode
CryptoBufferExceeded = qerr.CryptoBufferExceeded
KeyUpdateError = qerr.KeyUpdateError
AEADLimitReached = qerr.AEADLimitReached
NoViablePathError = qerr.NoViablePathError
)
// A StreamError is used for Stream.CancelRead and Stream.CancelWrite.
// It is also returned from Stream.Read and Stream.Write if the peer canceled reading or writing.
type StreamError struct {
StreamID StreamID
ErrorCode StreamErrorCode
Remote bool
}
func (e *StreamError) Is(target error) bool {
_, ok := target.(*StreamError)
return ok
}
func (e *StreamError) Error() string {
pers := "local"
if e.Remote {
pers = "remote"
}
return fmt.Sprintf("stream %d canceled by %s with error code %d", e.StreamID, pers, e.ErrorCode)
}

237
vendor/github.com/quic-go/quic-go/frame_sorter.go generated vendored Normal file
View File

@@ -0,0 +1,237 @@
package quic
import (
"errors"
"sync"
"github.com/quic-go/quic-go/internal/protocol"
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
)
// byteInterval is an interval from one ByteCount to the other
type byteInterval struct {
Start protocol.ByteCount
End protocol.ByteCount
}
var byteIntervalElementPool sync.Pool
func init() {
byteIntervalElementPool = *list.NewPool[byteInterval]()
}
type frameSorterEntry struct {
Data []byte
DoneCb func()
}
type frameSorter struct {
queue map[protocol.ByteCount]frameSorterEntry
readPos protocol.ByteCount
gaps *list.List[byteInterval]
}
var errDuplicateStreamData = errors.New("duplicate stream data")
func newFrameSorter() *frameSorter {
s := frameSorter{
gaps: list.NewWithPool[byteInterval](&byteIntervalElementPool),
queue: make(map[protocol.ByteCount]frameSorterEntry),
}
s.gaps.PushFront(byteInterval{Start: 0, End: protocol.MaxByteCount})
return &s
}
func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, doneCb func()) error {
err := s.push(data, offset, doneCb)
if err == errDuplicateStreamData {
if doneCb != nil {
doneCb()
}
return nil
}
return err
}
func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func()) error {
if len(data) == 0 {
return errDuplicateStreamData
}
start := offset
end := offset + protocol.ByteCount(len(data))
if end <= s.gaps.Front().Value.Start {
return errDuplicateStreamData
}
startGap, startsInGap := s.findStartGap(start)
endGap, endsInGap := s.findEndGap(startGap, end)
startGapEqualsEndGap := startGap == endGap
if (startGapEqualsEndGap && end <= startGap.Value.Start) ||
(!startGapEqualsEndGap && startGap.Value.End >= endGap.Value.Start && end <= startGap.Value.Start) {
return errDuplicateStreamData
}
startGapNext := startGap.Next()
startGapEnd := startGap.Value.End // save it, in case startGap is modified
endGapStart := endGap.Value.Start // save it, in case endGap is modified
endGapEnd := endGap.Value.End // save it, in case endGap is modified
var adjustedStartGapEnd bool
var wasCut bool
pos := start
var hasReplacedAtLeastOne bool
for {
oldEntry, ok := s.queue[pos]
if !ok {
break
}
oldEntryLen := protocol.ByteCount(len(oldEntry.Data))
if end-pos > oldEntryLen || (hasReplacedAtLeastOne && end-pos == oldEntryLen) {
// The existing frame is shorter than the new frame. Replace it.
delete(s.queue, pos)
pos += oldEntryLen
hasReplacedAtLeastOne = true
if oldEntry.DoneCb != nil {
oldEntry.DoneCb()
}
} else {
if !hasReplacedAtLeastOne {
return errDuplicateStreamData
}
// The existing frame is longer than the new frame.
// Cut the new frame such that the end aligns with the start of the existing frame.
data = data[:pos-start]
end = pos
wasCut = true
break
}
}
if !startsInGap && !hasReplacedAtLeastOne {
// cut the frame, such that it starts at the start of the gap
data = data[startGap.Value.Start-start:]
start = startGap.Value.Start
wasCut = true
}
if start <= startGap.Value.Start {
if end >= startGap.Value.End {
// The frame covers the whole startGap. Delete the gap.
s.gaps.Remove(startGap)
} else {
startGap.Value.Start = end
}
} else if !hasReplacedAtLeastOne {
startGap.Value.End = start
adjustedStartGapEnd = true
}
if !startGapEqualsEndGap {
s.deleteConsecutive(startGapEnd)
var nextGap *list.Element[byteInterval]
for gap := startGapNext; gap.Value.End < endGapStart; gap = nextGap {
nextGap = gap.Next()
s.deleteConsecutive(gap.Value.End)
s.gaps.Remove(gap)
}
}
if !endsInGap && start != endGapEnd && end > endGapEnd {
// cut the frame, such that it ends at the end of the gap
data = data[:endGapEnd-start]
end = endGapEnd
wasCut = true
}
if end == endGapEnd {
if !startGapEqualsEndGap {
// The frame covers the whole endGap. Delete the gap.
s.gaps.Remove(endGap)
}
} else {
if startGapEqualsEndGap && adjustedStartGapEnd {
// The frame split the existing gap into two.
s.gaps.InsertAfter(byteInterval{Start: end, End: startGapEnd}, startGap)
} else if !startGapEqualsEndGap {
endGap.Value.Start = end
}
}
if wasCut && len(data) < protocol.MinStreamFrameBufferSize {
newData := make([]byte, len(data))
copy(newData, data)
data = newData
if doneCb != nil {
doneCb()
doneCb = nil
}
}
if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps {
return errors.New("too many gaps in received data")
}
s.queue[start] = frameSorterEntry{Data: data, DoneCb: doneCb}
return nil
}
func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*list.Element[byteInterval], bool) {
for gap := s.gaps.Front(); gap != nil; gap = gap.Next() {
if offset >= gap.Value.Start && offset <= gap.Value.End {
return gap, true
}
if offset < gap.Value.Start {
return gap, false
}
}
panic("no gap found")
}
func (s *frameSorter) findEndGap(startGap *list.Element[byteInterval], offset protocol.ByteCount) (*list.Element[byteInterval], bool) {
for gap := startGap; gap != nil; gap = gap.Next() {
if offset >= gap.Value.Start && offset < gap.Value.End {
return gap, true
}
if offset < gap.Value.Start {
return gap.Prev(), false
}
}
panic("no gap found")
}
// deleteConsecutive deletes consecutive frames from the queue, starting at pos
func (s *frameSorter) deleteConsecutive(pos protocol.ByteCount) {
for {
oldEntry, ok := s.queue[pos]
if !ok {
break
}
oldEntryLen := protocol.ByteCount(len(oldEntry.Data))
delete(s.queue, pos)
if oldEntry.DoneCb != nil {
oldEntry.DoneCb()
}
pos += oldEntryLen
}
}
func (s *frameSorter) Pop() (protocol.ByteCount, []byte, func()) {
entry, ok := s.queue[s.readPos]
if !ok {
return s.readPos, nil, nil
}
delete(s.queue, s.readPos)
offset := s.readPos
s.readPos += protocol.ByteCount(len(entry.Data))
if s.gaps.Front().Value.End <= s.readPos {
panic("frame sorter BUG: read position higher than a gap")
}
return offset, entry.Data, entry.DoneCb
}
// HasMoreData says if there is any more data queued at *any* offset.
func (s *frameSorter) HasMoreData() bool {
return len(s.queue) > 0
}

165
vendor/github.com/quic-go/quic-go/framer.go generated vendored Normal file
View File

@@ -0,0 +1,165 @@
package quic
import (
"errors"
"sync"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils/ringbuffer"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/quicvarint"
)
type framer interface {
HasData() bool
QueueControlFrame(wire.Frame)
AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount)
AddActiveStream(protocol.StreamID)
AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount)
Handle0RTTRejection() error
}
type framerI struct {
mutex sync.Mutex
streamGetter streamGetter
activeStreams map[protocol.StreamID]struct{}
streamQueue ringbuffer.RingBuffer[protocol.StreamID]
controlFrameMutex sync.Mutex
controlFrames []wire.Frame
}
var _ framer = &framerI{}
func newFramer(streamGetter streamGetter) framer {
return &framerI{
streamGetter: streamGetter,
activeStreams: make(map[protocol.StreamID]struct{}),
}
}
func (f *framerI) HasData() bool {
f.mutex.Lock()
hasData := !f.streamQueue.Empty()
f.mutex.Unlock()
if hasData {
return true
}
f.controlFrameMutex.Lock()
hasData = len(f.controlFrames) > 0
f.controlFrameMutex.Unlock()
return hasData
}
func (f *framerI) QueueControlFrame(frame wire.Frame) {
f.controlFrameMutex.Lock()
f.controlFrames = append(f.controlFrames, frame)
f.controlFrameMutex.Unlock()
}
func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) {
var length protocol.ByteCount
f.controlFrameMutex.Lock()
for len(f.controlFrames) > 0 {
frame := f.controlFrames[len(f.controlFrames)-1]
frameLen := frame.Length(v)
if length+frameLen > maxLen {
break
}
frames = append(frames, ackhandler.Frame{Frame: frame})
length += frameLen
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
}
f.controlFrameMutex.Unlock()
return frames, length
}
func (f *framerI) AddActiveStream(id protocol.StreamID) {
f.mutex.Lock()
if _, ok := f.activeStreams[id]; !ok {
f.streamQueue.PushBack(id)
f.activeStreams[id] = struct{}{}
}
f.mutex.Unlock()
}
func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) {
startLen := len(frames)
var length protocol.ByteCount
f.mutex.Lock()
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
numActiveStreams := f.streamQueue.Len()
for i := 0; i < numActiveStreams; i++ {
if protocol.MinStreamFrameSize+length > maxLen {
break
}
id := f.streamQueue.PopFront()
// This should never return an error. Better check it anyway.
// The stream will only be in the streamQueue, if it enqueued itself there.
str, err := f.streamGetter.GetOrOpenSendStream(id)
// The stream can be nil if it completed after it said it had data.
if str == nil || err != nil {
delete(f.activeStreams, id)
continue
}
remainingLen := maxLen - length
// For the last STREAM frame, we'll remove the DataLen field later.
// Therefore, we can pretend to have more bytes available when popping
// the STREAM frame (which will always have the DataLen set).
remainingLen += quicvarint.Len(uint64(remainingLen))
frame, ok, hasMoreData := str.popStreamFrame(remainingLen, v)
if hasMoreData { // put the stream back in the queue (at the end)
f.streamQueue.PushBack(id)
} else { // no more data to send. Stream is not active
delete(f.activeStreams, id)
}
// The frame can be "nil"
// * if the receiveStream was canceled after it said it had data
// * the remaining size doesn't allow us to add another STREAM frame
if !ok {
continue
}
frames = append(frames, frame)
length += frame.Frame.Length(v)
}
f.mutex.Unlock()
if len(frames) > startLen {
l := frames[len(frames)-1].Frame.Length(v)
// account for the smaller size of the last STREAM frame
frames[len(frames)-1].Frame.DataLenPresent = false
length += frames[len(frames)-1].Frame.Length(v) - l
}
return frames, length
}
func (f *framerI) Handle0RTTRejection() error {
f.mutex.Lock()
defer f.mutex.Unlock()
f.controlFrameMutex.Lock()
f.streamQueue.Clear()
for id := range f.activeStreams {
delete(f.activeStreams, id)
}
var j int
for i, frame := range f.controlFrames {
switch frame.(type) {
case *wire.MaxDataFrame, *wire.MaxStreamDataFrame, *wire.MaxStreamsFrame:
return errors.New("didn't expect MAX_DATA / MAX_STREAM_DATA / MAX_STREAMS frame to be sent in 0-RTT")
case *wire.DataBlockedFrame, *wire.StreamDataBlockedFrame, *wire.StreamsBlockedFrame:
continue
default:
f.controlFrames[j] = f.controlFrames[i]
j++
}
}
f.controlFrames = f.controlFrames[:j]
f.controlFrameMutex.Unlock()
return nil
}

136
vendor/github.com/quic-go/quic-go/http3/body.go generated vendored Normal file
View File

@@ -0,0 +1,136 @@
package http3
import (
"context"
"io"
"net"
"github.com/quic-go/quic-go"
)
// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by:
// * for the server: the http.Request.Body
// * for the client: the http.Response.Body
// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set.
// When a stream is taken over, it's the caller's responsibility to close the stream.
type HTTPStreamer interface {
HTTPStream() Stream
}
type StreamCreator interface {
// Context returns a context that is cancelled when the underlying connection is closed.
Context() context.Context
OpenStream() (quic.Stream, error)
OpenStreamSync(context.Context) (quic.Stream, error)
OpenUniStream() (quic.SendStream, error)
OpenUniStreamSync(context.Context) (quic.SendStream, error)
LocalAddr() net.Addr
RemoteAddr() net.Addr
ConnectionState() quic.ConnectionState
}
var _ StreamCreator = quic.Connection(nil)
// A Hijacker allows hijacking of the stream creating part of a quic.Session from a http.Response.Body.
// It is used by WebTransport to create WebTransport streams after a session has been established.
type Hijacker interface {
StreamCreator() StreamCreator
}
// The body of a http.Request or http.Response.
type body struct {
str quic.Stream
wasHijacked bool // set when HTTPStream is called
}
var (
_ io.ReadCloser = &body{}
_ HTTPStreamer = &body{}
)
func newRequestBody(str Stream) *body {
return &body{str: str}
}
func (r *body) HTTPStream() Stream {
r.wasHijacked = true
return r.str
}
func (r *body) wasStreamHijacked() bool {
return r.wasHijacked
}
func (r *body) Read(b []byte) (int, error) {
n, err := r.str.Read(b)
return n, maybeReplaceError(err)
}
func (r *body) Close() error {
r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
return nil
}
type hijackableBody struct {
body
conn quic.Connection // only needed to implement Hijacker
// only set for the http.Response
// The channel is closed when the user is done with this response:
// either when Read() errors, or when Close() is called.
reqDone chan<- struct{}
reqDoneClosed bool
}
var (
_ Hijacker = &hijackableBody{}
_ HTTPStreamer = &hijackableBody{}
)
func newResponseBody(str Stream, conn quic.Connection, done chan<- struct{}) *hijackableBody {
return &hijackableBody{
body: body{
str: str,
},
reqDone: done,
conn: conn,
}
}
func (r *hijackableBody) StreamCreator() StreamCreator {
return r.conn
}
func (r *hijackableBody) Read(b []byte) (int, error) {
n, err := r.str.Read(b)
if err != nil {
r.requestDone()
}
return n, maybeReplaceError(err)
}
func (r *hijackableBody) requestDone() {
if r.reqDoneClosed || r.reqDone == nil {
return
}
if r.reqDone != nil {
close(r.reqDone)
}
r.reqDoneClosed = true
}
func (r *body) StreamID() quic.StreamID {
return r.str.StreamID()
}
func (r *hijackableBody) Close() error {
r.requestDone()
// If the EOF was read, CancelRead() is a no-op.
r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
return nil
}
func (r *hijackableBody) HTTPStream() Stream {
return r.str
}

55
vendor/github.com/quic-go/quic-go/http3/capsule.go generated vendored Normal file
View File

@@ -0,0 +1,55 @@
package http3
import (
"io"
"github.com/quic-go/quic-go/quicvarint"
)
// CapsuleType is the type of the capsule.
type CapsuleType uint64
type exactReader struct {
R *io.LimitedReader
}
func (r *exactReader) Read(b []byte) (int, error) {
n, err := r.R.Read(b)
if r.R.N > 0 {
return n, io.ErrUnexpectedEOF
}
return n, err
}
// ParseCapsule parses the header of a Capsule.
// It returns an io.LimitedReader that can be used to read the Capsule value.
// The Capsule value must be read entirely (i.e. until the io.EOF) before using r again.
func ParseCapsule(r quicvarint.Reader) (CapsuleType, io.Reader, error) {
ct, err := quicvarint.Read(r)
if err != nil {
if err == io.EOF {
return 0, nil, io.ErrUnexpectedEOF
}
return 0, nil, err
}
l, err := quicvarint.Read(r)
if err != nil {
if err == io.EOF {
return 0, nil, io.ErrUnexpectedEOF
}
return 0, nil, err
}
return CapsuleType(ct), &exactReader{R: io.LimitReader(r, int64(l)).(*io.LimitedReader)}, nil
}
// WriteCapsule writes a capsule
func WriteCapsule(w quicvarint.Writer, ct CapsuleType, value []byte) error {
b := make([]byte, 0, 16)
b = quicvarint.Append(b, uint64(ct))
b = quicvarint.Append(b, uint64(len(value)))
if _, err := w.Write(b); err != nil {
return err
}
_, err := w.Write(value)
return err
}

477
vendor/github.com/quic-go/quic-go/http3/client.go generated vendored Normal file
View File

@@ -0,0 +1,477 @@
package http3
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
"github.com/quic-go/qpack"
)
// MethodGet0RTT allows a GET request to be sent using 0-RTT.
// Note that 0-RTT data doesn't provide replay protection.
const MethodGet0RTT = "GET_0RTT"
const (
defaultUserAgent = "quic-go HTTP/3"
defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
)
var defaultQuicConfig = &quic.Config{
MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
KeepAlivePeriod: 10 * time.Second,
}
type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
var dialAddr dialFunc = quic.DialAddrEarly
type roundTripperOpts struct {
DisableCompression bool
EnableDatagram bool
MaxHeaderBytes int64
AdditionalSettings map[uint64]uint64
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
}
// client is a HTTP3 client doing requests
type client struct {
tlsConf *tls.Config
config *quic.Config
opts *roundTripperOpts
dialOnce sync.Once
dialer dialFunc
handshakeErr error
requestWriter *requestWriter
decoder *qpack.Decoder
hostname string
conn atomic.Pointer[quic.EarlyConnection]
logger utils.Logger
}
var _ roundTripCloser = &client{}
func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
if conf == nil {
conf = defaultQuicConfig.Clone()
}
if len(conf.Versions) == 0 {
conf = conf.Clone()
conf.Versions = []quic.VersionNumber{protocol.SupportedVersions[0]}
}
if len(conf.Versions) != 1 {
return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
}
if conf.MaxIncomingStreams == 0 {
conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams
}
conf.EnableDatagrams = opts.EnableDatagram
logger := utils.DefaultLogger.WithPrefix("h3 client")
if tlsConf == nil {
tlsConf = &tls.Config{}
} else {
tlsConf = tlsConf.Clone()
}
if tlsConf.ServerName == "" {
sni, _, err := net.SplitHostPort(hostname)
if err != nil {
// It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
sni = hostname
}
tlsConf.ServerName = sni
}
// Replace existing ALPNs by H3
tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])}
return &client{
hostname: authorityAddr("https", hostname),
tlsConf: tlsConf,
requestWriter: newRequestWriter(logger),
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
config: conf,
opts: opts,
dialer: dialer,
logger: logger,
}, nil
}
func (c *client) dial(ctx context.Context) error {
var err error
var conn quic.EarlyConnection
if c.dialer != nil {
conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config)
} else {
conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
}
if err != nil {
return err
}
c.conn.Store(&conn)
// send the SETTINGs frame, using 0-RTT data, if possible
go func() {
if err := c.setupConn(conn); err != nil {
c.logger.Debugf("Setting up connection failed: %s", err)
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
}
}()
if c.opts.StreamHijacker != nil {
go c.handleBidirectionalStreams(conn)
}
go c.handleUnidirectionalStreams(conn)
return nil
}
func (c *client) setupConn(conn quic.EarlyConnection) error {
// open the control stream
str, err := conn.OpenUniStream()
if err != nil {
return err
}
b := make([]byte, 0, 64)
b = quicvarint.Append(b, streamTypeControlStream)
// send the SETTINGS frame
b = (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Append(b)
_, err = str.Write(b)
return err
}
func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) {
for {
str, err := conn.AcceptStream(context.Background())
if err != nil {
c.logger.Debugf("accepting bidirectional stream failed: %s", err)
return
}
go func(str quic.Stream) {
_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
return c.opts.StreamHijacker(ft, conn, str, e)
})
if err == errHijacked {
return
}
if err != nil {
c.logger.Debugf("error handling stream: %s", err)
}
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
}(str)
}
}
func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) {
for {
str, err := conn.AcceptUniStream(context.Background())
if err != nil {
c.logger.Debugf("accepting unidirectional stream failed: %s", err)
return
}
go func(str quic.ReceiveStream) {
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) {
return
}
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
return
}
// We're only interested in the control stream here.
switch streamType {
case streamTypeControlStream:
case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream:
// Our QPACK implementation doesn't use the dynamic table yet.
// TODO: check that only one stream of each type is opened.
return
case streamTypePushStream:
// We never increased the Push ID, so we don't expect any push streams.
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
return
default:
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
return
}
str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
return
}
f, err := parseNextFrame(str, nil)
if err != nil {
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
return
}
sf, ok := f.(*settingsFrame)
if !ok {
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
return
}
if !sf.Datagram {
return
}
// If datagram support was enabled on our side as well as on the server side,
// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams {
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
}
}(str)
}
}
func (c *client) Close() error {
conn := c.conn.Load()
if conn == nil {
return nil
}
return (*conn).CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
}
func (c *client) maxHeaderBytes() uint64 {
if c.opts.MaxHeaderBytes <= 0 {
return defaultMaxResponseHeaderBytes
}
return uint64(c.opts.MaxHeaderBytes)
}
// RoundTripOpt executes a request and returns a response
func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
}
c.dialOnce.Do(func() {
c.handshakeErr = c.dial(req.Context())
})
if c.handshakeErr != nil {
return nil, c.handshakeErr
}
// At this point, c.conn is guaranteed to be set.
conn := *c.conn.Load()
// Immediately send out this request, if this is a 0-RTT request.
if req.Method == MethodGet0RTT {
req.Method = http.MethodGet
} else {
// wait for the handshake to complete
select {
case <-conn.HandshakeComplete():
case <-req.Context().Done():
return nil, req.Context().Err()
}
}
str, err := conn.OpenStreamSync(req.Context())
if err != nil {
return nil, err
}
// Request Cancellation:
// This go routine keeps running even after RoundTripOpt() returns.
// It is shut down when the application is done processing the body.
reqDone := make(chan struct{})
done := make(chan struct{})
go func() {
defer close(done)
select {
case <-req.Context().Done():
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
case <-reqDone:
}
}()
doneChan := reqDone
if opt.DontCloseRequestStream {
doneChan = nil
}
rsp, rerr := c.doRequest(req, conn, str, opt, doneChan)
if rerr.err != nil { // if any error occurred
close(reqDone)
<-done
if rerr.streamErr != 0 { // if it was a stream error
str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
}
if rerr.connErr != 0 { // if it was a connection error
var reason string
if rerr.err != nil {
reason = rerr.err.Error()
}
conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
}
return nil, maybeReplaceError(rerr.err)
}
if opt.DontCloseRequestStream {
close(reqDone)
<-done
}
return rsp, maybeReplaceError(rerr.err)
}
// cancelingReader reads from the io.Reader.
// It cancels writing on the stream if any error other than io.EOF occurs.
type cancelingReader struct {
r io.Reader
str Stream
}
func (r *cancelingReader) Read(b []byte) (int, error) {
n, err := r.r.Read(b)
if err != nil && err != io.EOF {
r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
}
return n, err
}
func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
defer body.Close()
buf := make([]byte, bodyCopyBufferSize)
sr := &cancelingReader{str: str, r: body}
if contentLength == -1 {
_, err := io.CopyBuffer(str, sr, buf)
return err
}
// make sure we don't send more bytes than the content length
n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf)
if err != nil {
return err
}
var extra int64
extra, err = io.CopyBuffer(io.Discard, sr, buf)
n += extra
if n > contentLength {
str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n)
}
return err
}
func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
var requestGzip bool
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
requestGzip = true
}
if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil {
return nil, newStreamError(ErrCodeInternalError, err)
}
if req.Body == nil && !opt.DontCloseRequestStream {
str.Close()
}
hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") })
if req.Body != nil {
// send the request body asynchronously
go func() {
contentLength := int64(-1)
// According to the documentation for http.Request.ContentLength,
// a value of 0 with a non-nil Body is also treated as unknown content length.
if req.ContentLength > 0 {
contentLength = req.ContentLength
}
if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil {
c.logger.Errorf("Error writing request: %s", err)
}
if !opt.DontCloseRequestStream {
hstr.Close()
}
}()
}
frame, err := parseNextFrame(str, nil)
if err != nil {
return nil, newStreamError(ErrCodeFrameError, err)
}
hf, ok := frame.(*headersFrame)
if !ok {
return nil, newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
}
if hf.Length > c.maxHeaderBytes() {
return nil, newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()))
}
headerBlock := make([]byte, hf.Length)
if _, err := io.ReadFull(str, headerBlock); err != nil {
return nil, newStreamError(ErrCodeRequestIncomplete, err)
}
hfs, err := c.decoder.DecodeFull(headerBlock)
if err != nil {
// TODO: use the right error code
return nil, newConnError(ErrCodeGeneralProtocolError, err)
}
res, err := responseFromHeaders(hfs)
if err != nil {
return nil, newStreamError(ErrCodeMessageError, err)
}
connState := conn.ConnectionState().TLS
res.TLS = &connState
res.Request = req
// Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
// See section 4.1.2 of RFC 9114.
var httpStr Stream
if _, ok := res.Header["Content-Length"]; ok && res.ContentLength >= 0 {
httpStr = newLengthLimitedStream(hstr, res.ContentLength)
} else {
httpStr = hstr
}
respBody := newResponseBody(httpStr, conn, reqDone)
// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
_, hasTransferEncoding := res.Header["Transfer-Encoding"]
isInformational := res.StatusCode >= 100 && res.StatusCode < 200
isNoContent := res.StatusCode == http.StatusNoContent
isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300
if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect {
res.ContentLength = -1
if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 {
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
res.ContentLength = clen64
}
}
}
if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = newGzipReader(respBody)
res.Uncompressed = true
} else {
res.Body = respBody
}
return res, requestError{}
}
func (c *client) HandshakeComplete() bool {
conn := c.conn.Load()
if conn == nil {
return false
}
select {
case <-(*conn).HandshakeComplete():
return true
default:
return false
}
}

58
vendor/github.com/quic-go/quic-go/http3/error.go generated vendored Normal file
View File

@@ -0,0 +1,58 @@
package http3
import (
"errors"
"fmt"
"github.com/quic-go/quic-go"
)
// Error is returned from the round tripper (for HTTP clients)
// and inside the HTTP handler (for HTTP servers) if an HTTP/3 error occurs.
// See section 8 of RFC 9114.
type Error struct {
Remote bool
ErrorCode ErrCode
ErrorMessage string
}
var _ error = &Error{}
func (e *Error) Error() string {
s := e.ErrorCode.string()
if s == "" {
s = fmt.Sprintf("H3 error (%#x)", uint64(e.ErrorCode))
}
// Usually errors are remote. Only make it explicit for local errors.
if !e.Remote {
s += " (local)"
}
if e.ErrorMessage != "" {
s += ": " + e.ErrorMessage
}
return s
}
func maybeReplaceError(err error) error {
if err == nil {
return nil
}
var (
e Error
strErr *quic.StreamError
appErr *quic.ApplicationError
)
switch {
default:
return err
case errors.As(err, &strErr):
e.Remote = strErr.Remote
e.ErrorCode = ErrCode(strErr.ErrorCode)
case errors.As(err, &appErr):
e.Remote = appErr.Remote
e.ErrorCode = ErrCode(appErr.ErrorCode)
e.ErrorMessage = appErr.ErrorMessage
}
return &e
}

81
vendor/github.com/quic-go/quic-go/http3/error_codes.go generated vendored Normal file
View File

@@ -0,0 +1,81 @@
package http3
import (
"fmt"
"github.com/quic-go/quic-go"
)
type ErrCode quic.ApplicationErrorCode
const (
ErrCodeNoError ErrCode = 0x100
ErrCodeGeneralProtocolError ErrCode = 0x101
ErrCodeInternalError ErrCode = 0x102
ErrCodeStreamCreationError ErrCode = 0x103
ErrCodeClosedCriticalStream ErrCode = 0x104
ErrCodeFrameUnexpected ErrCode = 0x105
ErrCodeFrameError ErrCode = 0x106
ErrCodeExcessiveLoad ErrCode = 0x107
ErrCodeIDError ErrCode = 0x108
ErrCodeSettingsError ErrCode = 0x109
ErrCodeMissingSettings ErrCode = 0x10a
ErrCodeRequestRejected ErrCode = 0x10b
ErrCodeRequestCanceled ErrCode = 0x10c
ErrCodeRequestIncomplete ErrCode = 0x10d
ErrCodeMessageError ErrCode = 0x10e
ErrCodeConnectError ErrCode = 0x10f
ErrCodeVersionFallback ErrCode = 0x110
ErrCodeDatagramError ErrCode = 0x33
)
func (e ErrCode) String() string {
s := e.string()
if s != "" {
return s
}
return fmt.Sprintf("unknown error code: %#x", uint16(e))
}
func (e ErrCode) string() string {
switch e {
case ErrCodeNoError:
return "H3_NO_ERROR"
case ErrCodeGeneralProtocolError:
return "H3_GENERAL_PROTOCOL_ERROR"
case ErrCodeInternalError:
return "H3_INTERNAL_ERROR"
case ErrCodeStreamCreationError:
return "H3_STREAM_CREATION_ERROR"
case ErrCodeClosedCriticalStream:
return "H3_CLOSED_CRITICAL_STREAM"
case ErrCodeFrameUnexpected:
return "H3_FRAME_UNEXPECTED"
case ErrCodeFrameError:
return "H3_FRAME_ERROR"
case ErrCodeExcessiveLoad:
return "H3_EXCESSIVE_LOAD"
case ErrCodeIDError:
return "H3_ID_ERROR"
case ErrCodeSettingsError:
return "H3_SETTINGS_ERROR"
case ErrCodeMissingSettings:
return "H3_MISSING_SETTINGS"
case ErrCodeRequestRejected:
return "H3_REQUEST_REJECTED"
case ErrCodeRequestCanceled:
return "H3_REQUEST_CANCELLED"
case ErrCodeRequestIncomplete:
return "H3_INCOMPLETE_REQUEST"
case ErrCodeMessageError:
return "H3_MESSAGE_ERROR"
case ErrCodeConnectError:
return "H3_CONNECT_ERROR"
case ErrCodeVersionFallback:
return "H3_VERSION_FALLBACK"
case ErrCodeDatagramError:
return "H3_DATAGRAM_ERROR"
default:
return ""
}
}

164
vendor/github.com/quic-go/quic-go/http3/frames.go generated vendored Normal file
View File

@@ -0,0 +1,164 @@
package http3
import (
"bytes"
"errors"
"fmt"
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// FrameType is the frame type of a HTTP/3 frame
type FrameType uint64
type unknownFrameHandlerFunc func(FrameType, error) (processed bool, err error)
type frame interface{}
var errHijacked = errors.New("hijacked")
func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) {
qr := quicvarint.NewReader(r)
for {
t, err := quicvarint.Read(qr)
if err != nil {
if unknownFrameHandler != nil {
hijacked, err := unknownFrameHandler(0, err)
if err != nil {
return nil, err
}
if hijacked {
return nil, errHijacked
}
}
return nil, err
}
// Call the unknownFrameHandler for frames not defined in the HTTP/3 spec
if t > 0xd && unknownFrameHandler != nil {
hijacked, err := unknownFrameHandler(FrameType(t), nil)
if err != nil {
return nil, err
}
if hijacked {
return nil, errHijacked
}
// If the unknownFrameHandler didn't process the frame, it is our responsibility to skip it.
}
l, err := quicvarint.Read(qr)
if err != nil {
return nil, err
}
switch t {
case 0x0:
return &dataFrame{Length: l}, nil
case 0x1:
return &headersFrame{Length: l}, nil
case 0x4:
return parseSettingsFrame(r, l)
case 0x3: // CANCEL_PUSH
case 0x5: // PUSH_PROMISE
case 0x7: // GOAWAY
case 0xd: // MAX_PUSH_ID
}
// skip over unknown frames
if _, err := io.CopyN(io.Discard, qr, int64(l)); err != nil {
return nil, err
}
}
}
type dataFrame struct {
Length uint64
}
func (f *dataFrame) Append(b []byte) []byte {
b = quicvarint.Append(b, 0x0)
return quicvarint.Append(b, f.Length)
}
type headersFrame struct {
Length uint64
}
func (f *headersFrame) Append(b []byte) []byte {
b = quicvarint.Append(b, 0x1)
return quicvarint.Append(b, f.Length)
}
const settingDatagram = 0x33
type settingsFrame struct {
Datagram bool
Other map[uint64]uint64 // all settings that we don't explicitly recognize
}
func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) {
if l > 8*(1<<10) {
return nil, fmt.Errorf("unexpected size for SETTINGS frame: %d", l)
}
buf := make([]byte, l)
if _, err := io.ReadFull(r, buf); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return nil, err
}
frame := &settingsFrame{}
b := bytes.NewReader(buf)
var readDatagram bool
for b.Len() > 0 {
id, err := quicvarint.Read(b)
if err != nil { // should not happen. We allocated the whole frame already.
return nil, err
}
val, err := quicvarint.Read(b)
if err != nil { // should not happen. We allocated the whole frame already.
return nil, err
}
switch id {
case settingDatagram:
if readDatagram {
return nil, fmt.Errorf("duplicate setting: %d", id)
}
readDatagram = true
if val != 0 && val != 1 {
return nil, fmt.Errorf("invalid value for H3_DATAGRAM: %d", val)
}
frame.Datagram = val == 1
default:
if _, ok := frame.Other[id]; ok {
return nil, fmt.Errorf("duplicate setting: %d", id)
}
if frame.Other == nil {
frame.Other = make(map[uint64]uint64)
}
frame.Other[id] = val
}
}
return frame, nil
}
func (f *settingsFrame) Append(b []byte) []byte {
b = quicvarint.Append(b, 0x4)
var l protocol.ByteCount
for id, val := range f.Other {
l += quicvarint.Len(id) + quicvarint.Len(val)
}
if f.Datagram {
l += quicvarint.Len(settingDatagram) + quicvarint.Len(1)
}
b = quicvarint.Append(b, uint64(l))
if f.Datagram {
b = quicvarint.Append(b, settingDatagram)
b = quicvarint.Append(b, 1)
}
for id, val := range f.Other {
b = quicvarint.Append(b, id)
b = quicvarint.Append(b, val)
}
return b
}

39
vendor/github.com/quic-go/quic-go/http3/gzip_reader.go generated vendored Normal file
View File

@@ -0,0 +1,39 @@
package http3
// copied from net/transport.go
// gzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
import (
"compress/gzip"
"io"
)
// call gzip.NewReader on the first call to Read
type gzipReader struct {
body io.ReadCloser // underlying Response.Body
zr *gzip.Reader // lazily-initialized gzip reader
zerr error // sticky error
}
func newGzipReader(body io.ReadCloser) io.ReadCloser {
return &gzipReader{body: body}
}
func (gz *gzipReader) Read(p []byte) (n int, err error) {
if gz.zerr != nil {
return 0, gz.zerr
}
if gz.zr == nil {
gz.zr, err = gzip.NewReader(gz.body)
if err != nil {
gz.zerr = err
return 0, err
}
}
return gz.zr.Read(p)
}
func (gz *gzipReader) Close() error {
return gz.body.Close()
}

198
vendor/github.com/quic-go/quic-go/http3/headers.go generated vendored Normal file
View File

@@ -0,0 +1,198 @@
package http3
import (
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"golang.org/x/net/http/httpguts"
"github.com/quic-go/qpack"
)
type header struct {
// Pseudo header fields defined in RFC 9114
Path string
Method string
Authority string
Scheme string
Status string
// for Extended connect
Protocol string
// parsed and deduplicated
ContentLength int64
// all non-pseudo headers
Headers http.Header
}
func parseHeaders(headers []qpack.HeaderField, isRequest bool) (header, error) {
hdr := header{Headers: make(http.Header, len(headers))}
var readFirstRegularHeader, readContentLength bool
var contentLengthStr string
for _, h := range headers {
// field names need to be lowercase, see section 4.2 of RFC 9114
if strings.ToLower(h.Name) != h.Name {
return header{}, fmt.Errorf("header field is not lower-case: %s", h.Name)
}
if !httpguts.ValidHeaderFieldValue(h.Value) {
return header{}, fmt.Errorf("invalid header field value for %s: %q", h.Name, h.Value)
}
if h.IsPseudo() {
if readFirstRegularHeader {
// all pseudo headers must appear before regular header fields, see section 4.3 of RFC 9114
return header{}, fmt.Errorf("received pseudo header %s after a regular header field", h.Name)
}
var isResponsePseudoHeader bool // pseudo headers are either valid for requests or for responses
switch h.Name {
case ":path":
hdr.Path = h.Value
case ":method":
hdr.Method = h.Value
case ":authority":
hdr.Authority = h.Value
case ":protocol":
hdr.Protocol = h.Value
case ":scheme":
hdr.Scheme = h.Value
case ":status":
hdr.Status = h.Value
isResponsePseudoHeader = true
default:
return header{}, fmt.Errorf("unknown pseudo header: %s", h.Name)
}
if isRequest && isResponsePseudoHeader {
return header{}, fmt.Errorf("invalid request pseudo header: %s", h.Name)
}
if !isRequest && !isResponsePseudoHeader {
return header{}, fmt.Errorf("invalid response pseudo header: %s", h.Name)
}
} else {
if !httpguts.ValidHeaderFieldName(h.Name) {
return header{}, fmt.Errorf("invalid header field name: %q", h.Name)
}
readFirstRegularHeader = true
switch h.Name {
case "content-length":
// Ignore duplicate Content-Length headers.
// Fail if the duplicates differ.
if !readContentLength {
readContentLength = true
contentLengthStr = h.Value
} else if contentLengthStr != h.Value {
return header{}, fmt.Errorf("contradicting content lengths (%s and %s)", contentLengthStr, h.Value)
}
default:
hdr.Headers.Add(h.Name, h.Value)
}
}
}
if len(contentLengthStr) > 0 {
// use ParseUint instead of ParseInt, so that parsing fails on negative values
cl, err := strconv.ParseUint(contentLengthStr, 10, 63)
if err != nil {
return header{}, fmt.Errorf("invalid content length: %w", err)
}
hdr.Headers.Set("Content-Length", contentLengthStr)
hdr.ContentLength = int64(cl)
}
return hdr, nil
}
func requestFromHeaders(headerFields []qpack.HeaderField) (*http.Request, error) {
hdr, err := parseHeaders(headerFields, true)
if err != nil {
return nil, err
}
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
if len(hdr.Headers["Cookie"]) > 0 {
hdr.Headers.Set("Cookie", strings.Join(hdr.Headers["Cookie"], "; "))
}
isConnect := hdr.Method == http.MethodConnect
// Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4
isExtendedConnected := isConnect && hdr.Protocol != ""
if isExtendedConnected {
if hdr.Scheme == "" || hdr.Path == "" || hdr.Authority == "" {
return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty")
}
} else if isConnect {
if hdr.Path != "" || hdr.Authority == "" { // normal CONNECT
return nil, errors.New(":path must be empty and :authority must not be empty")
}
} else if len(hdr.Path) == 0 || len(hdr.Authority) == 0 || len(hdr.Method) == 0 {
return nil, errors.New(":path, :authority and :method must not be empty")
}
var u *url.URL
var requestURI string
var protocol string
if isConnect {
u = &url.URL{}
if isExtendedConnected {
u, err = url.ParseRequestURI(hdr.Path)
if err != nil {
return nil, err
}
} else {
u.Path = hdr.Path
}
u.Scheme = hdr.Scheme
u.Host = hdr.Authority
requestURI = hdr.Authority
protocol = hdr.Protocol
} else {
protocol = "HTTP/3.0"
u, err = url.ParseRequestURI(hdr.Path)
if err != nil {
return nil, fmt.Errorf("invalid content length: %w", err)
}
requestURI = hdr.Path
}
return &http.Request{
Method: hdr.Method,
URL: u,
Proto: protocol,
ProtoMajor: 3,
ProtoMinor: 0,
Header: hdr.Headers,
Body: nil,
ContentLength: hdr.ContentLength,
Host: hdr.Authority,
RequestURI: requestURI,
}, nil
}
func hostnameFromRequest(req *http.Request) string {
if req.URL != nil {
return req.URL.Host
}
return ""
}
func responseFromHeaders(headerFields []qpack.HeaderField) (*http.Response, error) {
hdr, err := parseHeaders(headerFields, false)
if err != nil {
return nil, err
}
if hdr.Status == "" {
return nil, errors.New("missing status field")
}
rsp := &http.Response{
Proto: "HTTP/3.0",
ProtoMajor: 3,
Header: hdr.Headers,
ContentLength: hdr.ContentLength,
}
status, err := strconv.Atoi(hdr.Status)
if err != nil {
return nil, fmt.Errorf("invalid status code: %w", err)
}
rsp.StatusCode = status
rsp.Status = hdr.Status + " " + http.StatusText(status)
return rsp, nil
}

124
vendor/github.com/quic-go/quic-go/http3/http_stream.go generated vendored Normal file
View File

@@ -0,0 +1,124 @@
package http3
import (
"errors"
"fmt"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/utils"
)
// A Stream is a HTTP/3 stream.
// When writing to and reading from the stream, data is framed in HTTP/3 DATA frames.
type Stream quic.Stream
// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly
// from the QUIC stream, it writes to and reads from the HTTP stream.
type stream struct {
quic.Stream
buf []byte
onFrameError func()
bytesRemainingInFrame uint64
}
var _ Stream = &stream{}
func newStream(str quic.Stream, onFrameError func()) *stream {
return &stream{
Stream: str,
onFrameError: onFrameError,
buf: make([]byte, 0, 16),
}
}
func (s *stream) Read(b []byte) (int, error) {
if s.bytesRemainingInFrame == 0 {
parseLoop:
for {
frame, err := parseNextFrame(s.Stream, nil)
if err != nil {
return 0, err
}
switch f := frame.(type) {
case *headersFrame:
// skip HEADERS frames
continue
case *dataFrame:
s.bytesRemainingInFrame = f.Length
break parseLoop
default:
s.onFrameError()
// parseNextFrame skips over unknown frame types
// Therefore, this condition is only entered when we parsed another known frame type.
return 0, fmt.Errorf("peer sent an unexpected frame: %T", f)
}
}
}
var n int
var err error
if s.bytesRemainingInFrame < uint64(len(b)) {
n, err = s.Stream.Read(b[:s.bytesRemainingInFrame])
} else {
n, err = s.Stream.Read(b)
}
s.bytesRemainingInFrame -= uint64(n)
return n, err
}
func (s *stream) hasMoreData() bool {
return s.bytesRemainingInFrame > 0
}
func (s *stream) Write(b []byte) (int, error) {
s.buf = s.buf[:0]
s.buf = (&dataFrame{Length: uint64(len(b))}).Append(s.buf)
if _, err := s.Stream.Write(s.buf); err != nil {
return 0, err
}
return s.Stream.Write(b)
}
var errTooMuchData = errors.New("peer sent too much data")
type lengthLimitedStream struct {
*stream
contentLength int64
read int64
resetStream bool
}
var _ Stream = &lengthLimitedStream{}
func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream {
return &lengthLimitedStream{
stream: str,
contentLength: contentLength,
}
}
func (s *lengthLimitedStream) checkContentLengthViolation() error {
if s.read > s.contentLength || s.read == s.contentLength && s.hasMoreData() {
if !s.resetStream {
s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
s.resetStream = true
}
return errTooMuchData
}
return nil
}
func (s *lengthLimitedStream) Read(b []byte) (int, error) {
if err := s.checkContentLengthViolation(); err != nil {
return 0, err
}
n, err := s.stream.Read(b[:utils.Min(int64(len(b)), s.contentLength-s.read)])
s.read += int64(n)
if err := s.checkContentLengthViolation(); err != nil {
return n, err
}
return n, err
}

8
vendor/github.com/quic-go/quic-go/http3/mockgen.go generated vendored Normal file
View File

@@ -0,0 +1,8 @@
//go:build gomock || generate
package http3
//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 RoundTripCloser"
type RoundTripCloser = roundTripCloser
//go:generate sh -c "go run go.uber.org/mock/mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/quic-go/quic-go/http3 QUICEarlyListener"

View File

@@ -0,0 +1,287 @@
package http3
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"sync"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/idna"
"github.com/quic-go/qpack"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/utils"
)
const bodyCopyBufferSize = 8 * 1024
type requestWriter struct {
mutex sync.Mutex
encoder *qpack.Encoder
headerBuf *bytes.Buffer
logger utils.Logger
}
func newRequestWriter(logger utils.Logger) *requestWriter {
headerBuf := &bytes.Buffer{}
encoder := qpack.NewEncoder(headerBuf)
return &requestWriter{
encoder: encoder,
headerBuf: headerBuf,
logger: logger,
}
}
func (w *requestWriter) WriteRequestHeader(str quic.Stream, req *http.Request, gzip bool) error {
// TODO: figure out how to add support for trailers
buf := &bytes.Buffer{}
if err := w.writeHeaders(buf, req, gzip); err != nil {
return err
}
_, err := str.Write(buf.Bytes())
return err
}
func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) error {
w.mutex.Lock()
defer w.mutex.Unlock()
defer w.encoder.Close()
defer w.headerBuf.Reset()
if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil {
return err
}
b := make([]byte, 0, 128)
b = (&headersFrame{Length: uint64(w.headerBuf.Len())}).Append(b)
if _, err := wr.Write(b); err != nil {
return err
}
_, err := wr.Write(w.headerBuf.Bytes())
return err
}
// copied from net/transport.go
// Modified to support Extended CONNECT:
// Contrary to what the godoc for the http.Request says,
// we do respect the Proto field if the method is CONNECT.
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) error {
host := req.Host
if host == "" {
host = req.URL.Host
}
host, err := httpguts.PunycodeHostPort(host)
if err != nil {
return err
}
if !httpguts.ValidHostHeader(host) {
return errors.New("http3: invalid Host header")
}
// http.NewRequest sets this field to HTTP/1.1
isExtendedConnect := req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1"
var path string
if req.Method != http.MethodConnect || isExtendedConnect {
path = req.URL.RequestURI()
if !validPseudoPath(path) {
orig := path
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
if !validPseudoPath(path) {
if req.URL.Opaque != "" {
return fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
} else {
return fmt.Errorf("invalid request :path %q", orig)
}
}
}
}
// Check for any invalid headers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
return fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
}
}
}
enumerateHeaders := func(f func(name, value string)) {
// 8.1.2.3 Request Pseudo-Header Fields
// The :path pseudo-header field includes the path and query parts of the
// target URI (the path-absolute production and optionally a '?' character
// followed by the query production (see Sections 3.3 and 3.4 of
// [RFC3986]).
f(":authority", host)
f(":method", req.Method)
if req.Method != http.MethodConnect || isExtendedConnect {
f(":path", path)
f(":scheme", req.URL.Scheme)
}
if isExtendedConnect {
f(":protocol", req.Proto)
}
if trailers != "" {
f("trailer", trailers)
}
var didUA bool
for k, vv := range req.Header {
if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") {
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
} else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") ||
strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") ||
strings.EqualFold(k, "keep-alive") {
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We have already checked if any
// are error-worthy so just ignore the rest.
continue
} else if strings.EqualFold(k, "user-agent") {
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
continue
}
vv = vv[:1]
if vv[0] == "" {
continue
}
}
for _, v := range vv {
f(k, v)
}
}
if shouldSendReqContentLength(req.Method, contentLength) {
f("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader {
f("accept-encoding", "gzip")
}
if !didUA {
f("user-agent", defaultUserAgent)
}
}
// Do a first pass over the headers counting bytes to ensure
// we don't exceed cc.peerMaxHeaderListSize. This is done as a
// separate pass before encoding the headers to prevent
// modifying the hpack state.
hlSize := uint64(0)
enumerateHeaders(func(name, value string) {
hf := hpack.HeaderField{Name: name, Value: value}
hlSize += uint64(hf.Size())
})
// TODO: check maximum header list size
// if hlSize > cc.peerMaxHeaderListSize {
// return errRequestHeaderListSize
// }
// trace := httptrace.ContextClientTrace(req.Context())
// traceHeaders := traceHasWroteHeaderField(trace)
// Header list size is ok. Write the headers.
enumerateHeaders(func(name, value string) {
name = strings.ToLower(name)
w.encoder.WriteField(qpack.HeaderField{Name: name, Value: value})
// if traceHeaders {
// traceWroteHeaderField(trace, name, value)
// }
})
return nil
}
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
// and returns a host:port. The port 443 is added if needed.
func authorityAddr(scheme string, authority string) (addr string) {
host, port, err := net.SplitHostPort(authority)
if err != nil { // authority didn't have a port
port = "443"
if scheme == "http" {
port = "80"
}
host = authority
}
if a, err := idna.ToASCII(host); err == nil {
host = a
}
// IPv6 address literal, without a port:
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host + ":" + port
}
return net.JoinHostPort(host, port)
}
// validPseudoPath reports whether v is a valid :path pseudo-header
// value. It must be either:
//
// *) a non-empty string starting with '/'
// *) the string '*', for OPTIONS requests.
//
// For now this is only used a quick check for deciding when to clean
// up Opaque URLs before sending requests from the Transport.
// See golang.org/issue/16847
//
// We used to enforce that the path also didn't start with "//", but
// Google's GFE accepts such paths and Chrome sends them, so ignore
// that part of the spec. See golang.org/issue/19103.
func validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/') || v == "*"
}
// actualContentLength returns a sanitized version of
// req.ContentLength, where 0 actually means zero (not unknown) and -1
// means unknown.
func actualContentLength(req *http.Request) int64 {
if req.Body == nil {
return 0
}
if req.ContentLength != 0 {
return req.ContentLength
}
return -1
}
// shouldSendReqContentLength reports whether the http2.Transport should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
// For zero bodies, whether we send a content-length depends on the method.
// It also kinda doesn't matter for http2 either way, with END_STREAM.
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}

View File

@@ -0,0 +1,219 @@
package http3
import (
"bufio"
"bytes"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/qpack"
)
// The maximum length of an encoded HTTP/3 frame header is 16:
// The frame has a type and length field, both QUIC varints (maximum 8 bytes in length)
const frameHeaderLen = 16
// headerWriter wraps the stream, so that the first Write call flushes the header to the stream
type headerWriter struct {
str quic.Stream
header http.Header
status int // status code passed to WriteHeader
written bool
logger utils.Logger
}
// writeHeader encodes and flush header to the stream
func (hw *headerWriter) writeHeader() error {
var headers bytes.Buffer
enc := qpack.NewEncoder(&headers)
enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)})
for k, v := range hw.header {
for index := range v {
enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
}
}
buf := make([]byte, 0, frameHeaderLen+headers.Len())
buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf)
hw.logger.Infof("Responding with %d", hw.status)
buf = append(buf, headers.Bytes()...)
_, err := hw.str.Write(buf)
return err
}
// first Write will trigger flushing header
func (hw *headerWriter) Write(p []byte) (int, error) {
if !hw.written {
if err := hw.writeHeader(); err != nil {
return 0, err
}
hw.written = true
}
return hw.str.Write(p)
}
type responseWriter struct {
*headerWriter
conn quic.Connection
bufferedStr *bufio.Writer
buf []byte
headerWritten bool
contentLen int64 // if handler set valid Content-Length header
numWritten int64 // bytes written
}
var (
_ http.ResponseWriter = &responseWriter{}
_ http.Flusher = &responseWriter{}
_ Hijacker = &responseWriter{}
)
func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter {
hw := &headerWriter{
str: str,
header: http.Header{},
logger: logger,
}
return &responseWriter{
headerWriter: hw,
buf: make([]byte, frameHeaderLen),
conn: conn,
bufferedStr: bufio.NewWriter(hw),
}
}
func (w *responseWriter) Header() http.Header {
return w.header
}
func (w *responseWriter) WriteHeader(status int) {
if w.headerWritten {
return
}
// http status must be 3 digits
if status < 100 || status > 999 {
panic(fmt.Sprintf("invalid WriteHeader code %v", status))
}
if status >= 200 {
w.headerWritten = true
// Add Date header.
// This is what the standard library does.
// Can be disabled by setting the Date header to nil.
if _, ok := w.header["Date"]; !ok {
w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
}
// Content-Length checking
// use ParseUint instead of ParseInt, as negative values are invalid
if clen := w.header.Get("Content-Length"); clen != "" {
if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
w.contentLen = int64(cl)
} else {
// emit a warning for malformed Content-Length and remove it
w.logger.Errorf("Malformed Content-Length %s", clen)
w.header.Del("Content-Length")
}
}
}
w.status = status
if !w.headerWritten {
w.writeHeader()
}
}
func (w *responseWriter) Write(p []byte) (int, error) {
bodyAllowed := bodyAllowedForStatus(w.status)
if !w.headerWritten {
// If body is not allowed, we don't need to (and we can't) sniff the content type.
if bodyAllowed {
// If no content type, apply sniffing algorithm to body.
// We can't use `w.header.Get` here since if the Content-Type was set to nil, we shoundn't do sniffing.
_, haveType := w.header["Content-Type"]
// If the Transfer-Encoding or Content-Encoding was set and is non-blank,
// we shouldn't sniff the body.
hasTE := w.header.Get("Transfer-Encoding") != ""
hasCE := w.header.Get("Content-Encoding") != ""
if !hasCE && !haveType && !hasTE && len(p) > 0 {
w.header.Set("Content-Type", http.DetectContentType(p))
}
}
w.WriteHeader(http.StatusOK)
bodyAllowed = true
}
if !bodyAllowed {
return 0, http.ErrBodyNotAllowed
}
w.numWritten += int64(len(p))
if w.contentLen != 0 && w.numWritten > w.contentLen {
return 0, http.ErrContentLength
}
df := &dataFrame{Length: uint64(len(p))}
w.buf = w.buf[:0]
w.buf = df.Append(w.buf)
if _, err := w.bufferedStr.Write(w.buf); err != nil {
return 0, maybeReplaceError(err)
}
n, err := w.bufferedStr.Write(p)
return n, maybeReplaceError(err)
}
func (w *responseWriter) FlushError() error {
if !w.headerWritten {
w.WriteHeader(http.StatusOK)
}
if !w.written {
if err := w.writeHeader(); err != nil {
return maybeReplaceError(err)
}
w.written = true
}
return w.bufferedStr.Flush()
}
func (w *responseWriter) Flush() {
if err := w.FlushError(); err != nil {
w.logger.Errorf("could not flush to stream: %s", err.Error())
}
}
func (w *responseWriter) StreamCreator() StreamCreator {
return w.conn
}
func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
return w.str.SetReadDeadline(deadline)
}
func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
return w.str.SetWriteDeadline(deadline)
}
// copied from http2/http2.go
// bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 2616, section 4.4.
func bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == http.StatusNoContent:
return false
case status == http.StatusNotModified:
return false
}
return true
}

301
vendor/github.com/quic-go/quic-go/http3/roundtrip.go generated vendored Normal file
View File

@@ -0,0 +1,301 @@
package http3
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"sync/atomic"
"golang.org/x/net/http/httpguts"
"github.com/quic-go/quic-go"
)
type roundTripCloser interface {
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
HandshakeComplete() bool
io.Closer
}
type roundTripCloserWithCount struct {
roundTripCloser
useCount atomic.Int64
}
// RoundTripper implements the http.RoundTripper interface
type RoundTripper struct {
mutex sync.Mutex
// DisableCompression, if true, prevents the Transport from
// requesting compression with an "Accept-Encoding: gzip"
// request header when the Request contains no existing
// Accept-Encoding value. If the Transport requests gzip on
// its own and gets a gzipped response, it's transparently
// decoded in the Response.Body. However, if the user
// explicitly requested gzip it is not automatically
// uncompressed.
DisableCompression bool
// TLSClientConfig specifies the TLS configuration to use with
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
// QuicConfig is the quic.Config used for dialing new connections.
// If nil, reasonable default values will be used.
QuicConfig *quic.Config
// Enable support for HTTP/3 datagrams.
// If set to true, QuicConfig.EnableDatagram will be set.
// See https://datatracker.ietf.org/doc/html/rfc9297.
EnableDatagrams bool
// Additional HTTP/3 settings.
// It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft.
AdditionalSettings map[uint64]uint64
// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
// It is called right after parsing the frame type.
// If parsing the frame type fails, the error is passed to the callback.
// In that case, the frame type will not be set.
// Callers can either ignore the frame and return control of the stream back to HTTP/3
// (by returning hijacked false).
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
// When set, this callback is called for unknown unidirectional stream of unknown stream type.
// If parsing the stream type fails, the error is passed to the callback.
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, a UDPConn will be created at the first request
// and will be reused for subsequent connections to other servers.
Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
// MaxResponseHeaderBytes specifies a limit on how many response bytes are
// allowed in the server's response header.
// Zero means to use a default limit.
MaxResponseHeaderBytes int64
newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
clients map[string]*roundTripCloserWithCount
transport *quic.Transport
}
// RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct {
// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
OnlyCachedConn bool
// DontCloseRequestStream controls whether the request stream is closed after sending the request.
// If set, context cancellations have no effect after the response headers are received.
DontCloseRequestStream bool
}
var (
_ http.RoundTripper = &RoundTripper{}
_ io.Closer = &RoundTripper{}
)
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
var ErrNoCachedConn = errors.New("http3: no cached connection was available")
// RoundTripOpt is like RoundTrip, but takes options.
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if req.URL == nil {
closeRequestBody(req)
return nil, errors.New("http3: nil Request.URL")
}
if req.URL.Scheme != "https" {
closeRequestBody(req)
return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
}
if req.URL.Host == "" {
closeRequestBody(req)
return nil, errors.New("http3: no Host in request URL")
}
if req.Header == nil {
closeRequestBody(req)
return nil, errors.New("http3: nil Request.Header")
}
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("http3: invalid http header field name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
}
}
}
if req.Method != "" && !validMethod(req.Method) {
closeRequestBody(req)
return nil, fmt.Errorf("http3: invalid method %q", req.Method)
}
hostname := authorityAddr("https", hostnameFromRequest(req))
cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn)
if err != nil {
return nil, err
}
defer cl.useCount.Add(-1)
rsp, err := cl.RoundTripOpt(req, opt)
if err != nil {
r.removeClient(hostname)
if isReused {
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
return r.RoundTripOpt(req, opt)
}
}
}
return rsp, err
}
// RoundTrip does a round trip.
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
}
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
r.clients = make(map[string]*roundTripCloserWithCount)
}
client, ok := r.clients[hostname]
if !ok {
if onlyCached {
return nil, false, ErrNoCachedConn
}
var err error
newCl := newClient
if r.newClient != nil {
newCl = r.newClient
}
dial := r.Dial
if dial == nil {
if r.transport == nil {
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, false, err
}
r.transport = &quic.Transport{Conn: udpConn}
}
dial = r.makeDialer()
}
c, err := newCl(
hostname,
r.TLSClientConfig,
&roundTripperOpts{
EnableDatagram: r.EnableDatagrams,
DisableCompression: r.DisableCompression,
MaxHeaderBytes: r.MaxResponseHeaderBytes,
StreamHijacker: r.StreamHijacker,
UniStreamHijacker: r.UniStreamHijacker,
},
r.QuicConfig,
dial,
)
if err != nil {
return nil, false, err
}
client = &roundTripCloserWithCount{roundTripCloser: c}
r.clients[hostname] = client
} else if client.HandshakeComplete() {
isReused = true
}
client.useCount.Add(1)
return client, isReused, nil
}
func (r *RoundTripper) removeClient(hostname string) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
return
}
delete(r.clients, hostname)
}
// Close closes the QUIC connections that this RoundTripper has used.
// It also closes the underlying UDPConn if it is not nil.
func (r *RoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
for _, client := range r.clients {
if err := client.Close(); err != nil {
return err
}
}
r.clients = nil
if r.transport != nil {
if err := r.transport.Close(); err != nil {
return err
}
if err := r.transport.Conn.Close(); err != nil {
return err
}
r.transport = nil
}
return nil
}
func closeRequestBody(req *http.Request) {
if req.Body != nil {
req.Body.Close()
}
}
func validMethod(method string) bool {
/*
Method = "OPTIONS" ; Section 9.2
| "GET" ; Section 9.3
| "HEAD" ; Section 9.4
| "POST" ; Section 9.5
| "PUT" ; Section 9.6
| "DELETE" ; Section 9.7
| "TRACE" ; Section 9.8
| "CONNECT" ; Section 9.9
| extension-method
extension-method = token
token = 1*<any CHAR except CTLs or separators>
*/
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
// copied from net/http/http.go
func isNotToken(r rune) bool {
return !httpguts.IsTokenRune(r)
}
// makeDialer makes a QUIC dialer using r.udpConn.
func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
}
}
func (r *RoundTripper) CloseIdleConnections() {
r.mutex.Lock()
defer r.mutex.Unlock()
for hostname, client := range r.clients {
if client.useCount.Load() == 0 {
client.Close()
delete(r.clients, hostname)
}
}
}

767
vendor/github.com/quic-go/quic-go/http3/server.go generated vendored Normal file
View File

@@ -0,0 +1,767 @@
package http3
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
"github.com/quic-go/qpack"
)
// allows mocking of quic.Listen and quic.ListenAddr
var (
quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
return quic.ListenEarly(conn, tlsConf, config)
}
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
return quic.ListenAddrEarly(addr, tlsConf, config)
}
)
// NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2.
const NextProtoH3 = "h3"
// StreamType is the stream type of a unidirectional stream.
type StreamType uint64
const (
streamTypeControlStream = 0
streamTypePushStream = 1
streamTypeQPACKEncoderStream = 2
streamTypeQPACKDecoderStream = 3
)
// A QUICEarlyListener listens for incoming QUIC connections.
type QUICEarlyListener interface {
Accept(context.Context) (quic.EarlyConnection, error)
Addr() net.Addr
io.Closer
}
var _ QUICEarlyListener = &quic.EarlyListener{}
func versionToALPN(v protocol.VersionNumber) string {
//nolint:exhaustive // These are all the versions we care about.
switch v {
case protocol.Version1, protocol.Version2:
return NextProtoH3
default:
return ""
}
}
// ConfigureTLSConfig creates a new tls.Config which can be used
// to create a quic.Listener meant for serving http3. The created
// tls.Config adds the functionality of detecting the used QUIC version
// in order to set the correct ALPN value for the http3 connection.
func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config {
// The tls.Config used to setup the quic.Listener needs to have the GetConfigForClient callback set.
// That way, we can get the QUIC version and set the correct ALPN value.
return &tls.Config{
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
// determine the ALPN from the QUIC version used
proto := NextProtoH3
val := ch.Context().Value(quic.QUICVersionContextKey)
if v, ok := val.(quic.VersionNumber); ok {
proto = versionToALPN(v)
}
config := tlsConf
if tlsConf.GetConfigForClient != nil {
getConfigForClient := tlsConf.GetConfigForClient
var err error
conf, err := getConfigForClient(ch)
if err != nil {
return nil, err
}
if conf != nil {
config = conf
}
}
if config == nil {
return nil, nil
}
config = config.Clone()
config.NextProtos = []string{proto}
return config, nil
},
}
}
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation.
type contextKey struct {
name string
}
func (k *contextKey) String() string { return "quic-go/http3 context value " + k.name }
// ServerContextKey is a context key. It can be used in HTTP
// handlers with Context.Value to access the server that
// started the handler. The associated value will be of
// type *http3.Server.
var ServerContextKey = &contextKey{"http3-server"}
type requestError struct {
err error
streamErr ErrCode
connErr ErrCode
}
func newStreamError(code ErrCode, err error) requestError {
return requestError{err: err, streamErr: code}
}
func newConnError(code ErrCode, err error) requestError {
return requestError{err: err, connErr: code}
}
// listenerInfo contains info about specific listener added with addListener
type listenerInfo struct {
port int // 0 means that no info about port is available
}
// Server is a HTTP/3 server.
type Server struct {
// Addr optionally specifies the UDP address for the server to listen on,
// in the form "host:port".
//
// When used by ListenAndServe and ListenAndServeTLS methods, if empty,
// ":https" (port 443) is used. See net.Dial for details of the address
// format.
//
// Otherwise, if Port is not set and underlying QUIC listeners do not
// have valid port numbers, the port part is used in Alt-Svc headers set
// with SetQuicHeaders.
Addr string
// Port is used in Alt-Svc response headers set with SetQuicHeaders. If
// needed Port can be manually set when the Server is created.
//
// This is useful when a Layer 4 firewall is redirecting UDP traffic and
// clients must use a port different from the port the Server is
// listening on.
Port int
// TLSConfig provides a TLS configuration for use by server. It must be
// set for ListenAndServe and Serve methods.
TLSConfig *tls.Config
// QuicConfig provides the parameters for QUIC connection created with
// Serve. If nil, it uses reasonable default values.
//
// Configured versions are also used in Alt-Svc response header set with
// SetQuicHeaders.
QuicConfig *quic.Config
// Handler is the HTTP request handler to use. If not set, defaults to
// http.NotFound.
Handler http.Handler
// EnableDatagrams enables support for HTTP/3 datagrams.
// If set to true, QuicConfig.EnableDatagram will be set.
// See https://datatracker.ietf.org/doc/html/rfc9297.
EnableDatagrams bool
// MaxHeaderBytes controls the maximum number of bytes the server will
// read parsing the request HEADERS frame. It does not limit the size of
// the request body. If zero or negative, http.DefaultMaxHeaderBytes is
// used.
MaxHeaderBytes int
// AdditionalSettings specifies additional HTTP/3 settings.
// It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft.
AdditionalSettings map[uint64]uint64
// StreamHijacker, when set, is called for the first unknown frame parsed on a bidirectional stream.
// It is called right after parsing the frame type.
// If parsing the frame type fails, the error is passed to the callback.
// In that case, the frame type will not be set.
// Callers can either ignore the frame and return control of the stream back to HTTP/3
// (by returning hijacked false).
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
// UniStreamHijacker, when set, is called for unknown unidirectional stream of unknown stream type.
// If parsing the stream type fails, the error is passed to the callback.
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
mutex sync.RWMutex
listeners map[*QUICEarlyListener]listenerInfo
closed bool
altSvcHeader string
logger utils.Logger
}
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
//
// If s.Addr is blank, ":https" is used.
func (s *Server) ListenAndServe() error {
return s.serveConn(s.TLSConfig, nil)
}
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections.
//
// If s.Addr is blank, ":https" is used.
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
var err error
certs := make([]tls.Certificate, 1)
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return err
}
// We currently only use the cert-related stuff from tls.Config,
// so we don't need to make a full copy.
config := &tls.Config{
Certificates: certs,
}
return s.serveConn(config, nil)
}
// Serve an existing UDP connection.
// It is possible to reuse the same connection for outgoing connections.
// Closing the server does not close the connection.
func (s *Server) Serve(conn net.PacketConn) error {
return s.serveConn(s.TLSConfig, conn)
}
// ServeQUICConn serves a single QUIC connection.
func (s *Server) ServeQUICConn(conn quic.Connection) error {
s.mutex.Lock()
if s.logger == nil {
s.logger = utils.DefaultLogger.WithPrefix("server")
}
s.mutex.Unlock()
return s.handleConn(conn)
}
// ServeListener serves an existing QUIC listener.
// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config
// and use it to construct a http3-friendly QUIC listener.
// Closing the server does close the listener.
// ServeListener always returns a non-nil error. After Shutdown or Close, the returned error is http.ErrServerClosed.
func (s *Server) ServeListener(ln QUICEarlyListener) error {
if err := s.addListener(&ln); err != nil {
return err
}
defer s.removeListener(&ln)
for {
conn, err := ln.Accept(context.Background())
if err == quic.ErrServerClosed {
return http.ErrServerClosed
}
if err != nil {
return err
}
go func() {
if err := s.handleConn(conn); err != nil {
s.logger.Debugf(err.Error())
}
}()
}
}
var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig")
func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
if tlsConf == nil {
return errServerWithoutTLSConfig
}
s.mutex.Lock()
closed := s.closed
s.mutex.Unlock()
if closed {
return http.ErrServerClosed
}
baseConf := ConfigureTLSConfig(tlsConf)
quicConf := s.QuicConfig
if quicConf == nil {
quicConf = &quic.Config{Allow0RTT: true}
} else {
quicConf = s.QuicConfig.Clone()
}
if s.EnableDatagrams {
quicConf.EnableDatagrams = true
}
var ln QUICEarlyListener
var err error
if conn == nil {
addr := s.Addr
if addr == "" {
addr = ":https"
}
ln, err = quicListenAddr(addr, baseConf, quicConf)
} else {
ln, err = quicListen(conn, baseConf, quicConf)
}
if err != nil {
return err
}
return s.ServeListener(ln)
}
func extractPort(addr string) (int, error) {
_, portStr, err := net.SplitHostPort(addr)
if err != nil {
return 0, err
}
portInt, err := net.LookupPort("tcp", portStr)
if err != nil {
return 0, err
}
return portInt, nil
}
func (s *Server) generateAltSvcHeader() {
if len(s.listeners) == 0 {
// Don't announce any ports since no one is listening for connections
s.altSvcHeader = ""
return
}
// This code assumes that we will use protocol.SupportedVersions if no quic.Config is passed.
supportedVersions := protocol.SupportedVersions
if s.QuicConfig != nil && len(s.QuicConfig.Versions) > 0 {
supportedVersions = s.QuicConfig.Versions
}
// keep track of which have been seen so we don't yield duplicate values
seen := make(map[string]struct{}, len(supportedVersions))
var versionStrings []string
for _, version := range supportedVersions {
if v := versionToALPN(version); len(v) > 0 {
if _, ok := seen[v]; !ok {
versionStrings = append(versionStrings, v)
seen[v] = struct{}{}
}
}
}
var altSvc []string
addPort := func(port int) {
for _, v := range versionStrings {
altSvc = append(altSvc, fmt.Sprintf(`%s=":%d"; ma=2592000`, v, port))
}
}
if s.Port != 0 {
// if Port is specified, we must use it instead of the
// listener addresses since there's a reason it's specified.
addPort(s.Port)
} else {
// if we have some listeners assigned, try to find ports
// which we can announce, otherwise nothing should be announced
validPortsFound := false
for _, info := range s.listeners {
if info.port != 0 {
addPort(info.port)
validPortsFound = true
}
}
if !validPortsFound {
if port, err := extractPort(s.Addr); err == nil {
addPort(port)
}
}
}
s.altSvcHeader = strings.Join(altSvc, ",")
}
// We store a pointer to interface in the map set. This is safe because we only
// call trackListener via Serve and can track+defer untrack the same pointer to
// local variable there. We never need to compare a Listener from another caller.
func (s *Server) addListener(l *QUICEarlyListener) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.closed {
return http.ErrServerClosed
}
if s.logger == nil {
s.logger = utils.DefaultLogger.WithPrefix("server")
}
if s.listeners == nil {
s.listeners = make(map[*QUICEarlyListener]listenerInfo)
}
if port, err := extractPort((*l).Addr().String()); err == nil {
s.listeners[l] = listenerInfo{port}
} else {
s.logger.Errorf("Unable to extract port from listener %+v, will not be announced using SetQuicHeaders: %s", err)
s.listeners[l] = listenerInfo{}
}
s.generateAltSvcHeader()
return nil
}
func (s *Server) removeListener(l *QUICEarlyListener) {
s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.listeners, l)
s.generateAltSvcHeader()
}
func (s *Server) handleConn(conn quic.Connection) error {
decoder := qpack.NewDecoder(nil)
// send a SETTINGS frame
str, err := conn.OpenUniStream()
if err != nil {
return fmt.Errorf("opening the control stream failed: %w", err)
}
b := make([]byte, 0, 64)
b = quicvarint.Append(b, streamTypeControlStream) // stream type
b = (&settingsFrame{Datagram: s.EnableDatagrams, Other: s.AdditionalSettings}).Append(b)
str.Write(b)
go s.handleUnidirectionalStreams(conn)
// Process all requests immediately.
// It's the client's responsibility to decide which requests are eligible for 0-RTT.
for {
str, err := conn.AcceptStream(context.Background())
if err != nil {
var appErr *quic.ApplicationError
if errors.As(err, &appErr) && appErr.ErrorCode == quic.ApplicationErrorCode(ErrCodeNoError) {
return nil
}
return fmt.Errorf("accepting stream failed: %w", err)
}
go func() {
rerr := s.handleRequest(conn, str, decoder, func() {
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "")
})
if rerr.err == errHijacked {
return
}
if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
s.logger.Debugf("Handling request failed: %s", err)
if rerr.streamErr != 0 {
str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
}
if rerr.connErr != 0 {
var reason string
if rerr.err != nil {
reason = rerr.err.Error()
}
conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
}
return
}
str.Close()
}()
}
}
func (s *Server) handleUnidirectionalStreams(conn quic.Connection) {
for {
str, err := conn.AcceptUniStream(context.Background())
if err != nil {
s.logger.Debugf("accepting unidirectional stream failed: %s", err)
return
}
go func(str quic.ReceiveStream) {
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, err) {
return
}
s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
return
}
// We're only interested in the control stream here.
switch streamType {
case streamTypeControlStream:
case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream:
// Our QPACK implementation doesn't use the dynamic table yet.
// TODO: check that only one stream of each type is opened.
return
case streamTypePushStream: // only the server can push
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "")
return
default:
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
return
}
str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
return
}
f, err := parseNextFrame(str, nil)
if err != nil {
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
return
}
sf, ok := f.(*settingsFrame)
if !ok {
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
return
}
if !sf.Datagram {
return
}
// If datagram support was enabled on our side as well as on the client side,
// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
if s.EnableDatagrams && !conn.ConnectionState().SupportsDatagrams {
conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
}
}(str)
}
}
func (s *Server) maxHeaderBytes() uint64 {
if s.MaxHeaderBytes <= 0 {
return http.DefaultMaxHeaderBytes
}
return uint64(s.MaxHeaderBytes)
}
func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
var ufh unknownFrameHandlerFunc
if s.StreamHijacker != nil {
ufh = func(ft FrameType, e error) (processed bool, err error) { return s.StreamHijacker(ft, conn, str, e) }
}
frame, err := parseNextFrame(str, ufh)
if err != nil {
if err == errHijacked {
return requestError{err: errHijacked}
}
return newStreamError(ErrCodeRequestIncomplete, err)
}
hf, ok := frame.(*headersFrame)
if !ok {
return newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
}
if hf.Length > s.maxHeaderBytes() {
return newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes()))
}
headerBlock := make([]byte, hf.Length)
if _, err := io.ReadFull(str, headerBlock); err != nil {
return newStreamError(ErrCodeRequestIncomplete, err)
}
hfs, err := decoder.DecodeFull(headerBlock)
if err != nil {
// TODO: use the right error code
return newConnError(ErrCodeGeneralProtocolError, err)
}
req, err := requestFromHeaders(hfs)
if err != nil {
return newStreamError(ErrCodeMessageError, err)
}
connState := conn.ConnectionState().TLS
req.TLS = &connState
req.RemoteAddr = conn.RemoteAddr().String()
// Check that the client doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
// See section 4.1.2 of RFC 9114.
var httpStr Stream
if _, ok := req.Header["Content-Length"]; ok && req.ContentLength >= 0 {
httpStr = newLengthLimitedStream(newStream(str, onFrameError), req.ContentLength)
} else {
httpStr = newStream(str, onFrameError)
}
body := newRequestBody(httpStr)
req.Body = body
if s.logger.Debug() {
s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID())
} else {
s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
}
ctx := str.Context()
ctx = context.WithValue(ctx, ServerContextKey, s)
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
req = req.WithContext(ctx)
r := newResponseWriter(str, conn, s.logger)
handler := s.Handler
if handler == nil {
handler = http.DefaultServeMux
}
var panicked bool
func() {
defer func() {
if p := recover(); p != nil {
panicked = true
if p == http.ErrAbortHandler {
return
}
// Copied from net/http/server.go
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
}
}()
handler.ServeHTTP(r, req)
}()
if body.wasStreamHijacked() {
return requestError{err: errHijacked}
}
// only write response when there is no panic
if !panicked {
// response not written to the client yet, set Content-Length
if !r.written {
if _, haveCL := r.header["Content-Length"]; !haveCL {
r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10))
}
}
r.Flush()
}
// If the EOF was read by the handler, CancelRead() is a no-op.
str.CancelRead(quic.StreamErrorCode(ErrCodeNoError))
return requestError{}
}
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
func (s *Server) Close() error {
s.mutex.Lock()
defer s.mutex.Unlock()
s.closed = true
var err error
for ln := range s.listeners {
if cerr := (*ln).Close(); cerr != nil && err == nil {
err = cerr
}
}
return err
}
// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
func (s *Server) CloseGracefully(timeout time.Duration) error {
// TODO: implement
return nil
}
// ErrNoAltSvcPort is the error returned by SetQuicHeaders when no port was found
// for Alt-Svc to announce. This can happen if listening on a PacketConn without a port
// (UNIX socket, for example) and no port is specified in Server.Port or Server.Addr.
var ErrNoAltSvcPort = errors.New("no port can be announced, specify it explicitly using Server.Port or Server.Addr")
// SetQuicHeaders can be used to set the proper headers that announce that this server supports HTTP/3.
// The values set by default advertise all of the ports the server is listening on, but can be
// changed to a specific port by setting Server.Port before launching the serverr.
// If no listener's Addr().String() returns an address with a valid port, Server.Addr will be used
// to extract the port, if specified.
// For example, a server launched using ListenAndServe on an address with port 443 would set:
//
// Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
func (s *Server) SetQuicHeaders(hdr http.Header) error {
s.mutex.RLock()
defer s.mutex.RUnlock()
if s.altSvcHeader == "" {
return ErrNoAltSvcPort
}
// use the map directly to avoid constant canonicalization
// since the key is already canonicalized
hdr["Alt-Svc"] = append(hdr["Alt-Svc"], s.altSvcHeader)
return nil
}
// ListenAndServeQUIC listens on the UDP network address addr and calls the
// handler for HTTP/3 requests on incoming connections. http.DefaultServeMux is
// used when handler is nil.
func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
server := &Server{
Addr: addr,
Handler: handler,
}
return server.ListenAndServeTLS(certFile, keyFile)
}
// ListenAndServe listens on the given network address for both, TLS and QUIC
// connections in parallel. It returns if one of the two returns an error.
// http.DefaultServeMux is used when handler is nil.
// The correct Alt-Svc headers for QUIC are set.
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
// Load certs
var err error
certs := make([]tls.Certificate, 1)
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return err
}
// We currently only use the cert-related stuff from tls.Config,
// so we don't need to make a full copy.
config := &tls.Config{
Certificates: certs,
}
if addr == "" {
addr = ":https"
}
// Open the listeners
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
defer udpConn.Close()
if handler == nil {
handler = http.DefaultServeMux
}
// Start the servers
quicServer := &Server{
TLSConfig: config,
Handler: handler,
}
hErr := make(chan error)
qErr := make(chan error)
go func() {
hErr <- http.ListenAndServeTLS(addr, certFile, keyFile, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
quicServer.SetQuicHeaders(w.Header())
handler.ServeHTTP(w, r)
}))
}()
go func() {
qErr <- quicServer.Serve(udpConn)
}()
select {
case err := <-hErr:
quicServer.Close()
return err
case err := <-qErr:
// Cannot close the HTTP server or wait for requests to complete properly :/
return err
}
}

349
vendor/github.com/quic-go/quic-go/interface.go generated vendored Normal file
View File

@@ -0,0 +1,349 @@
package quic
import (
"context"
"crypto/tls"
"errors"
"io"
"net"
"time"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/logging"
)
// The StreamID is the ID of a QUIC stream.
type StreamID = protocol.StreamID
// A VersionNumber is a QUIC version number.
type VersionNumber = protocol.VersionNumber
const (
// Version1 is RFC 9000
Version1 = protocol.Version1
// Version2 is RFC 9369
Version2 = protocol.Version2
)
// A ClientToken is a token received by the client.
// It can be used to skip address validation on future connection attempts.
type ClientToken struct {
data []byte
}
type TokenStore interface {
// Pop searches for a ClientToken associated with the given key.
// Since tokens are not supposed to be reused, it must remove the token from the cache.
// It returns nil when no token is found.
Pop(key string) (token *ClientToken)
// Put adds a token to the cache with the given key. It might get called
// multiple times in a connection.
Put(key string, token *ClientToken)
}
// Err0RTTRejected is the returned from:
// * Open{Uni}Stream{Sync}
// * Accept{Uni}Stream
// * Stream.Read and Stream.Write
// when the server rejects a 0-RTT connection attempt.
var Err0RTTRejected = errors.New("0-RTT rejected")
// ConnectionTracingKey can be used to associate a ConnectionTracer with a Connection.
// It is set on the Connection.Context() context,
// as well as on the context passed to logging.Tracer.NewConnectionTracer.
var ConnectionTracingKey = connTracingCtxKey{}
type connTracingCtxKey struct{}
// QUICVersionContextKey can be used to find out the QUIC version of a TLS handshake from the
// context returned by tls.Config.ClientHelloInfo.Context.
var QUICVersionContextKey = handshake.QUICVersionContextKey
// Stream is the interface implemented by QUIC streams
// In addition to the errors listed on the Connection,
// calls to stream functions can return a StreamError if the stream is canceled.
type Stream interface {
ReceiveStream
SendStream
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
SetDeadline(t time.Time) error
}
// A ReceiveStream is a unidirectional Receive Stream.
type ReceiveStream interface {
// StreamID returns the stream ID.
StreamID() StreamID
// Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
// If the connection was closed due to a timeout, the error satisfies
// the net.Error interface, and Timeout() will be true.
io.Reader
// CancelRead aborts receiving on this stream.
// It will ask the peer to stop transmitting stream data.
// Read will unblock immediately, and future Read calls will fail.
// When called multiple times or after reading the io.EOF it is a no-op.
CancelRead(StreamErrorCode)
// SetReadDeadline sets the deadline for future Read calls and
// any currently-blocked Read call.
// A zero value for t means Read will not time out.
SetReadDeadline(t time.Time) error
}
// A SendStream is a unidirectional Send Stream.
type SendStream interface {
// StreamID returns the stream ID.
StreamID() StreamID
// Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
// If the connection was closed due to a timeout, the error satisfies
// the net.Error interface, and Timeout() will be true.
io.Writer
// Close closes the write-direction of the stream.
// Future calls to Write are not permitted after calling Close.
// It must not be called concurrently with Write.
// It must not be called after calling CancelWrite.
io.Closer
// CancelWrite aborts sending on this stream.
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
// Write will unblock immediately, and future calls to Write will fail.
// When called multiple times or after closing the stream it is a no-op.
CancelWrite(StreamErrorCode)
// The Context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() or CancelWrite() is called, or when the peer
// cancels the read-side of their stream.
// The cancellation cause is set to the error that caused the stream to
// close, or `context.Canceled` in case the stream is closed without error.
Context() context.Context
// SetWriteDeadline sets the deadline for future Write calls
// and any currently-blocked Write call.
// Even if write times out, it may return n > 0, indicating that
// some data was successfully written.
// A zero value for t means Write will not time out.
SetWriteDeadline(t time.Time) error
}
// A Connection is a QUIC connection between two peers.
// Calls to the connection (and to streams) can return the following types of errors:
// * ApplicationError: for errors triggered by the application running on top of QUIC
// * TransportError: for errors triggered by the QUIC transport (in many cases a misbehaving peer)
// * IdleTimeoutError: when the peer goes away unexpectedly (this is a net.Error timeout error)
// * HandshakeTimeoutError: when the cryptographic handshake takes too long (this is a net.Error timeout error)
// * StatelessResetError: when we receive a stateless reset (this is a net.Error temporary error)
// * VersionNegotiationError: returned by the client, when there's no version overlap between the peers
type Connection interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
// If the connection was closed due to a timeout, the error satisfies
// the net.Error interface, and Timeout() will be true.
AcceptStream(context.Context) (Stream, error)
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
// If the connection was closed due to a timeout, the error satisfies
// the net.Error interface, and Timeout() will be true.
AcceptUniStream(context.Context) (ReceiveStream, error)
// OpenStream opens a new bidirectional QUIC stream.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream.
// If the error is non-nil, it satisfies the net.Error interface.
// When reaching the peer's stream limit, err.Temporary() will be true.
// If the connection was closed due to a timeout, Timeout() will be true.
OpenStream() (Stream, error)
// OpenStreamSync opens a new bidirectional QUIC stream.
// It blocks until a new stream can be opened.
// If the error is non-nil, it satisfies the net.Error interface.
// If the connection was closed due to a timeout, Timeout() will be true.
OpenStreamSync(context.Context) (Stream, error)
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
// If the error is non-nil, it satisfies the net.Error interface.
// When reaching the peer's stream limit, Temporary() will be true.
// If the connection was closed due to a timeout, Timeout() will be true.
OpenUniStream() (SendStream, error)
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
// It blocks until a new stream can be opened.
// If the error is non-nil, it satisfies the net.Error interface.
// If the connection was closed due to a timeout, Timeout() will be true.
OpenUniStreamSync(context.Context) (SendStream, error)
// LocalAddr returns the local address.
LocalAddr() net.Addr
// RemoteAddr returns the address of the peer.
RemoteAddr() net.Addr
// CloseWithError closes the connection with an error.
// The error string will be sent to the peer.
CloseWithError(ApplicationErrorCode, string) error
// Context returns a context that is cancelled when the connection is closed.
// The cancellation cause is set to the error that caused the connection to
// close, or `context.Canceled` in case the listener is closed first.
Context() context.Context
// ConnectionState returns basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
ConnectionState() ConnectionState
// SendMessage sends a message as a datagram, as specified in RFC 9221.
SendMessage([]byte) error
// ReceiveMessage gets a message received in a datagram, as specified in RFC 9221.
ReceiveMessage(context.Context) ([]byte, error)
}
// An EarlyConnection is a connection that is handshaking.
// Data sent during the handshake is encrypted using the forward secure keys.
// When using client certificates, the client's identity is only verified
// after completion of the handshake.
type EarlyConnection interface {
Connection
// HandshakeComplete blocks until the handshake completes (or fails).
// For the client, data sent before completion of the handshake is encrypted with 0-RTT keys.
// For the server, data sent before completion of the handshake is encrypted with 1-RTT keys,
// however the client's identity is only verified once the handshake completes.
HandshakeComplete() <-chan struct{}
NextConnection() Connection
}
// StatelessResetKey is a key used to derive stateless reset tokens.
type StatelessResetKey [32]byte
// TokenGeneratorKey is a key used to encrypt session resumption tokens.
type TokenGeneratorKey = handshake.TokenProtectorKey
// A ConnectionID is a QUIC Connection ID, as defined in RFC 9000.
// It is not able to handle QUIC Connection IDs longer than 20 bytes,
// as they are allowed by RFC 8999.
type ConnectionID = protocol.ConnectionID
// ConnectionIDFromBytes interprets b as a Connection ID. It panics if b is
// longer than 20 bytes.
func ConnectionIDFromBytes(b []byte) ConnectionID {
return protocol.ParseConnectionID(b)
}
// A ConnectionIDGenerator is an interface that allows clients to implement their own format
// for the Connection IDs that servers/clients use as SrcConnectionID in QUIC packets.
//
// Connection IDs generated by an implementation should always produce IDs of constant size.
type ConnectionIDGenerator interface {
// GenerateConnectionID generates a new ConnectionID.
// Generated ConnectionIDs should be unique and observers should not be able to correlate two ConnectionIDs.
GenerateConnectionID() (ConnectionID, error)
// ConnectionIDLen tells what is the length of the ConnectionIDs generated by the implementation of
// this interface.
// Effectively, this means that implementations of ConnectionIDGenerator must always return constant-size
// connection IDs. Valid lengths are between 0 and 20 and calls to GenerateConnectionID.
// 0-length ConnectionsIDs can be used when an endpoint (server or client) does not require multiplexing connections
// in the presence of a connection migration environment.
ConnectionIDLen() int
}
// Config contains all configuration data needed for a QUIC server or client.
type Config struct {
// GetConfigForClient is called for incoming connections.
// If the error is not nil, the connection attempt is refused.
GetConfigForClient func(info *ClientHelloInfo) (*Config, error)
// The QUIC versions that can be negotiated.
// If not set, it uses all versions available.
Versions []VersionNumber
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
// If we don't receive any packet from the peer within this time, the connection attempt is aborted.
// Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted.
// If this value is zero, the timeout is set to 5 seconds.
HandshakeIdleTimeout time.Duration
// MaxIdleTimeout is the maximum duration that may pass without any incoming network activity.
// The actual value for the idle timeout is the minimum of this value and the peer's.
// This value only applies after the handshake has completed.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 30 seconds.
MaxIdleTimeout time.Duration
// RequireAddressValidation determines if a QUIC Retry packet is sent.
// This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT.
// See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details.
// If not set, every client is forced to prove its remote address.
RequireAddressValidation func(net.Addr) bool
// The TokenStore stores tokens received from the server.
// Tokens are used to skip address validation on future connection attempts.
// The key used to store tokens is the ServerName from the tls.Config, if set
// otherwise the token is associated with the server's IP address.
TokenStore TokenStore
// InitialStreamReceiveWindow is the initial size of the stream-level flow control window for receiving data.
// If the application is consuming data quickly enough, the flow control auto-tuning algorithm
// will increase the window up to MaxStreamReceiveWindow.
// If this value is zero, it will default to 512 KB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
InitialStreamReceiveWindow uint64
// MaxStreamReceiveWindow is the maximum stream-level flow control window for receiving data.
// If this value is zero, it will default to 6 MB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
MaxStreamReceiveWindow uint64
// InitialConnectionReceiveWindow is the initial size of the stream-level flow control window for receiving data.
// If the application is consuming data quickly enough, the flow control auto-tuning algorithm
// will increase the window up to MaxConnectionReceiveWindow.
// If this value is zero, it will default to 512 KB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
InitialConnectionReceiveWindow uint64
// MaxConnectionReceiveWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 15 MB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
MaxConnectionReceiveWindow uint64
// AllowConnectionWindowIncrease is called every time the connection flow controller attempts
// to increase the connection flow control window.
// If set, the caller can prevent an increase of the window. Typically, it would do so to
// limit the memory usage.
// To avoid deadlocks, it is not valid to call other functions on the connection or on streams
// in this callback.
AllowConnectionWindowIncrease func(conn Connection, delta uint64) bool
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any bidirectional streams.
// Values larger than 2^60 will be clipped to that value.
MaxIncomingStreams int64
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any unidirectional streams.
// Values larger than 2^60 will be clipped to that value.
MaxIncomingUniStreams int64
// KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive.
// If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most
// every half of MaxIdleTimeout, whichever is smaller).
KeepAlivePeriod time.Duration
// DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899).
// This allows the sending of QUIC packets that fully utilize the available MTU of the path.
// Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit.
// If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
DisablePathMTUDiscovery bool
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
// Only valid for the server.
Allow0RTT bool
// Enable QUIC datagram support (RFC 9221).
EnableDatagrams bool
Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer
}
type ClientHelloInfo struct {
RemoteAddr net.Addr
}
// ConnectionState records basic details about a QUIC connection
type ConnectionState struct {
// TLS contains information about the TLS connection state, incl. the tls.ConnectionState.
TLS tls.ConnectionState
// SupportsDatagrams says if support for QUIC datagrams (RFC 9221) was negotiated.
// This requires both nodes to support and enable the datagram extensions (via Config.EnableDatagrams).
// If datagram support was negotiated, datagrams can be sent and received using the
// SendMessage and ReceiveMessage methods on the Connection.
SupportsDatagrams bool
// Used0RTT says if 0-RTT resumption was used.
Used0RTT bool
// Version is the QUIC version of the QUIC connection.
Version VersionNumber
// GSO says if generic segmentation offload is used
GSO bool
}

View File

@@ -0,0 +1,20 @@
package ackhandler
import "github.com/quic-go/quic-go/internal/wire"
// IsFrameAckEliciting returns true if the frame is ack-eliciting.
func IsFrameAckEliciting(f wire.Frame) bool {
_, isAck := f.(*wire.AckFrame)
_, isConnectionClose := f.(*wire.ConnectionCloseFrame)
return !isAck && !isConnectionClose
}
// HasAckElicitingFrames returns true if at least one frame is ack-eliciting.
func HasAckElicitingFrames(fs []Frame) bool {
for _, f := range fs {
if IsFrameAckEliciting(f.Frame) {
return true
}
}
return false
}

View File

@@ -0,0 +1,24 @@
package ackhandler
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler.
// clientAddressValidated indicates whether the address was validated beforehand by an address validation token.
// clientAddressValidated has no effect for a client.
func NewAckHandler(
initialPacketNumber protocol.PacketNumber,
initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats,
clientAddressValidated bool,
enableECN bool,
pers protocol.Perspective,
tracer *logging.ConnectionTracer,
logger utils.Logger,
) (SentPacketHandler, ReceivedPacketHandler) {
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger)
return sph, newReceivedPacketHandler(sph, rttStats, logger)
}

View File

@@ -0,0 +1,296 @@
package ackhandler
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
type ecnState uint8
const (
ecnStateInitial ecnState = iota
ecnStateTesting
ecnStateUnknown
ecnStateCapable
ecnStateFailed
)
// must fit into an uint8, otherwise numSentTesting and numLostTesting must have a larger type
const numECNTestingPackets = 10
type ecnHandler interface {
SentPacket(protocol.PacketNumber, protocol.ECN)
Mode() protocol.ECN
HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool)
LostPacket(protocol.PacketNumber)
}
// The ecnTracker performs ECN validation of a path.
// Once failed, it doesn't do any re-validation of the path.
// It is designed only work for 1-RTT packets, it doesn't handle multiple packet number spaces.
// In order to avoid revealing any internal state to on-path observers,
// callers should make sure to start using ECN (i.e. calling Mode) for the very first 1-RTT packet sent.
// The validation logic implemented here strictly follows the algorithm described in RFC 9000 section 13.4.2 and A.4.
type ecnTracker struct {
state ecnState
numSentTesting, numLostTesting uint8
firstTestingPacket protocol.PacketNumber
lastTestingPacket protocol.PacketNumber
firstCapablePacket protocol.PacketNumber
numSentECT0, numSentECT1 int64
numAckedECT0, numAckedECT1, numAckedECNCE int64
tracer *logging.ConnectionTracer
logger utils.Logger
}
var _ ecnHandler = &ecnTracker{}
func newECNTracker(logger utils.Logger, tracer *logging.ConnectionTracer) *ecnTracker {
return &ecnTracker{
firstTestingPacket: protocol.InvalidPacketNumber,
lastTestingPacket: protocol.InvalidPacketNumber,
firstCapablePacket: protocol.InvalidPacketNumber,
state: ecnStateInitial,
logger: logger,
tracer: tracer,
}
}
func (e *ecnTracker) SentPacket(pn protocol.PacketNumber, ecn protocol.ECN) {
//nolint:exhaustive // These are the only ones we need to take care of.
switch ecn {
case protocol.ECNNon:
return
case protocol.ECT0:
e.numSentECT0++
case protocol.ECT1:
e.numSentECT1++
case protocol.ECNUnsupported:
if e.state != ecnStateFailed {
panic("didn't expect ECN to be unsupported")
}
default:
panic(fmt.Sprintf("sent packet with unexpected ECN marking: %s", ecn))
}
if e.state == ecnStateCapable && e.firstCapablePacket == protocol.InvalidPacketNumber {
e.firstCapablePacket = pn
}
if e.state != ecnStateTesting {
return
}
e.numSentTesting++
if e.firstTestingPacket == protocol.InvalidPacketNumber {
e.firstTestingPacket = pn
}
if e.numSentECT0+e.numSentECT1 >= numECNTestingPackets {
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger)
}
e.state = ecnStateUnknown
e.lastTestingPacket = pn
}
}
func (e *ecnTracker) Mode() protocol.ECN {
switch e.state {
case ecnStateInitial:
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger)
}
e.state = ecnStateTesting
return e.Mode()
case ecnStateTesting, ecnStateCapable:
return protocol.ECT0
case ecnStateUnknown, ecnStateFailed:
return protocol.ECNNon
default:
panic(fmt.Sprintf("unknown ECN state: %d", e.state))
}
}
func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) {
if e.state != ecnStateTesting && e.state != ecnStateUnknown {
return
}
if !e.isTestingPacket(pn) {
return
}
e.numLostTesting++
// Only proceed if we have sent all 10 testing packets.
if e.state != ecnStateUnknown {
return
}
if e.numLostTesting >= e.numSentTesting {
e.logger.Debugf("Disabling ECN. All testing packets were lost.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets)
}
e.state = ecnStateFailed
return
}
// Path validation also fails if some testing packets are lost, and all other testing packets where CE-marked
e.failIfMangled()
}
// HandleNewlyAcked handles the ECN counts on an ACK frame.
// It must only be called for ACK frames that increase the largest acknowledged packet number,
// see section 13.4.2.1 of RFC 9000.
func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool) {
if e.state == ecnStateFailed {
return false
}
// ECN validation can fail if the received total count for either ECT(0) or ECT(1) exceeds
// the total number of packets sent with each corresponding ECT codepoint.
if ect0 > e.numSentECT0 || ect1 > e.numSentECT1 {
e.logger.Debugf("Disabling ECN. Received more ECT(0) / ECT(1) acknowledgements than packets sent.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent)
}
e.state = ecnStateFailed
return false
}
// Count ECT0 and ECT1 marks that we used when sending the packets that are now being acknowledged.
var ackedECT0, ackedECT1 int64
for _, p := range packets {
//nolint:exhaustive // We only ever send ECT(0) and ECT(1).
switch e.ecnMarking(p.PacketNumber) {
case protocol.ECT0:
ackedECT0++
case protocol.ECT1:
ackedECT1++
}
}
// If an ACK frame newly acknowledges a packet that the endpoint sent with either the ECT(0) or ECT(1)
// codepoint set, ECN validation fails if the corresponding ECN counts are not present in the ACK frame.
// This check detects:
// * paths that bleach all ECN marks, and
// * peers that don't report any ECN counts
if (ackedECT0 > 0 || ackedECT1 > 0) && ect0 == 0 && ect1 == 0 && ecnce == 0 {
e.logger.Debugf("Disabling ECN. ECN-marked packet acknowledged, but no ECN counts on ACK frame.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts)
}
e.state = ecnStateFailed
return false
}
// Determine the increase in ECT0, ECT1 and ECNCE marks
newECT0 := ect0 - e.numAckedECT0
newECT1 := ect1 - e.numAckedECT1
newECNCE := ecnce - e.numAckedECNCE
// We're only processing ACKs that increase the Largest Acked.
// Therefore, the ECN counters should only ever increase.
// Any decrease means that the peer's counting logic is broken.
if newECT0 < 0 || newECT1 < 0 || newECNCE < 0 {
e.logger.Debugf("Disabling ECN. ECN counts decreased unexpectedly.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedDecreasedECNCounts)
}
e.state = ecnStateFailed
return false
}
// ECN validation also fails if the sum of the increase in ECT(0) and ECN-CE counts is less than the number
// of newly acknowledged packets that were originally sent with an ECT(0) marking.
// This could be the result of (partial) bleaching.
if newECT0+newECNCE < ackedECT0 {
e.logger.Debugf("Disabling ECN. Received less ECT(0) + ECN-CE than packets sent with ECT(0).")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts)
}
e.state = ecnStateFailed
return false
}
// Similarly, ECN validation fails if the sum of the increases to ECT(1) and ECN-CE counts is less than
// the number of newly acknowledged packets sent with an ECT(1) marking.
if newECT1+newECNCE < ackedECT1 {
e.logger.Debugf("Disabling ECN. Received less ECT(1) + ECN-CE than packets sent with ECT(1).")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts)
}
e.state = ecnStateFailed
return false
}
// update our counters
e.numAckedECT0 = ect0
e.numAckedECT1 = ect1
e.numAckedECNCE = ecnce
// Detect mangling (a path remarking all ECN-marked testing packets as CE),
// once all 10 testing packets have been sent out.
if e.state == ecnStateUnknown {
e.failIfMangled()
if e.state == ecnStateFailed {
return false
}
}
if e.state == ecnStateTesting || e.state == ecnStateUnknown {
var ackedTestingPacket bool
for _, p := range packets {
if e.isTestingPacket(p.PacketNumber) {
ackedTestingPacket = true
break
}
}
// This check won't succeed if the path is mangling ECN-marks (i.e. rewrites all ECN-marked packets to CE).
if ackedTestingPacket && (newECT0 > 0 || newECT1 > 0) {
e.logger.Debugf("ECN capability confirmed.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger)
}
e.state = ecnStateCapable
}
}
// Don't trust CE marks before having confirmed ECN capability of the path.
// Otherwise, mangling would be misinterpreted as actual congestion.
return e.state == ecnStateCapable && newECNCE > 0
}
// failIfMangled fails ECN validation if all testing packets are lost or CE-marked.
func (e *ecnTracker) failIfMangled() {
numAckedECNCE := e.numAckedECNCE + int64(e.numLostTesting)
if e.numSentECT0+e.numSentECT1 > numAckedECNCE {
return
}
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected)
}
e.state = ecnStateFailed
}
func (e *ecnTracker) ecnMarking(pn protocol.PacketNumber) protocol.ECN {
if pn < e.firstTestingPacket || e.firstTestingPacket == protocol.InvalidPacketNumber {
return protocol.ECNNon
}
if pn < e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber {
return protocol.ECT0
}
if pn < e.firstCapablePacket || e.firstCapablePacket == protocol.InvalidPacketNumber {
return protocol.ECNNon
}
// We don't need to deal with the case when ECN validation fails,
// since we're ignoring any ECN counts reported in ACK frames in that case.
return protocol.ECT0
}
func (e *ecnTracker) isTestingPacket(pn protocol.PacketNumber) bool {
if e.firstTestingPacket == protocol.InvalidPacketNumber {
return false
}
return pn >= e.firstTestingPacket && (pn <= e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber)
}

View File

@@ -0,0 +1,21 @@
package ackhandler
import (
"github.com/quic-go/quic-go/internal/wire"
)
// FrameHandler handles the acknowledgement and the loss of a frame.
type FrameHandler interface {
OnAcked(wire.Frame)
OnLost(wire.Frame)
}
type Frame struct {
Frame wire.Frame // nil if the frame has already been acknowledged in another packet
Handler FrameHandler
}
type StreamFrame struct {
Frame *wire.StreamFrame
Handler FrameHandler
}

View File

@@ -0,0 +1,53 @@
package ackhandler
import (
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool)
// ReceivedAck processes an ACK frame.
// It does not store a copy of the frame.
ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error)
ReceivedBytes(protocol.ByteCount)
DropPackets(protocol.EncryptionLevel)
ResetForRetry(rcvTime time.Time) error
SetHandshakeConfirmed()
// The SendMode determines if and what kind of packets can be sent.
SendMode(now time.Time) SendMode
// TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets.
TimeUntilSend() time.Time
SetMaxDatagramSize(count protocol.ByteCount)
// only to be called once the handshake is complete
QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */
ECNMode(isShortHeaderPacket bool) protocol.ECN // isShortHeaderPacket should only be true for non-coalesced 1-RTT packets
PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber
GetLossDetectionTimeout() time.Time
OnLossDetectionTimeout() error
}
type sentPacketTracker interface {
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
ReceivedPacket(protocol.EncryptionLevel)
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
IsPotentiallyDuplicate(protocol.PacketNumber, protocol.EncryptionLevel) bool
ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, ackEliciting bool) error
DropPackets(protocol.EncryptionLevel)
GetAlarmTimeout() time.Time
GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame
}

View File

@@ -0,0 +1,9 @@
//go:build gomock || generate
package ackhandler
//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/quic-go/quic-go/internal/ackhandler SentPacketTracker"
type SentPacketTracker = sentPacketTracker
//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_ecn_handler_test.go github.com/quic-go/quic-go/internal/ackhandler ECNHandler"
type ECNHandler = ecnHandler

View File

@@ -0,0 +1,55 @@
package ackhandler
import (
"sync"
"time"
"github.com/quic-go/quic-go/internal/protocol"
)
// A Packet is a packet
type packet struct {
SendTime time.Time
PacketNumber protocol.PacketNumber
StreamFrames []StreamFrame
Frames []Frame
LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK
Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel
IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller.
includedInBytesInFlight bool
declaredLost bool
skippedPacket bool
}
func (p *packet) outstanding() bool {
return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket
}
var packetPool = sync.Pool{New: func() any { return &packet{} }}
func getPacket() *packet {
p := packetPool.Get().(*packet)
p.PacketNumber = 0
p.StreamFrames = nil
p.Frames = nil
p.LargestAcked = 0
p.Length = 0
p.EncryptionLevel = protocol.EncryptionLevel(0)
p.SendTime = time.Time{}
p.IsPathMTUProbePacket = false
p.includedInBytesInFlight = false
p.declaredLost = false
p.skippedPacket = false
return p
}
// We currently only return Packets back into the pool when they're acknowledged (not when they're lost).
// This simplifies the code, and gives the vast majority of the performance benefit we can gain from using the pool.
func putPacket(p *packet) {
p.Frames = nil
p.StreamFrames = nil
packetPool.Put(p)
}

View File

@@ -0,0 +1,84 @@
package ackhandler
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
type packetNumberGenerator interface {
Peek() protocol.PacketNumber
// Pop pops the packet number.
// It reports if the packet number (before the one just popped) was skipped.
// It never skips more than one packet number in a row.
Pop() (skipped bool, _ protocol.PacketNumber)
}
type sequentialPacketNumberGenerator struct {
next protocol.PacketNumber
}
var _ packetNumberGenerator = &sequentialPacketNumberGenerator{}
func newSequentialPacketNumberGenerator(initial protocol.PacketNumber) packetNumberGenerator {
return &sequentialPacketNumberGenerator{next: initial}
}
func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber {
return p.next
}
func (p *sequentialPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) {
next := p.next
p.next++
return false, next
}
// The skippingPacketNumberGenerator generates the packet number for the next packet
// it randomly skips a packet number every averagePeriod packets (on average).
// It is guaranteed to never skip two consecutive packet numbers.
type skippingPacketNumberGenerator struct {
period protocol.PacketNumber
maxPeriod protocol.PacketNumber
next protocol.PacketNumber
nextToSkip protocol.PacketNumber
rng utils.Rand
}
var _ packetNumberGenerator = &skippingPacketNumberGenerator{}
func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol.PacketNumber) packetNumberGenerator {
g := &skippingPacketNumberGenerator{
next: initial,
period: initialPeriod,
maxPeriod: maxPeriod,
}
g.generateNewSkip()
return g
}
func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber {
if p.next == p.nextToSkip {
return p.next + 1
}
return p.next
}
func (p *skippingPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) {
next := p.next
if p.next == p.nextToSkip {
next++
p.next += 2
p.generateNewSkip()
return true, next
}
p.next++ // generate a new packet number for the next packet
return false, next
}
func (p *skippingPacketNumberGenerator) generateNewSkip() {
// make sure that there are never two consecutive packet numbers that are skipped
p.nextToSkip = p.next + 3 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period)))
p.period = utils.Min(2*p.period, p.maxPeriod)
}

View File

@@ -0,0 +1,142 @@
package ackhandler
import (
"fmt"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
)
type receivedPacketHandler struct {
sentPackets sentPacketTracker
initialPackets *receivedPacketTracker
handshakePackets *receivedPacketTracker
appDataPackets *receivedPacketTracker
lowest1RTTPacket protocol.PacketNumber
}
var _ ReceivedPacketHandler = &receivedPacketHandler{}
func newReceivedPacketHandler(
sentPackets sentPacketTracker,
rttStats *utils.RTTStats,
logger utils.Logger,
) ReceivedPacketHandler {
return &receivedPacketHandler{
sentPackets: sentPackets,
initialPackets: newReceivedPacketTracker(rttStats, logger),
handshakePackets: newReceivedPacketTracker(rttStats, logger),
appDataPackets: newReceivedPacketTracker(rttStats, logger),
lowest1RTTPacket: protocol.InvalidPacketNumber,
}
}
func (h *receivedPacketHandler) ReceivedPacket(
pn protocol.PacketNumber,
ecn protocol.ECN,
encLevel protocol.EncryptionLevel,
rcvTime time.Time,
ackEliciting bool,
) error {
h.sentPackets.ReceivedPacket(encLevel)
switch encLevel {
case protocol.EncryptionInitial:
return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
case protocol.EncryptionHandshake:
// The Handshake packet number space might already have been dropped as a result
// of processing the CRYPTO frame that was contained in this packet.
if h.handshakePackets == nil {
return nil
}
return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
case protocol.Encryption0RTT:
if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket {
return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket)
}
return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
case protocol.Encryption1RTT:
if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket {
h.lowest1RTTPacket = pn
}
if err := h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil {
return err
}
h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked())
return nil
default:
panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel))
}
}
func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
//nolint:exhaustive // 1-RTT packet number space is never dropped.
switch encLevel {
case protocol.EncryptionInitial:
h.initialPackets = nil
case protocol.EncryptionHandshake:
h.handshakePackets = nil
case protocol.Encryption0RTT:
// Nothing to do here.
// If we are rejecting 0-RTT, no 0-RTT packets will have been decrypted.
default:
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
}
}
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
var initialAlarm, handshakeAlarm time.Time
if h.initialPackets != nil {
initialAlarm = h.initialPackets.GetAlarmTimeout()
}
if h.handshakePackets != nil {
handshakeAlarm = h.handshakePackets.GetAlarmTimeout()
}
oneRTTAlarm := h.appDataPackets.GetAlarmTimeout()
return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm)
}
func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame {
var ack *wire.AckFrame
//nolint:exhaustive // 0-RTT packets can't contain ACK frames.
switch encLevel {
case protocol.EncryptionInitial:
if h.initialPackets != nil {
ack = h.initialPackets.GetAckFrame(onlyIfQueued)
}
case protocol.EncryptionHandshake:
if h.handshakePackets != nil {
ack = h.handshakePackets.GetAckFrame(onlyIfQueued)
}
case protocol.Encryption1RTT:
// 0-RTT packets can't contain ACK frames
return h.appDataPackets.GetAckFrame(onlyIfQueued)
default:
return nil
}
// For Initial and Handshake ACKs, the delay time is ignored by the receiver.
// Set it to 0 in order to save bytes.
if ack != nil {
ack.DelayTime = 0
}
return ack
}
func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool {
switch encLevel {
case protocol.EncryptionInitial:
if h.initialPackets != nil {
return h.initialPackets.IsPotentiallyDuplicate(pn)
}
case protocol.EncryptionHandshake:
if h.handshakePackets != nil {
return h.handshakePackets.IsPotentiallyDuplicate(pn)
}
case protocol.Encryption0RTT, protocol.Encryption1RTT:
return h.appDataPackets.IsPotentiallyDuplicate(pn)
}
panic("unexpected encryption level")
}

View File

@@ -0,0 +1,151 @@
package ackhandler
import (
"sync"
"github.com/quic-go/quic-go/internal/protocol"
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
"github.com/quic-go/quic-go/internal/wire"
)
// interval is an interval from one PacketNumber to the other
type interval struct {
Start protocol.PacketNumber
End protocol.PacketNumber
}
var intervalElementPool sync.Pool
func init() {
intervalElementPool = *list.NewPool[interval]()
}
// The receivedPacketHistory stores if a packet number has already been received.
// It generates ACK ranges which can be used to assemble an ACK frame.
// It does not store packet contents.
type receivedPacketHistory struct {
ranges *list.List[interval]
deletedBelow protocol.PacketNumber
}
func newReceivedPacketHistory() *receivedPacketHistory {
return &receivedPacketHistory{
ranges: list.NewWithPool[interval](&intervalElementPool),
}
}
// ReceivedPacket registers a packet with PacketNumber p and updates the ranges
func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
// ignore delayed packets, if we already deleted the range
if p < h.deletedBelow {
return false
}
isNew := h.addToRanges(p)
h.maybeDeleteOldRanges()
return isNew
}
func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
if h.ranges.Len() == 0 {
h.ranges.PushBack(interval{Start: p, End: p})
return true
}
for el := h.ranges.Back(); el != nil; el = el.Prev() {
// p already included in an existing range. Nothing to do here
if p >= el.Value.Start && p <= el.Value.End {
return false
}
if el.Value.End == p-1 { // extend a range at the end
el.Value.End = p
return true
}
if el.Value.Start == p+1 { // extend a range at the beginning
el.Value.Start = p
prev := el.Prev()
if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges
prev.Value.End = el.Value.End
h.ranges.Remove(el)
}
return true
}
// create a new range at the end
if p > el.Value.End {
h.ranges.InsertAfter(interval{Start: p, End: p}, el)
return true
}
}
// create a new range at the beginning
h.ranges.InsertBefore(interval{Start: p, End: p}, h.ranges.Front())
return true
}
// Delete old ranges, if we're tracking more than 500 of them.
// This is a DoS defense against a peer that sends us too many gaps.
func (h *receivedPacketHistory) maybeDeleteOldRanges() {
for h.ranges.Len() > protocol.MaxNumAckRanges {
h.ranges.Remove(h.ranges.Front())
}
}
// DeleteBelow deletes all entries below (but not including) p
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
if p < h.deletedBelow {
return
}
h.deletedBelow = p
nextEl := h.ranges.Front()
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
nextEl = el.Next()
if el.Value.End < p { // delete a whole range
h.ranges.Remove(el)
} else if p > el.Value.Start && p <= el.Value.End {
el.Value.Start = p
return
} else { // no ranges affected. Nothing to do
return
}
}
}
// AppendAckRanges appends to a slice of all AckRanges that can be used in an AckFrame
func (h *receivedPacketHistory) AppendAckRanges(ackRanges []wire.AckRange) []wire.AckRange {
if h.ranges.Len() > 0 {
for el := h.ranges.Back(); el != nil; el = el.Prev() {
ackRanges = append(ackRanges, wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End})
}
}
return ackRanges
}
func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
ackRange := wire.AckRange{}
if h.ranges.Len() > 0 {
r := h.ranges.Back().Value
ackRange.Smallest = r.Start
ackRange.Largest = r.End
}
return ackRange
}
func (h *receivedPacketHistory) IsPotentiallyDuplicate(p protocol.PacketNumber) bool {
if p < h.deletedBelow {
return true
}
for el := h.ranges.Back(); el != nil; el = el.Prev() {
if p > el.Value.End {
return false
}
if p <= el.Value.End && p >= el.Value.Start {
return true
}
}
return false
}

View File

@@ -0,0 +1,196 @@
package ackhandler
import (
"fmt"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
)
// number of ack-eliciting packets received before sending an ack.
const packetsBeforeAck = 2
type receivedPacketTracker struct {
largestObserved protocol.PacketNumber
ignoreBelow protocol.PacketNumber
largestObservedRcvdTime time.Time
ect0, ect1, ecnce uint64
packetHistory *receivedPacketHistory
maxAckDelay time.Duration
rttStats *utils.RTTStats
hasNewAck bool // true as soon as we received an ack-eliciting new packet
ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets
ackElicitingPacketsReceivedSinceLastAck int
ackAlarm time.Time
lastAck *wire.AckFrame
logger utils.Logger
}
func newReceivedPacketTracker(
rttStats *utils.RTTStats,
logger utils.Logger,
) *receivedPacketTracker {
return &receivedPacketTracker{
packetHistory: newReceivedPacketHistory(),
maxAckDelay: protocol.MaxAckDelay,
rttStats: rttStats,
logger: logger,
}
}
func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
if isNew := h.packetHistory.ReceivedPacket(pn); !isNew {
return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn)
}
isMissing := h.isMissing(pn)
if pn >= h.largestObserved {
h.largestObserved = pn
h.largestObservedRcvdTime = rcvTime
}
if ackEliciting {
h.hasNewAck = true
}
if ackEliciting {
h.maybeQueueACK(pn, rcvTime, isMissing)
}
//nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECNCE.
switch ecn {
case protocol.ECT0:
h.ect0++
case protocol.ECT1:
h.ect1++
case protocol.ECNCE:
h.ecnce++
}
return nil
}
// IgnoreBelow sets a lower limit for acknowledging packets.
// Packets with packet numbers smaller than p will not be acked.
func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
if pn <= h.ignoreBelow {
return
}
h.ignoreBelow = pn
h.packetHistory.DeleteBelow(pn)
if h.logger.Debug() {
h.logger.Debugf("\tIgnoring all packets below %d.", pn)
}
}
// isMissing says if a packet was reported missing in the last ACK.
func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil || p < h.ignoreBelow {
return false
}
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
}
func (h *receivedPacketTracker) hasNewMissingPackets() bool {
if h.lastAck == nil {
return false
}
highestRange := h.packetHistory.GetHighestAckRange()
return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1
}
// maybeQueueACK queues an ACK, if necessary.
func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) {
// always acknowledge the first packet
if h.lastAck == nil {
if !h.ackQueued {
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
}
h.ackQueued = true
return
}
if h.ackQueued {
return
}
h.ackElicitingPacketsReceivedSinceLastAck++
// Send an ACK if this packet was reported missing in an ACK sent before.
// Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ack, send an ACK immediately.
if wasMissing {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn)
}
h.ackQueued = true
}
// send an ACK every 2 ack-eliciting packets
if h.ackElicitingPacketsReceivedSinceLastAck >= packetsBeforeAck {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck)
}
h.ackQueued = true
} else if h.ackAlarm.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay)
}
h.ackAlarm = rcvTime.Add(h.maxAckDelay)
}
// Queue an ACK if there are new missing packets to report.
if h.hasNewMissingPackets() {
h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.")
h.ackQueued = true
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
}
}
func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
if !h.hasNewAck {
return nil
}
now := time.Now()
if onlyIfQueued {
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
return nil
}
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
h.logger.Debugf("Sending ACK because the ACK timer expired.")
}
}
// This function always returns the same ACK frame struct, filled with the most recent values.
ack := h.lastAck
if ack == nil {
ack = &wire.AckFrame{}
}
ack.Reset()
ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedRcvdTime))
ack.ECT0 = h.ect0
ack.ECT1 = h.ect1
ack.ECNCE = h.ecnce
ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false
h.hasNewAck = false
h.ackElicitingPacketsReceivedSinceLastAck = 0
return ack
}
func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }
func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
return h.packetHistory.IsPotentiallyDuplicate(pn)
}

View File

@@ -0,0 +1,46 @@
package ackhandler
import "fmt"
// The SendMode says what kind of packets can be sent.
type SendMode uint8
const (
// SendNone means that no packets should be sent
SendNone SendMode = iota
// SendAck means an ACK-only packet should be sent
SendAck
// SendPTOInitial means that an Initial probe packet should be sent
SendPTOInitial
// SendPTOHandshake means that a Handshake probe packet should be sent
SendPTOHandshake
// SendPTOAppData means that an Application data probe packet should be sent
SendPTOAppData
// SendPacingLimited means that the pacer doesn't allow sending of a packet right now,
// but will do in a little while.
// The timestamp when sending is allowed again can be obtained via the SentPacketHandler.TimeUntilSend.
SendPacingLimited
// SendAny means that any packet should be sent
SendAny
)
func (s SendMode) String() string {
switch s {
case SendNone:
return "none"
case SendAck:
return "ack"
case SendPTOInitial:
return "pto (Initial)"
case SendPTOHandshake:
return "pto (Handshake)"
case SendPTOAppData:
return "pto (Application Data)"
case SendAny:
return "any"
case SendPacingLimited:
return "pacing limited"
default:
return fmt.Sprintf("invalid send mode: %d", s)
}
}

View File

@@ -0,0 +1,928 @@
package ackhandler
import (
"errors"
"fmt"
"time"
"github.com/quic-go/quic-go/internal/congestion"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
const (
// Maximum reordering in time space before time based loss detection considers a packet lost.
// Specified as an RTT multiplier.
timeThreshold = 9.0 / 8
// Maximum reordering in packets before packet threshold loss detection considers a packet lost.
packetThreshold = 3
// Before validating the client's address, the server won't send more than 3x bytes than it received.
amplificationFactor = 3
// We use Retry packets to derive an RTT estimate. Make sure we don't set the RTT to a super low value yet.
minRTTAfterRetry = 5 * time.Millisecond
// The PTO duration uses exponential backoff, but is truncated to a maximum value, as allowed by RFC 8961, section 4.4.
maxPTODuration = 60 * time.Second
)
type packetNumberSpace struct {
history *sentPacketHistory
pns packetNumberGenerator
lossTime time.Time
lastAckElicitingPacketTime time.Time
largestAcked protocol.PacketNumber
largestSent protocol.PacketNumber
}
func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool) *packetNumberSpace {
var pns packetNumberGenerator
if skipPNs {
pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod)
} else {
pns = newSequentialPacketNumberGenerator(initialPN)
}
return &packetNumberSpace{
history: newSentPacketHistory(),
pns: pns,
largestSent: protocol.InvalidPacketNumber,
largestAcked: protocol.InvalidPacketNumber,
}
}
type sentPacketHandler struct {
initialPackets *packetNumberSpace
handshakePackets *packetNumberSpace
appDataPackets *packetNumberSpace
// Do we know that the peer completed address validation yet?
// Always true for the server.
peerCompletedAddressValidation bool
bytesReceived protocol.ByteCount
bytesSent protocol.ByteCount
// Have we validated the peer's address yet?
// Always true for the client.
peerAddressValidated bool
handshakeConfirmed bool
// lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
// example: we send an ACK for packets 90-100 with packet number 20
// once we receive an ACK from the peer for packet 20, the lowestNotConfirmedAcked is 101
// Only applies to the application-data packet number space.
lowestNotConfirmedAcked protocol.PacketNumber
ackedPackets []*packet // to avoid allocations in detectAndRemoveAckedPackets
bytesInFlight protocol.ByteCount
congestion congestion.SendAlgorithmWithDebugInfos
rttStats *utils.RTTStats
// The number of times a PTO has been sent without receiving an ack.
ptoCount uint32
ptoMode SendMode
// The number of PTO probe packets that should be sent.
// Only applies to the application-data packet number space.
numProbesToSend int
// The alarm timeout
alarm time.Time
enableECN bool
ecnTracker ecnHandler
perspective protocol.Perspective
tracer *logging.ConnectionTracer
logger utils.Logger
}
var (
_ SentPacketHandler = &sentPacketHandler{}
_ sentPacketTracker = &sentPacketHandler{}
)
// clientAddressValidated indicates whether the address was validated beforehand by an address validation token.
// If the address was validated, the amplification limit doesn't apply. It has no effect for a client.
func newSentPacketHandler(
initialPN protocol.PacketNumber,
initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats,
clientAddressValidated bool,
enableECN bool,
pers protocol.Perspective,
tracer *logging.ConnectionTracer,
logger utils.Logger,
) *sentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
initialMaxDatagramSize,
true, // use Reno
tracer,
)
h := &sentPacketHandler{
peerCompletedAddressValidation: pers == protocol.PerspectiveServer,
peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated,
initialPackets: newPacketNumberSpace(initialPN, false),
handshakePackets: newPacketNumberSpace(0, false),
appDataPackets: newPacketNumberSpace(0, true),
rttStats: rttStats,
congestion: congestion,
perspective: pers,
tracer: tracer,
logger: logger,
}
if enableECN {
h.enableECN = true
h.ecnTracker = newECNTracker(logger, tracer)
}
return h
}
func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
if p.includedInBytesInFlight {
if p.Length > h.bytesInFlight {
panic("negative bytes_in_flight")
}
h.bytesInFlight -= p.Length
p.includedInBytesInFlight = false
}
}
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
// The server won't await address validation after the handshake is confirmed.
// This applies even if we didn't receive an ACK for a Handshake packet.
if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake {
h.peerCompletedAddressValidation = true
}
// remove outstanding packets from bytes_in_flight
if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake {
pnSpace := h.getPacketNumberSpace(encLevel)
// We might already have dropped this packet number space.
if pnSpace == nil {
return
}
pnSpace.history.Iterate(func(p *packet) (bool, error) {
h.removeFromBytesInFlight(p)
return true, nil
})
}
// drop the packet history
//nolint:exhaustive // Not every packet number space can be dropped.
switch encLevel {
case protocol.EncryptionInitial:
h.initialPackets = nil
case protocol.EncryptionHandshake:
h.handshakePackets = nil
case protocol.Encryption0RTT:
// This function is only called when 0-RTT is rejected,
// and not when the client drops 0-RTT keys when the handshake completes.
// When 0-RTT is rejected, all application data sent so far becomes invalid.
// Delete the packets from the history and remove them from bytes_in_flight.
h.appDataPackets.history.Iterate(func(p *packet) (bool, error) {
if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket {
return false, nil
}
h.removeFromBytesInFlight(p)
h.appDataPackets.history.Remove(p.PacketNumber)
return true, nil
})
default:
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
}
if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 {
h.tracer.UpdatedPTOCount(0)
}
h.ptoCount = 0
h.numProbesToSend = 0
h.ptoMode = SendNone
h.setLossDetectionTimer()
}
func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) {
wasAmplificationLimit := h.isAmplificationLimited()
h.bytesReceived += n
if wasAmplificationLimit && !h.isAmplificationLimited() {
h.setLossDetectionTimer()
}
}
func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel) {
if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated {
h.peerAddressValidated = true
h.setLossDetectionTimer()
}
}
func (h *sentPacketHandler) packetsInFlight() int {
packetsInFlight := h.appDataPackets.history.Len()
if h.handshakePackets != nil {
packetsInFlight += h.handshakePackets.history.Len()
}
if h.initialPackets != nil {
packetsInFlight += h.initialPackets.history.Len()
}
return packetsInFlight
}
func (h *sentPacketHandler) SentPacket(
t time.Time,
pn, largestAcked protocol.PacketNumber,
streamFrames []StreamFrame,
frames []Frame,
encLevel protocol.EncryptionLevel,
ecn protocol.ECN,
size protocol.ByteCount,
isPathMTUProbePacket bool,
) {
h.bytesSent += size
pnSpace := h.getPacketNumberSpace(encLevel)
if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() {
for p := utils.Max(0, pnSpace.largestSent+1); p < pn; p++ {
h.logger.Debugf("Skipping packet number %d", p)
}
}
pnSpace.largestSent = pn
isAckEliciting := len(streamFrames) > 0 || len(frames) > 0
if isAckEliciting {
pnSpace.lastAckElicitingPacketTime = t
h.bytesInFlight += size
if h.numProbesToSend > 0 {
h.numProbesToSend--
}
}
h.congestion.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting)
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil {
h.ecnTracker.SentPacket(pn, ecn)
}
if !isAckEliciting {
pnSpace.history.SentNonAckElicitingPacket(pn)
if !h.peerCompletedAddressValidation {
h.setLossDetectionTimer()
}
return
}
p := getPacket()
p.SendTime = t
p.PacketNumber = pn
p.EncryptionLevel = encLevel
p.Length = size
p.LargestAcked = largestAcked
p.StreamFrames = streamFrames
p.Frames = frames
p.IsPathMTUProbePacket = isPathMTUProbePacket
p.includedInBytesInFlight = true
pnSpace.history.SentAckElicitingPacket(p)
if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
}
h.setLossDetectionTimer()
}
func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace {
switch encLevel {
case protocol.EncryptionInitial:
return h.initialPackets
case protocol.EncryptionHandshake:
return h.handshakePackets
case protocol.Encryption0RTT, protocol.Encryption1RTT:
return h.appDataPackets
default:
panic("invalid packet number space")
}
}
func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* contained 1-RTT packet */, error) {
pnSpace := h.getPacketNumberSpace(encLevel)
largestAcked := ack.LargestAcked()
if largestAcked > pnSpace.largestSent {
return false, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "received ACK for an unsent packet",
}
}
// Servers complete address validation when a protected packet is received.
if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation &&
(encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) {
h.peerCompletedAddressValidation = true
h.logger.Debugf("Peer doesn't await address validation any longer.")
// Make sure that the timer is reset, even if this ACK doesn't acknowledge any (ack-eliciting) packets.
h.setLossDetectionTimer()
}
priorInFlight := h.bytesInFlight
ackedPackets, err := h.detectAndRemoveAckedPackets(ack, encLevel)
if err != nil || len(ackedPackets) == 0 {
return false, err
}
// update the RTT, if the largest acked is newly acknowledged
if len(ackedPackets) > 0 {
if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() {
// don't use the ack delay for Initial and Handshake packets
var ackDelay time.Duration
if encLevel == protocol.Encryption1RTT {
ackDelay = utils.Min(ack.DelayTime, h.rttStats.MaxAckDelay())
}
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
if h.logger.Debug() {
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
}
h.congestion.MaybeExitSlowStart()
}
}
// Only inform the ECN tracker about new 1-RTT ACKs if the ACK increases the largest acked.
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil && largestAcked > pnSpace.largestAcked {
congested := h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE))
if congested {
h.congestion.OnCongestionEvent(largestAcked, 0, priorInFlight)
}
}
pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked)
if err := h.detectLostPackets(rcvTime, encLevel); err != nil {
return false, err
}
var acked1RTTPacket bool
for _, p := range ackedPackets {
if p.includedInBytesInFlight && !p.declaredLost {
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
}
if p.EncryptionLevel == protocol.Encryption1RTT {
acked1RTTPacket = true
}
h.removeFromBytesInFlight(p)
putPacket(p)
}
// After this point, we must not use ackedPackets any longer!
// We've already returned the buffers.
ackedPackets = nil //nolint:ineffassign // This is just to be on the safe side.
// Reset the pto_count unless the client is unsure if the server has validated the client's address.
if h.peerCompletedAddressValidation {
if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 {
h.tracer.UpdatedPTOCount(0)
}
h.ptoCount = 0
}
h.numProbesToSend = 0
if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
}
h.setLossDetectionTimer()
return acked1RTTPacket, nil
}
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
return h.lowestNotConfirmedAcked
}
// Packets are returned in ascending packet number order.
func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*packet, error) {
pnSpace := h.getPacketNumberSpace(encLevel)
h.ackedPackets = h.ackedPackets[:0]
ackRangeIndex := 0
lowestAcked := ack.LowestAcked()
largestAcked := ack.LargestAcked()
err := pnSpace.history.Iterate(func(p *packet) (bool, error) {
// Ignore packets below the lowest acked
if p.PacketNumber < lowestAcked {
return true, nil
}
// Break after largest acked is reached
if p.PacketNumber > largestAcked {
return false, nil
}
if ack.HasMissingRanges() {
ackRange := ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ack.AckRanges)-1 {
ackRangeIndex++
ackRange = ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
}
if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range
return true, nil
}
if p.PacketNumber > ackRange.Largest {
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
}
}
if p.skippedPacket {
return false, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel),
}
}
h.ackedPackets = append(h.ackedPackets, p)
return true, nil
})
if h.logger.Debug() && len(h.ackedPackets) > 0 {
pns := make([]protocol.PacketNumber, len(h.ackedPackets))
for i, p := range h.ackedPackets {
pns[i] = p.PacketNumber
}
h.logger.Debugf("\tnewly acked packets (%d): %d", len(pns), pns)
}
for _, p := range h.ackedPackets {
if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT {
h.lowestNotConfirmedAcked = utils.Max(h.lowestNotConfirmedAcked, p.LargestAcked+1)
}
for _, f := range p.Frames {
if f.Handler != nil {
f.Handler.OnAcked(f.Frame)
}
}
for _, f := range p.StreamFrames {
if f.Handler != nil {
f.Handler.OnAcked(f.Frame)
}
}
if err := pnSpace.history.Remove(p.PacketNumber); err != nil {
return nil, err
}
if h.tracer != nil && h.tracer.AcknowledgedPacket != nil {
h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber)
}
}
return h.ackedPackets, err
}
func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.EncryptionLevel) {
var encLevel protocol.EncryptionLevel
var lossTime time.Time
if h.initialPackets != nil {
lossTime = h.initialPackets.lossTime
encLevel = protocol.EncryptionInitial
}
if h.handshakePackets != nil && (lossTime.IsZero() || (!h.handshakePackets.lossTime.IsZero() && h.handshakePackets.lossTime.Before(lossTime))) {
lossTime = h.handshakePackets.lossTime
encLevel = protocol.EncryptionHandshake
}
if lossTime.IsZero() || (!h.appDataPackets.lossTime.IsZero() && h.appDataPackets.lossTime.Before(lossTime)) {
lossTime = h.appDataPackets.lossTime
encLevel = protocol.Encryption1RTT
}
return lossTime, encLevel
}
func (h *sentPacketHandler) getScaledPTO(includeMaxAckDelay bool) time.Duration {
pto := h.rttStats.PTO(includeMaxAckDelay) << h.ptoCount
if pto > maxPTODuration || pto <= 0 {
return maxPTODuration
}
return pto
}
// same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime
func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) {
// We only send application data probe packets once the handshake is confirmed,
// because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets.
if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() {
if h.peerCompletedAddressValidation {
return
}
t := time.Now().Add(h.getScaledPTO(false))
if h.initialPackets != nil {
return t, protocol.EncryptionInitial, true
}
return t, protocol.EncryptionHandshake, true
}
if h.initialPackets != nil {
encLevel = protocol.EncryptionInitial
if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() {
pto = t.Add(h.getScaledPTO(false))
}
}
if h.handshakePackets != nil && !h.handshakePackets.lastAckElicitingPacketTime.IsZero() {
t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(false))
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
pto = t
encLevel = protocol.EncryptionHandshake
}
}
if h.handshakeConfirmed && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() {
t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(true))
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
pto = t
encLevel = protocol.Encryption1RTT
}
}
return pto, encLevel, true
}
func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool {
if h.initialPackets != nil && h.initialPackets.history.HasOutstandingPackets() {
return true
}
if h.handshakePackets != nil && h.handshakePackets.history.HasOutstandingPackets() {
return true
}
return false
}
func (h *sentPacketHandler) hasOutstandingPackets() bool {
return h.appDataPackets.history.HasOutstandingPackets() || h.hasOutstandingCryptoPackets()
}
func (h *sentPacketHandler) setLossDetectionTimer() {
oldAlarm := h.alarm // only needed in case tracing is enabled
lossTime, encLevel := h.getLossTimeAndSpace()
if !lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = lossTime
if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm {
h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm)
}
return
}
// Cancel the alarm if amplification limited.
if h.isAmplificationLimited() {
h.alarm = time.Time{}
if !oldAlarm.IsZero() {
h.logger.Debugf("Canceling loss detection timer. Amplification limited.")
if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled()
}
}
return
}
// Cancel the alarm if no packets are outstanding
if !h.hasOutstandingPackets() && h.peerCompletedAddressValidation {
h.alarm = time.Time{}
if !oldAlarm.IsZero() {
h.logger.Debugf("Canceling loss detection timer. No packets in flight.")
if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled()
}
}
return
}
// PTO alarm
ptoTime, encLevel, ok := h.getPTOTimeAndSpace()
if !ok {
if !oldAlarm.IsZero() {
h.alarm = time.Time{}
h.logger.Debugf("Canceling loss detection timer. No PTO needed..")
if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled()
}
}
return
}
h.alarm = ptoTime
if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm {
h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm)
}
}
func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) error {
pnSpace := h.getPacketNumberSpace(encLevel)
pnSpace.lossTime = time.Time{}
maxRTT := float64(utils.Max(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
lossDelay := time.Duration(timeThreshold * maxRTT)
// Minimum time of granularity before packets are deemed lost.
lossDelay = utils.Max(lossDelay, protocol.TimerGranularity)
// Packets sent before this time are deemed lost.
lostSendTime := now.Add(-lossDelay)
priorInFlight := h.bytesInFlight
return pnSpace.history.Iterate(func(p *packet) (bool, error) {
if p.PacketNumber > pnSpace.largestAcked {
return false, nil
}
var packetLost bool
if p.SendTime.Before(lostSendTime) {
packetLost = true
if !p.skippedPacket {
if h.logger.Debug() {
h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber)
}
if h.tracer != nil && h.tracer.LostPacket != nil {
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold)
}
}
} else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold {
packetLost = true
if !p.skippedPacket {
if h.logger.Debug() {
h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber)
}
if h.tracer != nil && h.tracer.LostPacket != nil {
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold)
}
}
} else if pnSpace.lossTime.IsZero() {
// Note: This conditional is only entered once per call
lossTime := p.SendTime.Add(lossDelay)
if h.logger.Debug() {
h.logger.Debugf("\tsetting loss timer for packet %d (%s) to %s (in %s)", p.PacketNumber, encLevel, lossDelay, lossTime)
}
pnSpace.lossTime = lossTime
}
if packetLost {
pnSpace.history.DeclareLost(p.PacketNumber)
if !p.skippedPacket {
// the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted
h.removeFromBytesInFlight(p)
h.queueFramesForRetransmission(p)
if !p.IsPathMTUProbePacket {
h.congestion.OnCongestionEvent(p.PacketNumber, p.Length, priorInFlight)
}
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil {
h.ecnTracker.LostPacket(p.PacketNumber)
}
}
}
return true, nil
})
}
func (h *sentPacketHandler) OnLossDetectionTimeout() error {
defer h.setLossDetectionTimer()
earliestLossTime, encLevel := h.getLossTimeAndSpace()
if !earliestLossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime)
}
if h.tracer != nil && h.tracer.LossTimerExpired != nil {
h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel)
}
// Early retransmit or time loss detection
return h.detectLostPackets(time.Now(), encLevel)
}
// PTO
// When all outstanding are acknowledged, the alarm is canceled in
// setLossDetectionTimer. This doesn't reset the timer in the session though.
// When OnAlarm is called, we therefore need to make sure that there are
// actually packets outstanding.
if h.bytesInFlight == 0 && !h.peerCompletedAddressValidation {
h.ptoCount++
h.numProbesToSend++
if h.initialPackets != nil {
h.ptoMode = SendPTOInitial
} else if h.handshakePackets != nil {
h.ptoMode = SendPTOHandshake
} else {
return errors.New("sentPacketHandler BUG: PTO fired, but bytes_in_flight is 0 and Initial and Handshake already dropped")
}
return nil
}
_, encLevel, ok := h.getPTOTimeAndSpace()
if !ok {
return nil
}
if ps := h.getPacketNumberSpace(encLevel); !ps.history.HasOutstandingPackets() && !h.peerCompletedAddressValidation {
return nil
}
h.ptoCount++
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount)
}
if h.tracer != nil {
if h.tracer.LossTimerExpired != nil {
h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel)
}
if h.tracer.UpdatedPTOCount != nil {
h.tracer.UpdatedPTOCount(h.ptoCount)
}
}
h.numProbesToSend += 2
//nolint:exhaustive // We never arm a PTO timer for 0-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
h.ptoMode = SendPTOInitial
case protocol.EncryptionHandshake:
h.ptoMode = SendPTOHandshake
case protocol.Encryption1RTT:
// skip a packet number in order to elicit an immediate ACK
pn := h.PopPacketNumber(protocol.Encryption1RTT)
h.getPacketNumberSpace(protocol.Encryption1RTT).history.SkippedPacket(pn)
h.ptoMode = SendPTOAppData
default:
return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel)
}
return nil
}
func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
return h.alarm
}
func (h *sentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN {
if !h.enableECN {
return protocol.ECNUnsupported
}
if !isShortHeaderPacket {
return protocol.ECNNon
}
return h.ecnTracker.Mode()
}
func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) {
pnSpace := h.getPacketNumberSpace(encLevel)
pn := pnSpace.pns.Peek()
// See section 17.1 of RFC 9000.
return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked)
}
func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber {
pnSpace := h.getPacketNumberSpace(encLevel)
skipped, pn := pnSpace.pns.Pop()
if skipped {
skippedPN := pn - 1
pnSpace.history.SkippedPacket(skippedPN)
if h.logger.Debug() {
h.logger.Debugf("Skipping packet number %d", skippedPN)
}
}
return pn
}
func (h *sentPacketHandler) SendMode(now time.Time) SendMode {
numTrackedPackets := h.appDataPackets.history.Len()
if h.initialPackets != nil {
numTrackedPackets += h.initialPackets.history.Len()
}
if h.handshakePackets != nil {
numTrackedPackets += h.handshakePackets.history.Len()
}
if h.isAmplificationLimited() {
h.logger.Debugf("Amplification window limited. Received %d bytes, already sent out %d bytes", h.bytesReceived, h.bytesSent)
return SendNone
}
// Don't send any packets if we're keeping track of the maximum number of packets.
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
// but still allow sending of retransmissions and ACKs.
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
}
return SendNone
}
if h.numProbesToSend > 0 {
return h.ptoMode
}
// Only send ACKs if we're congestion limited.
if !h.congestion.CanSend(h.bytesInFlight) {
if h.logger.Debug() {
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, h.congestion.GetCongestionWindow())
}
return SendAck
}
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
}
return SendAck
}
if !h.congestion.HasPacingBudget(now) {
return SendPacingLimited
}
return SendAny
}
func (h *sentPacketHandler) TimeUntilSend() time.Time {
return h.congestion.TimeUntilSend(h.bytesInFlight)
}
func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) {
h.congestion.SetMaxDatagramSize(s)
}
func (h *sentPacketHandler) isAmplificationLimited() bool {
if h.peerAddressValidated {
return false
}
return h.bytesSent >= amplificationFactor*h.bytesReceived
}
func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool {
pnSpace := h.getPacketNumberSpace(encLevel)
p := pnSpace.history.FirstOutstanding()
if p == nil {
return false
}
h.queueFramesForRetransmission(p)
// TODO: don't declare the packet lost here.
// Keep track of acknowledged frames instead.
h.removeFromBytesInFlight(p)
pnSpace.history.DeclareLost(p.PacketNumber)
return true
}
func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) {
if len(p.Frames) == 0 && len(p.StreamFrames) == 0 {
panic("no frames")
}
for _, f := range p.Frames {
if f.Handler != nil {
f.Handler.OnLost(f.Frame)
}
}
for _, f := range p.StreamFrames {
if f.Handler != nil {
f.Handler.OnLost(f.Frame)
}
}
p.StreamFrames = nil
p.Frames = nil
}
func (h *sentPacketHandler) ResetForRetry(now time.Time) error {
h.bytesInFlight = 0
var firstPacketSendTime time.Time
h.initialPackets.history.Iterate(func(p *packet) (bool, error) {
if firstPacketSendTime.IsZero() {
firstPacketSendTime = p.SendTime
}
if p.declaredLost || p.skippedPacket {
return true, nil
}
h.queueFramesForRetransmission(p)
return true, nil
})
// All application data packets sent at this point are 0-RTT packets.
// In the case of a Retry, we can assume that the server dropped all of them.
h.appDataPackets.history.Iterate(func(p *packet) (bool, error) {
if !p.declaredLost && !p.skippedPacket {
h.queueFramesForRetransmission(p)
}
return true, nil
})
// Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial.
// Otherwise, we don't know which Initial the Retry was sent in response to.
if h.ptoCount == 0 {
// Don't set the RTT to a value lower than 5ms here.
h.rttStats.UpdateRTT(utils.Max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now)
if h.logger.Debug() {
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
}
if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
}
}
h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Peek(), false)
h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Peek(), true)
oldAlarm := h.alarm
h.alarm = time.Time{}
if h.tracer != nil {
if h.tracer.UpdatedPTOCount != nil {
h.tracer.UpdatedPTOCount(0)
}
if !oldAlarm.IsZero() && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled()
}
}
h.ptoCount = 0
return nil
}
func (h *sentPacketHandler) SetHandshakeConfirmed() {
if h.initialPackets != nil {
panic("didn't drop initial correctly")
}
if h.handshakePackets != nil {
panic("didn't drop handshake correctly")
}
h.handshakeConfirmed = true
// We don't send PTOs for application data packets before the handshake completes.
// Make sure the timer is armed now, if necessary.
h.setLossDetectionTimer()
}

View File

@@ -0,0 +1,177 @@
package ackhandler
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
)
type sentPacketHistory struct {
packets []*packet
numOutstanding int
highestPacketNumber protocol.PacketNumber
}
func newSentPacketHistory() *sentPacketHistory {
return &sentPacketHistory{
packets: make([]*packet, 0, 32),
highestPacketNumber: protocol.InvalidPacketNumber,
}
}
func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNumber) {
if h.highestPacketNumber != protocol.InvalidPacketNumber {
if pn != h.highestPacketNumber+1 {
panic("non-sequential packet number use")
}
}
}
func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) {
h.checkSequentialPacketNumberUse(pn)
h.highestPacketNumber = pn
h.packets = append(h.packets, &packet{
PacketNumber: pn,
skippedPacket: true,
})
}
func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber) {
h.checkSequentialPacketNumberUse(pn)
h.highestPacketNumber = pn
if len(h.packets) > 0 {
h.packets = append(h.packets, nil)
}
}
func (h *sentPacketHistory) SentAckElicitingPacket(p *packet) {
h.checkSequentialPacketNumberUse(p.PacketNumber)
h.highestPacketNumber = p.PacketNumber
h.packets = append(h.packets, p)
if p.outstanding() {
h.numOutstanding++
}
}
// Iterate iterates through all packets.
func (h *sentPacketHistory) Iterate(cb func(*packet) (cont bool, err error)) error {
for _, p := range h.packets {
if p == nil {
continue
}
cont, err := cb(p)
if err != nil {
return err
}
if !cont {
return nil
}
}
return nil
}
// FirstOutstanding returns the first outstanding packet.
func (h *sentPacketHistory) FirstOutstanding() *packet {
if !h.HasOutstandingPackets() {
return nil
}
for _, p := range h.packets {
if p != nil && p.outstanding() {
return p
}
}
return nil
}
func (h *sentPacketHistory) Len() int {
return len(h.packets)
}
func (h *sentPacketHistory) Remove(pn protocol.PacketNumber) error {
idx, ok := h.getIndex(pn)
if !ok {
return fmt.Errorf("packet %d not found in sent packet history", pn)
}
p := h.packets[idx]
if p.outstanding() {
h.numOutstanding--
if h.numOutstanding < 0 {
panic("negative number of outstanding packets")
}
}
h.packets[idx] = nil
// clean up all skipped packets directly before this packet number
for idx > 0 {
idx--
p := h.packets[idx]
if p == nil || !p.skippedPacket {
break
}
h.packets[idx] = nil
}
if idx == 0 {
h.cleanupStart()
}
if len(h.packets) > 0 && h.packets[0] == nil {
panic("remove failed")
}
return nil
}
// getIndex gets the index of packet p in the packets slice.
func (h *sentPacketHistory) getIndex(p protocol.PacketNumber) (int, bool) {
if len(h.packets) == 0 {
return 0, false
}
first := h.packets[0].PacketNumber
if p < first {
return 0, false
}
index := int(p - first)
if index > len(h.packets)-1 {
return 0, false
}
return index, true
}
func (h *sentPacketHistory) HasOutstandingPackets() bool {
return h.numOutstanding > 0
}
// delete all nil entries at the beginning of the packets slice
func (h *sentPacketHistory) cleanupStart() {
for i, p := range h.packets {
if p != nil {
h.packets = h.packets[i:]
return
}
}
h.packets = h.packets[:0]
}
func (h *sentPacketHistory) LowestPacketNumber() protocol.PacketNumber {
if len(h.packets) == 0 {
return protocol.InvalidPacketNumber
}
return h.packets[0].PacketNumber
}
func (h *sentPacketHistory) DeclareLost(pn protocol.PacketNumber) {
idx, ok := h.getIndex(pn)
if !ok {
return
}
p := h.packets[idx]
if p.outstanding() {
h.numOutstanding--
if h.numOutstanding < 0 {
panic("negative number of outstanding packets")
}
}
h.packets[idx] = nil
if idx == 0 {
h.cleanupStart()
}
}

View File

@@ -0,0 +1,25 @@
package congestion
import (
"math"
"time"
"github.com/quic-go/quic-go/internal/protocol"
)
// Bandwidth of a connection
type Bandwidth uint64
const infBandwidth Bandwidth = math.MaxUint64
const (
// BitsPerSecond is 1 bit per second
BitsPerSecond Bandwidth = 1
// BytesPerSecond is 1 byte per second
BytesPerSecond = 8 * BitsPerSecond
)
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
func BandwidthFromDelta(bytes protocol.ByteCount, delta time.Duration) Bandwidth {
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
}

View File

@@ -0,0 +1,18 @@
package congestion
import "time"
// A Clock returns the current time
type Clock interface {
Now() time.Time
}
// DefaultClock implements the Clock interface using the Go stdlib clock.
type DefaultClock struct{}
var _ Clock = DefaultClock{}
// Now gets the current time
func (DefaultClock) Now() time.Time {
return time.Now()
}

View File

@@ -0,0 +1,214 @@
package congestion
import (
"math"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
// This cubic implementation is based on the one found in Chromiums's QUIC
// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
// Constants based on TCP defaults.
// The following constants are in 2^10 fractions of a second instead of ms to
// allow a 10 shift right to divide.
// 1024*1024^3 (first 1024 is from 0.100^3)
// where 0.100 is 100 ms which is the scaling round trip time.
const (
cubeScale = 40
cubeCongestionWindowScale = 410
cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
// TODO: when re-enabling cubic, make sure to use the actual packet size here
maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
)
const defaultNumConnections = 1
// Default Cubic backoff factor
const beta float32 = 0.7
// Additional backoff factor when loss occurs in the concave part of the Cubic
// curve. This additional backoff factor is expected to give up bandwidth to
// new concurrent flows and speed up convergence.
const betaLastMax float32 = 0.85
// Cubic implements the cubic algorithm from TCP
type Cubic struct {
clock Clock
// Number of connections to simulate.
numConnections int
// Time when this cycle started, after last loss event.
epoch time.Time
// Max congestion window used just before last loss event.
// Note: to improve fairness to other streams an additional back off is
// applied to this value if the new value is below our latest value.
lastMaxCongestionWindow protocol.ByteCount
// Number of acked bytes since the cycle started (epoch).
ackedBytesCount protocol.ByteCount
// TCP Reno equivalent congestion window in packets.
estimatedTCPcongestionWindow protocol.ByteCount
// Origin point of cubic function.
originPointCongestionWindow protocol.ByteCount
// Time to origin point of cubic function in 2^10 fractions of a second.
timeToOriginPoint uint32
// Last congestion window in packets computed by cubic function.
lastTargetCongestionWindow protocol.ByteCount
}
// NewCubic returns a new Cubic instance
func NewCubic(clock Clock) *Cubic {
c := &Cubic{
clock: clock,
numConnections: defaultNumConnections,
}
c.Reset()
return c
}
// Reset is called after a timeout to reset the cubic state
func (c *Cubic) Reset() {
c.epoch = time.Time{}
c.lastMaxCongestionWindow = 0
c.ackedBytesCount = 0
c.estimatedTCPcongestionWindow = 0
c.originPointCongestionWindow = 0
c.timeToOriginPoint = 0
c.lastTargetCongestionWindow = 0
}
func (c *Cubic) alpha() float32 {
// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
// We derive the equivalent alpha for an N-connection emulation as:
b := c.beta()
return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
}
func (c *Cubic) beta() float32 {
// kNConnectionBeta is the backoff factor after loss for our N-connection
// emulation, which emulates the effective backoff of an ensemble of N
// TCP-Reno connections on a single loss event. The effective multiplier is
// computed as:
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
}
func (c *Cubic) betaLastMax() float32 {
// betaLastMax is the additional backoff factor after loss for our
// N-connection emulation, which emulates the additional backoff of
// an ensemble of N TCP-Reno connections on a single loss event. The
// effective multiplier is computed as:
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
}
// OnApplicationLimited is called on ack arrival when sender is unable to use
// the available congestion window. Resets Cubic state during quiescence.
func (c *Cubic) OnApplicationLimited() {
// When sender is not using the available congestion window, the window does
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
// using the entire window during the time since the beginning of the current
// "epoch" (the end of the last loss recovery period). Since
// application-limited periods break this assumption, we reset the epoch when
// in such a period. This reset effectively freezes congestion window growth
// through application-limited periods and allows Cubic growth to continue
// when the entire window is being used.
c.epoch = time.Time{}
}
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
// a loss event. Returns the new congestion window in packets. The new
// congestion window is a multiplicative decrease of our current window.
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow {
// We never reached the old max, so assume we are competing with another
// flow. Use our extra back off factor to allow the other flow to go up.
c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
} else {
c.lastMaxCongestionWindow = currentCongestionWindow
}
c.epoch = time.Time{} // Reset time.
return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
}
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
// Returns the new congestion window in packets. The new congestion window
// follows a cubic function that depends on the time passed since last
// packet loss.
func (c *Cubic) CongestionWindowAfterAck(
ackedBytes protocol.ByteCount,
currentCongestionWindow protocol.ByteCount,
delayMin time.Duration,
eventTime time.Time,
) protocol.ByteCount {
c.ackedBytesCount += ackedBytes
if c.epoch.IsZero() {
// First ACK after a loss event.
c.epoch = eventTime // Start of epoch.
c.ackedBytesCount = ackedBytes // Reset count.
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
c.estimatedTCPcongestionWindow = currentCongestionWindow
if c.lastMaxCongestionWindow <= currentCongestionWindow {
c.timeToOriginPoint = 0
c.originPointCongestionWindow = currentCongestionWindow
} else {
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
c.originPointCongestionWindow = c.lastMaxCongestionWindow
}
}
// Change the time unit from microseconds to 2^10 fractions per second. Take
// the round trip time in account. This is done to allow us to use shift as a
// divide operator.
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
// Right-shifts of negative, signed numbers have implementation-dependent
// behavior, so force the offset to be positive, as is done in the kernel.
offset := int64(c.timeToOriginPoint) - elapsedTime
if offset < 0 {
offset = -offset
}
deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale
var targetCongestionWindow protocol.ByteCount
if elapsedTime > int64(c.timeToOriginPoint) {
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
} else {
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
}
// Limit the CWND increase to half the acked bytes.
targetCongestionWindow = utils.Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
// Increase the window by approximately Alpha * 1 MSS of bytes every
// time we ack an estimated tcp window of bytes. For small
// congestion windows (less than 25), the formula below will
// increase slightly slower than linearly per estimated tcp window
// of bytes.
c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow))
c.ackedBytesCount = 0
// We have a new cubic congestion window.
c.lastTargetCongestionWindow = targetCongestionWindow
// Compute target congestion_window based on cubic target and estimated TCP
// congestion_window, use highest (fastest).
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
targetCongestionWindow = c.estimatedTCPcongestionWindow
}
return targetCongestionWindow
}
// SetNumConnections sets the number of emulated connections
func (c *Cubic) SetNumConnections(n int) {
c.numConnections = n
}

View File

@@ -0,0 +1,316 @@
package congestion
import (
"fmt"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
const (
// maxDatagramSize is the default maximum packet size used in the Linux TCP implementation.
// Used in QUIC for congestion window computations in bytes.
initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
maxBurstPackets = 3
renoBeta = 0.7 // Reno backoff factor.
minCongestionWindowPackets = 2
initialCongestionWindow = 32
)
type cubicSender struct {
hybridSlowStart HybridSlowStart
rttStats *utils.RTTStats
cubic *Cubic
pacer *pacer
clock Clock
reno bool
// Track the largest packet that has been sent.
largestSentPacketNumber protocol.PacketNumber
// Track the largest packet that has been acked.
largestAckedPacketNumber protocol.PacketNumber
// Track the largest packet number outstanding when a CWND cutback occurs.
largestSentAtLastCutback protocol.PacketNumber
// Whether the last loss event caused us to exit slowstart.
// Used for stats collection of slowstartPacketsLost
lastCutbackExitedSlowstart bool
// Congestion window in bytes.
congestionWindow protocol.ByteCount
// Slow start congestion window in bytes, aka ssthresh.
slowStartThreshold protocol.ByteCount
// ACK counter for the Reno implementation.
numAckedPackets uint64
initialCongestionWindow protocol.ByteCount
initialMaxCongestionWindow protocol.ByteCount
maxDatagramSize protocol.ByteCount
lastState logging.CongestionState
tracer *logging.ConnectionTracer
}
var (
_ SendAlgorithm = &cubicSender{}
_ SendAlgorithmWithDebugInfos = &cubicSender{}
)
// NewCubicSender makes a new cubic sender
func NewCubicSender(
clock Clock,
rttStats *utils.RTTStats,
initialMaxDatagramSize protocol.ByteCount,
reno bool,
tracer *logging.ConnectionTracer,
) *cubicSender {
return newCubicSender(
clock,
rttStats,
reno,
initialMaxDatagramSize,
initialCongestionWindow*initialMaxDatagramSize,
protocol.MaxCongestionWindowPackets*initialMaxDatagramSize,
tracer,
)
}
func newCubicSender(
clock Clock,
rttStats *utils.RTTStats,
reno bool,
initialMaxDatagramSize,
initialCongestionWindow,
initialMaxCongestionWindow protocol.ByteCount,
tracer *logging.ConnectionTracer,
) *cubicSender {
c := &cubicSender{
rttStats: rttStats,
largestSentPacketNumber: protocol.InvalidPacketNumber,
largestAckedPacketNumber: protocol.InvalidPacketNumber,
largestSentAtLastCutback: protocol.InvalidPacketNumber,
initialCongestionWindow: initialCongestionWindow,
initialMaxCongestionWindow: initialMaxCongestionWindow,
congestionWindow: initialCongestionWindow,
slowStartThreshold: protocol.MaxByteCount,
cubic: NewCubic(clock),
clock: clock,
reno: reno,
tracer: tracer,
maxDatagramSize: initialMaxDatagramSize,
}
c.pacer = newPacer(c.BandwidthEstimate)
if c.tracer != nil && c.tracer.UpdatedCongestionState != nil {
c.lastState = logging.CongestionStateSlowStart
c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
}
return c
}
// TimeUntilSend returns when the next packet should be sent.
func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) time.Time {
return c.pacer.TimeUntilSend()
}
func (c *cubicSender) HasPacingBudget(now time.Time) bool {
return c.pacer.Budget(now) >= c.maxDatagramSize
}
func (c *cubicSender) maxCongestionWindow() protocol.ByteCount {
return c.maxDatagramSize * protocol.MaxCongestionWindowPackets
}
func (c *cubicSender) minCongestionWindow() protocol.ByteCount {
return c.maxDatagramSize * minCongestionWindowPackets
}
func (c *cubicSender) OnPacketSent(
sentTime time.Time,
_ protocol.ByteCount,
packetNumber protocol.PacketNumber,
bytes protocol.ByteCount,
isRetransmittable bool,
) {
c.pacer.SentPacket(sentTime, bytes)
if !isRetransmittable {
return
}
c.largestSentPacketNumber = packetNumber
c.hybridSlowStart.OnPacketSent(packetNumber)
}
func (c *cubicSender) CanSend(bytesInFlight protocol.ByteCount) bool {
return bytesInFlight < c.GetCongestionWindow()
}
func (c *cubicSender) InRecovery() bool {
return c.largestAckedPacketNumber != protocol.InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
}
func (c *cubicSender) InSlowStart() bool {
return c.GetCongestionWindow() < c.slowStartThreshold
}
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
return c.congestionWindow
}
func (c *cubicSender) MaybeExitSlowStart() {
if c.InSlowStart() &&
c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
// exit slow start
c.slowStartThreshold = c.congestionWindow
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
}
}
func (c *cubicSender) OnPacketAcked(
ackedPacketNumber protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
c.largestAckedPacketNumber = utils.Max(ackedPacketNumber, c.largestAckedPacketNumber)
if c.InRecovery() {
return
}
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
if c.InSlowStart() {
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
}
}
func (c *cubicSender) OnCongestionEvent(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) {
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
// already sent should be treated as a single loss event, since it's expected.
if packetNumber <= c.largestSentAtLastCutback {
return
}
c.lastCutbackExitedSlowstart = c.InSlowStart()
c.maybeTraceStateChange(logging.CongestionStateRecovery)
if c.reno {
c.congestionWindow = protocol.ByteCount(float64(c.congestionWindow) * renoBeta)
} else {
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
}
if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
c.congestionWindow = minCwnd
}
c.slowStartThreshold = c.congestionWindow
c.largestSentAtLastCutback = c.largestSentPacketNumber
// reset packet count from congestion avoidance mode. We start
// counting again when we're out of recovery.
c.numAckedPackets = 0
}
// Called when we receive an ack. Normal TCP tracks how many packets one ack
// represents, but quic has a separate ack for each packet.
func (c *cubicSender) maybeIncreaseCwnd(
_ protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
// Do not increase the congestion window unless the sender is close to using
// the current window.
if !c.isCwndLimited(priorInFlight) {
c.cubic.OnApplicationLimited()
c.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
return
}
if c.congestionWindow >= c.maxCongestionWindow() {
return
}
if c.InSlowStart() {
// TCP slow start, exponential growth, increase by one for each ACK.
c.congestionWindow += c.maxDatagramSize
c.maybeTraceStateChange(logging.CongestionStateSlowStart)
return
}
// Congestion avoidance
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
if c.reno {
// Classic Reno congestion avoidance.
c.numAckedPackets++
if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
c.congestionWindow += c.maxDatagramSize
c.numAckedPackets = 0
}
} else {
c.congestionWindow = utils.Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
}
}
func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool {
congestionWindow := c.GetCongestionWindow()
if bytesInFlight >= congestionWindow {
return true
}
availableBytes := congestionWindow - bytesInFlight
slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
}
// BandwidthEstimate returns the current bandwidth estimate
func (c *cubicSender) BandwidthEstimate() Bandwidth {
srtt := c.rttStats.SmoothedRTT()
if srtt == 0 {
// If we haven't measured an rtt, the bandwidth estimate is unknown.
return infBandwidth
}
return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
}
// OnRetransmissionTimeout is called on an retransmission timeout
func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
c.largestSentAtLastCutback = protocol.InvalidPacketNumber
if !packetsRetransmitted {
return
}
c.hybridSlowStart.Restart()
c.cubic.Reset()
c.slowStartThreshold = c.congestionWindow / 2
c.congestionWindow = c.minCongestionWindow()
}
// OnConnectionMigration is called when the connection is migrated (?)
func (c *cubicSender) OnConnectionMigration() {
c.hybridSlowStart.Restart()
c.largestSentPacketNumber = protocol.InvalidPacketNumber
c.largestAckedPacketNumber = protocol.InvalidPacketNumber
c.largestSentAtLastCutback = protocol.InvalidPacketNumber
c.lastCutbackExitedSlowstart = false
c.cubic.Reset()
c.numAckedPackets = 0
c.congestionWindow = c.initialCongestionWindow
c.slowStartThreshold = c.initialMaxCongestionWindow
}
func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
if c.tracer == nil || c.tracer.UpdatedCongestionState == nil || new == c.lastState {
return
}
c.tracer.UpdatedCongestionState(new)
c.lastState = new
}
func (c *cubicSender) SetMaxDatagramSize(s protocol.ByteCount) {
if s < c.maxDatagramSize {
panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
}
cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
c.maxDatagramSize = s
if cwndIsMinCwnd {
c.congestionWindow = c.minCongestionWindow()
}
c.pacer.SetMaxDatagramSize(s)
}

View File

@@ -0,0 +1,113 @@
package congestion
import (
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
// Note(pwestin): the magic clamping numbers come from the original code in
// tcp_cubic.c.
const hybridStartLowWindow = protocol.ByteCount(16)
// Number of delay samples for detecting the increase of delay.
const hybridStartMinSamples = uint32(8)
// Exit slow start if the min rtt has increased by more than 1/8th.
const hybridStartDelayFactorExp = 3 // 2^3 = 8
// The original paper specifies 2 and 8ms, but those have changed over time.
const (
hybridStartDelayMinThresholdUs = int64(4000)
hybridStartDelayMaxThresholdUs = int64(16000)
)
// HybridSlowStart implements the TCP hybrid slow start algorithm
type HybridSlowStart struct {
endPacketNumber protocol.PacketNumber
lastSentPacketNumber protocol.PacketNumber
started bool
currentMinRTT time.Duration
rttSampleCount uint32
hystartFound bool
}
// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
func (s *HybridSlowStart) StartReceiveRound(lastSent protocol.PacketNumber) {
s.endPacketNumber = lastSent
s.currentMinRTT = 0
s.rttSampleCount = 0
s.started = true
}
// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
func (s *HybridSlowStart) IsEndOfRound(ack protocol.PacketNumber) bool {
return s.endPacketNumber < ack
}
// ShouldExitSlowStart should be called on every new ack frame, since a new
// RTT measurement can be made then.
// rtt: the RTT for this ack packet.
// minRTT: is the lowest delay (RTT) we have seen during the session.
// congestionWindow: the congestion window in packets.
func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow protocol.ByteCount) bool {
if !s.started {
// Time to start the hybrid slow start.
s.StartReceiveRound(s.lastSentPacketNumber)
}
if s.hystartFound {
return true
}
// Second detection parameter - delay increase detection.
// Compare the minimum delay (s.currentMinRTT) of the current
// burst of packets relative to the minimum delay during the session.
// Note: we only look at the first few(8) packets in each burst, since we
// only want to compare the lowest RTT of the burst relative to previous
// bursts.
s.rttSampleCount++
if s.rttSampleCount <= hybridStartMinSamples {
if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
s.currentMinRTT = latestRTT
}
}
// We only need to check this once per round.
if s.rttSampleCount == hybridStartMinSamples {
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
minRTTincreaseThresholdUs = utils.Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
minRTTincreaseThreshold := time.Duration(utils.Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
s.hystartFound = true
}
}
// Exit from slow start if the cwnd is greater than 16 and
// increasing delay is found.
return congestionWindow >= hybridStartLowWindow && s.hystartFound
}
// OnPacketSent is called when a packet was sent
func (s *HybridSlowStart) OnPacketSent(packetNumber protocol.PacketNumber) {
s.lastSentPacketNumber = packetNumber
}
// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
// the round when the final packet of the burst is received and start it on
// the next incoming ack.
func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber protocol.PacketNumber) {
if s.IsEndOfRound(ackedPacketNumber) {
s.started = false
}
}
// Started returns true if started
func (s *HybridSlowStart) Started() bool {
return s.started
}
// Restart the slow start phase
func (s *HybridSlowStart) Restart() {
s.started = false
s.hystartFound = false
}

View File

@@ -0,0 +1,28 @@
package congestion
import (
"time"
"github.com/quic-go/quic-go/internal/protocol"
)
// A SendAlgorithm performs congestion control
type SendAlgorithm interface {
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Time
HasPacingBudget(now time.Time) bool
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
CanSend(bytesInFlight protocol.ByteCount) bool
MaybeExitSlowStart()
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
OnCongestionEvent(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
OnRetransmissionTimeout(packetsRetransmitted bool)
SetMaxDatagramSize(protocol.ByteCount)
}
// A SendAlgorithmWithDebugInfos is a SendAlgorithm that exposes some debug infos
type SendAlgorithmWithDebugInfos interface {
SendAlgorithm
InSlowStart() bool
InRecovery() bool
GetCongestionWindow() protocol.ByteCount
}

View File

@@ -0,0 +1,80 @@
package congestion
import (
"math"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
const maxBurstSizePackets = 10
// The pacer implements a token bucket pacing algorithm.
type pacer struct {
budgetAtLastSent protocol.ByteCount
maxDatagramSize protocol.ByteCount
lastSentTime time.Time
adjustedBandwidth func() uint64 // in bytes/s
}
func newPacer(getBandwidth func() Bandwidth) *pacer {
p := &pacer{
maxDatagramSize: initialMaxDatagramSize,
adjustedBandwidth: func() uint64 {
// Bandwidth is in bits/s. We need the value in bytes/s.
bw := uint64(getBandwidth() / BytesPerSecond)
// Use a slightly higher value than the actual measured bandwidth.
// RTT variations then won't result in under-utilization of the congestion window.
// Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire,
// provided the congestion window is fully utilized and acknowledgments arrive at regular intervals.
return bw * 5 / 4
},
}
p.budgetAtLastSent = p.maxBurstSize()
return p
}
func (p *pacer) SentPacket(sendTime time.Time, size protocol.ByteCount) {
budget := p.Budget(sendTime)
if size > budget {
p.budgetAtLastSent = 0
} else {
p.budgetAtLastSent = budget - size
}
p.lastSentTime = sendTime
}
func (p *pacer) Budget(now time.Time) protocol.ByteCount {
if p.lastSentTime.IsZero() {
return p.maxBurstSize()
}
budget := p.budgetAtLastSent + (protocol.ByteCount(p.adjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
if budget < 0 { // protect against overflows
budget = protocol.MaxByteCount
}
return utils.Min(p.maxBurstSize(), budget)
}
func (p *pacer) maxBurstSize() protocol.ByteCount {
return utils.Max(
protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.adjustedBandwidth())/1e9,
maxBurstSizePackets*p.maxDatagramSize,
)
}
// TimeUntilSend returns when the next packet should be sent.
// It returns the zero value of time.Time if a packet can be sent immediately.
func (p *pacer) TimeUntilSend() time.Time {
if p.budgetAtLastSent >= p.maxDatagramSize {
return time.Time{}
}
return p.lastSentTime.Add(utils.Max(
protocol.MinPacingDelay,
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.adjustedBandwidth())))*time.Nanosecond,
))
}
func (p *pacer) SetMaxDatagramSize(s protocol.ByteCount) {
p.maxDatagramSize = s
}

View File

@@ -0,0 +1,125 @@
package flowcontrol
import (
"sync"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
type baseFlowController struct {
// for sending data
bytesSent protocol.ByteCount
sendWindow protocol.ByteCount
lastBlockedAt protocol.ByteCount
// for receiving data
//nolint:structcheck // The mutex is used both by the stream and the connection flow controller
mutex sync.Mutex
bytesRead protocol.ByteCount
highestReceived protocol.ByteCount
receiveWindow protocol.ByteCount
receiveWindowSize protocol.ByteCount
maxReceiveWindowSize protocol.ByteCount
allowWindowIncrease func(size protocol.ByteCount) bool
epochStartTime time.Time
epochStartOffset protocol.ByteCount
rttStats *utils.RTTStats
logger utils.Logger
}
// IsNewlyBlocked says if it is newly blocked by flow control.
// For every offset, it only returns true once.
// If it is blocked, the offset is returned.
func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
return false, 0
}
c.lastBlockedAt = c.sendWindow
return true, c.sendWindow
}
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
c.bytesSent += n
}
// UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame.
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
if offset > c.sendWindow {
c.sendWindow = offset
}
}
func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
// this only happens during connection establishment, when data is sent before we receive the peer's transport parameters
if c.bytesSent > c.sendWindow {
return 0
}
return c.sendWindow - c.bytesSent
}
// needs to be called with locked mutex
func (c *baseFlowController) addBytesRead(n protocol.ByteCount) {
// pretend we sent a WindowUpdate when reading the first byte
// this way auto-tuning of the window size already works for the first WindowUpdate
if c.bytesRead == 0 {
c.startNewAutoTuningEpoch(time.Now())
}
c.bytesRead += n
}
func (c *baseFlowController) hasWindowUpdate() bool {
bytesRemaining := c.receiveWindow - c.bytesRead
// update the window when more than the threshold was consumed
return bytesRemaining <= protocol.ByteCount(float64(c.receiveWindowSize)*(1-protocol.WindowUpdateThreshold))
}
// getWindowUpdate updates the receive window, if necessary
// it returns the new offset
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
if !c.hasWindowUpdate() {
return 0
}
c.maybeAdjustWindowSize()
c.receiveWindow = c.bytesRead + c.receiveWindowSize
return c.receiveWindow
}
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
func (c *baseFlowController) maybeAdjustWindowSize() {
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
// don't do anything if less than half the window has been consumed
if bytesReadInEpoch <= c.receiveWindowSize/2 {
return
}
rtt := c.rttStats.SmoothedRTT()
if rtt == 0 {
return
}
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
now := time.Now()
if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
// window is consumed too fast, try to increase the window size
newSize := utils.Min(2*c.receiveWindowSize, c.maxReceiveWindowSize)
if newSize > c.receiveWindowSize && (c.allowWindowIncrease == nil || c.allowWindowIncrease(newSize-c.receiveWindowSize)) {
c.receiveWindowSize = newSize
}
}
c.startNewAutoTuningEpoch(now)
}
func (c *baseFlowController) startNewAutoTuningEpoch(now time.Time) {
c.epochStartTime = now
c.epochStartOffset = c.bytesRead
}
func (c *baseFlowController) checkFlowControlViolation() bool {
return c.highestReceived > c.receiveWindow
}

View File

@@ -0,0 +1,112 @@
package flowcontrol
import (
"errors"
"fmt"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
)
type connectionFlowController struct {
baseFlowController
queueWindowUpdate func()
}
var _ ConnectionFlowController = &connectionFlowController{}
// NewConnectionFlowController gets a new flow controller for the connection
// It is created before we receive the peer's transport parameters, thus it starts with a sendWindow of 0.
func NewConnectionFlowController(
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
queueWindowUpdate func(),
allowWindowIncrease func(size protocol.ByteCount) bool,
rttStats *utils.RTTStats,
logger utils.Logger,
) ConnectionFlowController {
return &connectionFlowController{
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow,
allowWindowIncrease: allowWindowIncrease,
logger: logger,
},
queueWindowUpdate: queueWindowUpdate,
}
}
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
return c.baseFlowController.sendWindowSize()
}
// IncrementHighestReceived adds an increment to the highestReceived value
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
c.mutex.Lock()
defer c.mutex.Unlock()
c.highestReceived += increment
if c.checkFlowControlViolation() {
return &qerr.TransportError{
ErrorCode: qerr.FlowControlError,
ErrorMessage: fmt.Sprintf("received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow),
}
}
return nil
}
func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) {
c.mutex.Lock()
c.baseFlowController.addBytesRead(n)
shouldQueueWindowUpdate := c.hasWindowUpdate()
c.mutex.Unlock()
if shouldQueueWindowUpdate {
c.queueWindowUpdate()
}
}
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
c.mutex.Lock()
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if oldWindowSize < c.receiveWindowSize {
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
}
c.mutex.Unlock()
return offset
}
// EnsureMinimumWindowSize sets a minimum window size
// it should make sure that the connection-level window is increased when a stream-level window grows
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
c.mutex.Lock()
if inc > c.receiveWindowSize {
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10))
newSize := utils.Min(inc, c.maxReceiveWindowSize)
if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) {
c.receiveWindowSize = newSize
}
c.startNewAutoTuningEpoch(time.Now())
}
c.mutex.Unlock()
}
// Reset rests the flow controller. This happens when 0-RTT is rejected.
// All stream data is invalidated, it's if we had never opened a stream and never sent any data.
// At that point, we only have sent stream data, but we didn't have the keys to open 1-RTT keys yet.
func (c *connectionFlowController) Reset() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.bytesRead > 0 || c.highestReceived > 0 || !c.epochStartTime.IsZero() {
return errors.New("flow controller reset after reading data")
}
c.bytesSent = 0
c.lastBlockedAt = 0
return nil
}

View File

@@ -0,0 +1,42 @@
package flowcontrol
import "github.com/quic-go/quic-go/internal/protocol"
type flowController interface {
// for sending
SendWindowSize() protocol.ByteCount
UpdateSendWindow(protocol.ByteCount)
AddBytesSent(protocol.ByteCount)
// for receiving
AddBytesRead(protocol.ByteCount)
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
IsNewlyBlocked() (bool, protocol.ByteCount)
}
// A StreamFlowController is a flow controller for a QUIC stream.
type StreamFlowController interface {
flowController
// for receiving
// UpdateHighestReceived should be called when a new highest offset is received
// final has to be to true if this is the final offset of the stream,
// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
// Abandon should be called when reading from the stream is aborted early,
// and there won't be any further calls to AddBytesRead.
Abandon()
}
// The ConnectionFlowController is the flow controller for the connection.
type ConnectionFlowController interface {
flowController
Reset() error
}
type connectionFlowControllerI interface {
ConnectionFlowController
// The following two methods are not supposed to be called from outside this packet, but are needed internally
// for sending
EnsureMinimumWindowSize(protocol.ByteCount)
// for receiving
IncrementHighestReceived(protocol.ByteCount) error
}

View File

@@ -0,0 +1,149 @@
package flowcontrol
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
)
type streamFlowController struct {
baseFlowController
streamID protocol.StreamID
queueWindowUpdate func()
connection connectionFlowControllerI
receivedFinalOffset bool
}
var _ StreamFlowController = &streamFlowController{}
// NewStreamFlowController gets a new flow controller for a stream
func NewStreamFlowController(
streamID protocol.StreamID,
cfc ConnectionFlowController,
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount,
queueWindowUpdate func(protocol.StreamID),
rttStats *utils.RTTStats,
logger utils.Logger,
) StreamFlowController {
return &streamFlowController{
streamID: streamID,
connection: cfc.(connectionFlowControllerI),
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow,
sendWindow: initialSendWindow,
logger: logger,
},
}
}
// UpdateHighestReceived updates the highestReceived value, if the offset is higher.
func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool) error {
// If the final offset for this stream is already known, check for consistency.
if c.receivedFinalOffset {
// If we receive another final offset, check that it's the same.
if final && offset != c.highestReceived {
return &qerr.TransportError{
ErrorCode: qerr.FinalSizeError,
ErrorMessage: fmt.Sprintf("received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, offset),
}
}
// Check that the offset is below the final offset.
if offset > c.highestReceived {
return &qerr.TransportError{
ErrorCode: qerr.FinalSizeError,
ErrorMessage: fmt.Sprintf("received offset %d for stream %d, but final offset was already received at %d", offset, c.streamID, c.highestReceived),
}
}
}
if final {
c.receivedFinalOffset = true
}
if offset == c.highestReceived {
return nil
}
// A higher offset was received before.
// This can happen due to reordering.
if offset <= c.highestReceived {
if final {
return &qerr.TransportError{
ErrorCode: qerr.FinalSizeError,
ErrorMessage: fmt.Sprintf("received final offset %d for stream %d, but already received offset %d before", offset, c.streamID, c.highestReceived),
}
}
return nil
}
increment := offset - c.highestReceived
c.highestReceived = offset
if c.checkFlowControlViolation() {
return &qerr.TransportError{
ErrorCode: qerr.FlowControlError,
ErrorMessage: fmt.Sprintf("received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow),
}
}
return c.connection.IncrementHighestReceived(increment)
}
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
c.mutex.Lock()
c.baseFlowController.addBytesRead(n)
shouldQueueWindowUpdate := c.shouldQueueWindowUpdate()
c.mutex.Unlock()
if shouldQueueWindowUpdate {
c.queueWindowUpdate()
}
c.connection.AddBytesRead(n)
}
func (c *streamFlowController) Abandon() {
c.mutex.Lock()
unread := c.highestReceived - c.bytesRead
c.mutex.Unlock()
if unread > 0 {
c.connection.AddBytesRead(unread)
}
}
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
c.baseFlowController.AddBytesSent(n)
c.connection.AddBytesSent(n)
}
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
return utils.Min(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize())
}
func (c *streamFlowController) shouldQueueWindowUpdate() bool {
return !c.receivedFinalOffset && c.hasWindowUpdate()
}
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
// If we already received the final offset for this stream, the peer won't need any additional flow control credit.
if c.receivedFinalOffset {
return 0
}
// Don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
c.mutex.Lock()
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
}
c.mutex.Unlock()
return offset
}

View File

@@ -0,0 +1,94 @@
package handshake
import (
"crypto/cipher"
"encoding/binary"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
keyLabel := hkdfLabelKeyV1
ivLabel := hkdfLabelIVV1
if v == protocol.Version2 {
keyLabel = hkdfLabelKeyV2
ivLabel = hkdfLabelIVV2
}
key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, keyLabel, suite.KeyLen)
iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, ivLabel, suite.IVLen())
return suite.AEAD(key, iv)
}
type longHeaderSealer struct {
aead cipher.AEAD
headerProtector headerProtector
// use a single slice to avoid allocations
nonceBuf []byte
}
var _ LongHeaderSealer = &longHeaderSealer{}
func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer {
return &longHeaderSealer{
aead: aead,
headerProtector: headerProtector,
nonceBuf: make([]byte, aead.NonceSize()),
}
}
func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return s.aead.Seal(dst, s.nonceBuf, src, ad)
}
func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
s.headerProtector.EncryptHeader(sample, firstByte, pnBytes)
}
func (s *longHeaderSealer) Overhead() int {
return s.aead.Overhead()
}
type longHeaderOpener struct {
aead cipher.AEAD
headerProtector headerProtector
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
// use a single slice to avoid allocations
nonceBuf []byte
}
var _ LongHeaderOpener = &longHeaderOpener{}
func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener {
return &longHeaderOpener{
aead: aead,
headerProtector: headerProtector,
nonceBuf: make([]byte, aead.NonceSize()),
}
}
func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN)
}
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
if err == nil {
o.highestRcvdPN = utils.Max(o.highestRcvdPN, pn)
} else {
err = ErrDecryptionFailed
}
return dec, err
}
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
o.headerProtector.DecryptHeader(sample, firstByte, pnBytes)
}

View File

@@ -0,0 +1,104 @@
package handshake
import (
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/tls"
"fmt"
"golang.org/x/crypto/chacha20poly1305"
)
// These cipher suite implementations are copied from the standard library crypto/tls package.
const aeadNonceLength = 12
type cipherSuite struct {
ID uint16
Hash crypto.Hash
KeyLen int
AEAD func(key, nonceMask []byte) cipher.AEAD
}
func (s cipherSuite) IVLen() int { return aeadNonceLength }
func getCipherSuite(id uint16) *cipherSuite {
switch id {
case tls.TLS_AES_128_GCM_SHA256:
return &cipherSuite{ID: tls.TLS_AES_128_GCM_SHA256, Hash: crypto.SHA256, KeyLen: 16, AEAD: aeadAESGCMTLS13}
case tls.TLS_CHACHA20_POLY1305_SHA256:
return &cipherSuite{ID: tls.TLS_CHACHA20_POLY1305_SHA256, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadChaCha20Poly1305}
case tls.TLS_AES_256_GCM_SHA384:
return &cipherSuite{ID: tls.TLS_AES_256_GCM_SHA384, Hash: crypto.SHA384, KeyLen: 32, AEAD: aeadAESGCMTLS13}
default:
panic(fmt.Sprintf("unknown cypher suite: %d", id))
}
}
func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aes, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(aes)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
func aeadChaCha20Poly1305(key, nonceMask []byte) cipher.AEAD {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aead, err := chacha20poly1305.New(key)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
// before each call.
type xorNonceAEAD struct {
nonceMask [aeadNonceLength]byte
aead cipher.AEAD
}
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result
}
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result, err
}

View File

@@ -0,0 +1,21 @@
package handshake
import (
"net"
"time"
)
type conn struct {
localAddr, remoteAddr net.Addr
}
var _ net.Conn = &conn{}
func (c *conn) Read([]byte) (int, error) { return 0, nil }
func (c *conn) Write([]byte) (int, error) { return 0, nil }
func (c *conn) Close() error { return nil }
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
func (c *conn) SetReadDeadline(time.Time) error { return nil }
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
func (c *conn) SetDeadline(time.Time) error { return nil }

View File

@@ -0,0 +1,681 @@
package handshake
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/quicvarint"
)
type quicVersionContextKey struct{}
var QUICVersionContextKey = &quicVersionContextKey{}
const clientSessionStateRevision = 3
type cryptoSetup struct {
tlsConf *tls.Config
conn *qtls.QUICConn
events []Event
version protocol.VersionNumber
ourParams *wire.TransportParameters
peerParams *wire.TransportParameters
zeroRTTParameters *wire.TransportParameters
allow0RTT bool
rttStats *utils.RTTStats
tracer *logging.ConnectionTracer
logger utils.Logger
perspective protocol.Perspective
mutex sync.Mutex // protects all members below
handshakeCompleteTime time.Time
zeroRTTOpener LongHeaderOpener // only set for the server
zeroRTTSealer LongHeaderSealer // only set for the client
initialOpener LongHeaderOpener
initialSealer LongHeaderSealer
handshakeOpener LongHeaderOpener
handshakeSealer LongHeaderSealer
used0RTT atomic.Bool
aead *updatableAEAD
has1RTTSealer bool
has1RTTOpener bool
}
var _ CryptoSetup = &cryptoSetup{}
// NewCryptoSetupClient creates a new crypto setup for the client
func NewCryptoSetupClient(
connID protocol.ConnectionID,
tp *wire.TransportParameters,
tlsConf *tls.Config,
enable0RTT bool,
rttStats *utils.RTTStats,
tracer *logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
) CryptoSetup {
cs := newCryptoSetup(
connID,
tp,
rttStats,
tracer,
logger,
protocol.PerspectiveClient,
version,
)
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState)
cs.tlsConf = tlsConf
cs.conn = qtls.QUICClient(quicConf)
cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient))
return cs
}
// NewCryptoSetupServer creates a new crypto setup for the server
func NewCryptoSetupServer(
connID protocol.ConnectionID,
localAddr, remoteAddr net.Addr,
tp *wire.TransportParameters,
tlsConf *tls.Config,
allow0RTT bool,
rttStats *utils.RTTStats,
tracer *logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
) CryptoSetup {
cs := newCryptoSetup(
connID,
tp,
rttStats,
tracer,
logger,
protocol.PerspectiveServer,
version,
)
cs.allow0RTT = allow0RTT
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket)
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)
cs.tlsConf = quicConf.TLSConfig
cs.conn = qtls.QUICServer(quicConf)
return cs
}
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) {
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
c = c.Clone()
// This won't be necessary anymore once https://github.com/golang/go/issues/63722 is accepted.
c.MinVersion = tls.VersionTLS13
// We're returning a tls.Config here, so we need to apply this recursively.
addConnToClientHelloInfo(c, localAddr, remoteAddr)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
}
func newCryptoSetup(
connID protocol.ConnectionID,
tp *wire.TransportParameters,
rttStats *utils.RTTStats,
tracer *logging.ConnectionTracer,
logger utils.Logger,
perspective protocol.Perspective,
version protocol.VersionNumber,
) *cryptoSetup {
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
if tracer != nil && tracer.UpdatedKeyFromTLS != nil {
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
}
return &cryptoSetup{
initialSealer: initialSealer,
initialOpener: initialOpener,
aead: newUpdatableAEAD(rttStats, tracer, logger, version),
events: make([]Event, 0, 16),
ourParams: tp,
rttStats: rttStats,
tracer: tracer,
logger: logger,
perspective: perspective,
version: version,
}
}
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version)
h.initialSealer = initialSealer
h.initialOpener = initialOpener
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
}
}
func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
return h.aead.SetLargestAcked(pn)
}
func (h *cryptoSetup) StartHandshake() error {
err := h.conn.Start(context.WithValue(context.Background(), QUICVersionContextKey, h.version))
if err != nil {
return wrapError(err)
}
for {
ev := h.conn.NextEvent()
done, err := h.handleEvent(ev)
if err != nil {
return wrapError(err)
}
if done {
break
}
}
if h.perspective == protocol.PerspectiveClient {
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
h.logger.Debugf("Doing 0-RTT.")
h.events = append(h.events, Event{Kind: EventRestoredTransportParameters, TransportParameters: h.zeroRTTParameters})
} else {
h.logger.Debugf("Not doing 0-RTT. Has sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil)
}
}
return nil
}
// Close closes the crypto setup.
// It aborts the handshake, if it is still running.
func (h *cryptoSetup) Close() error {
return h.conn.Close()
}
// HandleMessage handles a TLS handshake message.
// It is called by the crypto streams when a new message is available.
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
if err := h.handleMessage(data, encLevel); err != nil {
return wrapError(err)
}
return nil
}
func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error {
if err := h.conn.HandleData(qtls.ToTLSEncryptionLevel(encLevel), data); err != nil {
return err
}
for {
ev := h.conn.NextEvent()
done, err := h.handleEvent(ev)
if err != nil {
return err
}
if done {
return nil
}
}
}
func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) {
switch ev.Kind {
case qtls.QUICNoEvent:
return true, nil
case qtls.QUICSetReadSecret:
h.SetReadKey(ev.Level, ev.Suite, ev.Data)
return false, nil
case qtls.QUICSetWriteSecret:
h.SetWriteKey(ev.Level, ev.Suite, ev.Data)
return false, nil
case qtls.QUICTransportParameters:
return false, h.handleTransportParameters(ev.Data)
case qtls.QUICTransportParametersRequired:
h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective))
return false, nil
case qtls.QUICRejectedEarlyData:
h.rejected0RTT()
return false, nil
case qtls.QUICWriteData:
h.WriteRecord(ev.Level, ev.Data)
return false, nil
case qtls.QUICHandshakeDone:
h.handshakeComplete()
return false, nil
default:
return false, fmt.Errorf("unexpected event: %d", ev.Kind)
}
}
func (h *cryptoSetup) NextEvent() Event {
if len(h.events) == 0 {
return Event{Kind: EventNoEvent}
}
ev := h.events[0]
h.events = h.events[1:]
return ev
}
func (h *cryptoSetup) handleTransportParameters(data []byte) error {
var tp wire.TransportParameters
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
return err
}
h.peerParams = &tp
h.events = append(h.events, Event{Kind: EventReceivedTransportParameters, TransportParameters: h.peerParams})
return nil
}
// must be called after receiving the transport parameters
func (h *cryptoSetup) marshalDataForSessionState() []byte {
b := make([]byte, 0, 256)
b = quicvarint.Append(b, clientSessionStateRevision)
b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds()))
return h.peerParams.MarshalForSessionTicket(b)
}
func (h *cryptoSetup) handleDataFromSessionState(data []byte) {
tp, err := h.handleDataFromSessionStateImpl(data)
if err != nil {
h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
return
}
h.zeroRTTParameters = tp
}
func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) {
r := bytes.NewReader(data)
ver, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
if ver != clientSessionStateRevision {
return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
}
rtt, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond)
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
return nil, err
}
return &tp, nil
}
func (h *cryptoSetup) getDataForSessionTicket() []byte {
ticket := &sessionTicket{
RTT: h.rttStats.SmoothedRTT(),
}
if h.allow0RTT {
ticket.Parameters = h.ourParams
}
return ticket.Marshal()
}
// GetSessionTicket generates a new session ticket.
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
// It is only valid for the server.
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
if err := qtls.SendSessionTicket(h.conn, h.allow0RTT); err != nil {
// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
// We can't check h.tlsConfig here, since the actual config might have been obtained from
// the GetConfigForClient callback.
// See https://github.com/golang/go/issues/62032.
// Once that issue is resolved, this error assertion can be removed.
if strings.Contains(err.Error(), "session ticket keys unavailable") {
return nil, nil
}
return nil, err
}
ev := h.conn.NextEvent()
if ev.Kind != qtls.QUICWriteData || ev.Level != qtls.QUICEncryptionLevelApplication {
panic("crypto/tls bug: where's my session ticket?")
}
ticket := ev.Data
if ev := h.conn.NextEvent(); ev.Kind != qtls.QUICNoEvent {
panic("crypto/tls bug: why more than one ticket?")
}
return ticket, nil
}
// handleSessionTicket is called for the server when receiving the client's session ticket.
// It reads parameters from the session ticket and decides whether to accept 0-RTT when the session ticket is used for 0-RTT.
func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bool) bool {
var t sessionTicket
if err := t.Unmarshal(sessionTicketData, using0RTT); err != nil {
h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error())
return false
}
h.rttStats.SetInitialRTT(t.RTT)
if !using0RTT {
return false
}
valid := h.ourParams.ValidFor0RTT(t.Parameters)
if !valid {
h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.")
return false
}
if !h.allow0RTT {
h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.")
return false
}
h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
return true
}
// rejected0RTT is called for the client when the server rejects 0-RTT.
func (h *cryptoSetup) rejected0RTT() {
h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.")
h.mutex.Lock()
had0RTTKeys := h.zeroRTTSealer != nil
h.zeroRTTSealer = nil
h.mutex.Unlock()
if had0RTTKeys {
h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys})
}
}
func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
suite := getCipherSuite(suiteID)
h.mutex.Lock()
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
switch el {
case qtls.QUICEncryptionLevelEarly:
if h.perspective == protocol.PerspectiveClient {
panic("Received 0-RTT read key for the client")
}
h.zeroRTTOpener = newLongHeaderOpener(
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
)
h.used0RTT.Store(true)
if h.logger.Debug() {
h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
}
case qtls.QUICEncryptionLevelHandshake:
h.handshakeOpener = newLongHeaderOpener(
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
)
if h.logger.Debug() {
h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
}
case qtls.QUICEncryptionLevelApplication:
h.aead.SetReadKey(suite, trafficSecret)
h.has1RTTOpener = true
if h.logger.Debug() {
h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
}
default:
panic("unexpected read encryption level")
}
h.mutex.Unlock()
h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
}
}
func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
suite := getCipherSuite(suiteID)
h.mutex.Lock()
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
switch el {
case qtls.QUICEncryptionLevelEarly:
if h.perspective == protocol.PerspectiveServer {
panic("Received 0-RTT write key for the server")
}
h.zeroRTTSealer = newLongHeaderSealer(
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
)
h.mutex.Unlock()
if h.logger.Debug() {
h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
}
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective)
}
// don't set used0RTT here. 0-RTT might still get rejected.
return
case qtls.QUICEncryptionLevelHandshake:
h.handshakeSealer = newLongHeaderSealer(
createAEAD(suite, trafficSecret, h.version),
newHeaderProtector(suite, trafficSecret, true, h.version),
)
if h.logger.Debug() {
h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
}
case qtls.QUICEncryptionLevelApplication:
h.aead.SetWriteKey(suite, trafficSecret)
h.has1RTTSealer = true
if h.logger.Debug() {
h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
}
if h.zeroRTTSealer != nil {
// Once we receive handshake keys, we know that 0-RTT was not rejected.
h.used0RTT.Store(true)
h.zeroRTTSealer = nil
h.logger.Debugf("Dropping 0-RTT keys.")
if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil {
h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
}
}
default:
panic("unexpected write encryption level")
}
h.mutex.Unlock()
if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
}
}
// WriteRecord is called when TLS writes data
func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) {
//nolint:exhaustive // handshake records can only be written for Initial and Handshake.
switch encLevel {
case qtls.QUICEncryptionLevelInitial:
h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p})
case qtls.QUICEncryptionLevelHandshake:
h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p})
case qtls.QUICEncryptionLevelApplication:
panic("unexpected write")
default:
panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel))
}
}
func (h *cryptoSetup) DiscardInitialKeys() {
h.mutex.Lock()
dropped := h.initialOpener != nil
h.initialOpener = nil
h.initialSealer = nil
h.mutex.Unlock()
if dropped {
h.logger.Debugf("Dropping Initial keys.")
}
}
func (h *cryptoSetup) handshakeComplete() {
h.handshakeCompleteTime = time.Now()
h.events = append(h.events, Event{Kind: EventHandshakeComplete})
}
func (h *cryptoSetup) SetHandshakeConfirmed() {
h.aead.SetHandshakeConfirmed()
// drop Handshake keys
var dropped bool
h.mutex.Lock()
if h.handshakeOpener != nil {
h.handshakeOpener = nil
h.handshakeSealer = nil
dropped = true
}
h.mutex.Unlock()
if dropped {
h.logger.Debugf("Dropping Handshake keys.")
}
}
func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.initialSealer == nil {
return nil, ErrKeysDropped
}
return h.initialSealer, nil
}
func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTSealer == nil {
return nil, ErrKeysDropped
}
return h.zeroRTTSealer, nil
}
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.handshakeSealer == nil {
if h.initialSealer == nil {
return nil, ErrKeysDropped
}
return nil, ErrKeysNotYetAvailable
}
return h.handshakeSealer, nil
}
func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if !h.has1RTTSealer {
return nil, ErrKeysNotYetAvailable
}
return h.aead, nil
}
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.initialOpener == nil {
return nil, ErrKeysDropped
}
return h.initialOpener, nil
}
func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTOpener == nil {
if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable
}
// if the initial opener is also not available, the keys were already dropped
return nil, ErrKeysDropped
}
return h.zeroRTTOpener, nil
}
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.handshakeOpener == nil {
if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable
}
// if the initial opener is also not available, the keys were already dropped
return nil, ErrKeysDropped
}
return h.handshakeOpener, nil
}
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
h.zeroRTTOpener = nil
h.logger.Debugf("Dropping 0-RTT keys.")
if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil {
h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
}
}
if !h.has1RTTOpener {
return nil, ErrKeysNotYetAvailable
}
return h.aead, nil
}
func (h *cryptoSetup) ConnectionState() ConnectionState {
return ConnectionState{
ConnectionState: h.conn.ConnectionState(),
Used0RTT: h.used0RTT.Load(),
}
}
func wrapError(err error) error {
// alert 80 is an internal error
if alertErr := qtls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 {
return qerr.NewLocalCryptoError(uint8(alertErr), err)
}
return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()}
}

View File

@@ -0,0 +1,135 @@
package handshake
import (
"crypto/aes"
"crypto/cipher"
"crypto/tls"
"encoding/binary"
"fmt"
"golang.org/x/crypto/chacha20"
"github.com/quic-go/quic-go/internal/protocol"
)
type headerProtector interface {
EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
}
func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string {
if v == protocol.Version2 {
return "quicv2 hp"
}
return "quic hp"
}
func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector {
hkdfLabel := hkdfHeaderProtectionLabel(v)
switch suite.ID {
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
return newAESHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel)
case tls.TLS_CHACHA20_POLY1305_SHA256:
return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel)
default:
panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID))
}
}
type aesHeaderProtector struct {
mask []byte
block cipher.Block
isLongHeader bool
}
var _ headerProtector = &aesHeaderProtector{}
func newAESHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
block, err := aes.NewCipher(hpKey)
if err != nil {
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
return &aesHeaderProtector{
block: block,
mask: make([]byte, block.BlockSize()),
isLongHeader: isLongHeader,
}
}
func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
if len(sample) != len(p.mask) {
panic("invalid sample size")
}
p.block.Encrypt(p.mask, sample)
if p.isLongHeader {
*firstByte ^= p.mask[0] & 0xf
} else {
*firstByte ^= p.mask[0] & 0x1f
}
for i := range hdrBytes {
hdrBytes[i] ^= p.mask[i+1]
}
}
type chachaHeaderProtector struct {
mask [5]byte
key [32]byte
isLongHeader bool
}
var _ headerProtector = &chachaHeaderProtector{}
func newChaChaHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector {
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen)
p := &chachaHeaderProtector{
isLongHeader: isLongHeader,
}
copy(p.key[:], hpKey)
return p
}
func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
if len(sample) != 16 {
panic("invalid sample size")
}
for i := 0; i < 5; i++ {
p.mask[i] = 0
}
cipher, err := chacha20.NewUnauthenticatedCipher(p.key[:], sample[4:])
if err != nil {
panic(err)
}
cipher.SetCounter(binary.LittleEndian.Uint32(sample[:4]))
cipher.XORKeyStream(p.mask[:], p.mask[:])
p.applyMask(firstByte, hdrBytes)
}
func (p *chachaHeaderProtector) applyMask(firstByte *byte, hdrBytes []byte) {
if p.isLongHeader {
*firstByte ^= p.mask[0] & 0xf
} else {
*firstByte ^= p.mask[0] & 0x1f
}
for i := range hdrBytes {
hdrBytes[i] ^= p.mask[i+1]
}
}

View File

@@ -0,0 +1,29 @@
package handshake
import (
"crypto"
"encoding/binary"
"golang.org/x/crypto/hkdf"
)
// hkdfExpandLabel HKDF expands a label.
// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the
// hkdfExpandLabel in the standard library.
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
b := make([]byte, 3, 3+6+len(label)+1+len(context))
binary.BigEndian.PutUint16(b, uint16(length))
b[2] = uint8(6 + len(label))
b = append(b, []byte("tls13 ")...)
b = append(b, []byte(label)...)
b = b[:3+6+len(label)+1]
b[3+6+len(label)] = uint8(len(context))
b = append(b, context...)
out := make([]byte, length)
n, err := hkdf.Expand(hash.New, secret, b).Read(out)
if err != nil || n != length {
panic("quic: HKDF-Expand-Label invocation failed unexpectedly")
}
return out
}

Some files were not shown because too many files have changed in this diff Show More