直播:后台 JWT 推流、前台画中画;WebRTC 服务与 Nginx WebSocket 代理

Made-with: Cursor
This commit is contained in:
whm
2026-03-25 15:00:14 +08:00
parent b83ec91b1a
commit 7811adca66
1050 changed files with 146524 additions and 37 deletions

View File

@@ -0,0 +1,286 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package allocation contains all CRUD operations for allocations
package allocation
import (
"net"
"sync"
"sync/atomic"
"time"
"github.com/pion/logging"
"github.com/pion/stun"
"github.com/pion/turn/v2/internal/ipnet"
"github.com/pion/turn/v2/internal/proto"
)
type allocationResponse struct {
transactionID [stun.TransactionIDSize]byte
responseAttrs []stun.Setter
}
// Allocation is tied to a FiveTuple and relays traffic
// use CreateAllocation and GetAllocation to operate
type Allocation struct {
RelayAddr net.Addr
Protocol Protocol
TurnSocket net.PacketConn
RelaySocket net.PacketConn
fiveTuple *FiveTuple
permissionsLock sync.RWMutex
permissions map[string]*Permission
channelBindingsLock sync.RWMutex
channelBindings []*ChannelBind
lifetimeTimer *time.Timer
closed chan interface{}
log logging.LeveledLogger
// Some clients (Firefox or others using resiprocate's nICE lib) may retry allocation
// with same 5 tuple when received 413, for compatible with these clients,
// cache for response lost and client retry to implement 'stateless stack approach'
// See: https://datatracker.ietf.org/doc/html/rfc5766#section-6.2
responseCache atomic.Value // *allocationResponse
}
// NewAllocation creates a new instance of NewAllocation.
func NewAllocation(turnSocket net.PacketConn, fiveTuple *FiveTuple, log logging.LeveledLogger) *Allocation {
return &Allocation{
TurnSocket: turnSocket,
fiveTuple: fiveTuple,
permissions: make(map[string]*Permission, 64),
closed: make(chan interface{}),
log: log,
}
}
// GetPermission gets the Permission from the allocation
func (a *Allocation) GetPermission(addr net.Addr) *Permission {
a.permissionsLock.RLock()
defer a.permissionsLock.RUnlock()
return a.permissions[ipnet.FingerprintAddr(addr)]
}
// AddPermission adds a new permission to the allocation
func (a *Allocation) AddPermission(p *Permission) {
fingerprint := ipnet.FingerprintAddr(p.Addr)
a.permissionsLock.RLock()
existedPermission, ok := a.permissions[fingerprint]
a.permissionsLock.RUnlock()
if ok {
existedPermission.refresh(permissionTimeout)
return
}
p.allocation = a
a.permissionsLock.Lock()
a.permissions[fingerprint] = p
a.permissionsLock.Unlock()
p.start(permissionTimeout)
}
// RemovePermission removes the net.Addr's fingerprint from the allocation's permissions
func (a *Allocation) RemovePermission(addr net.Addr) {
a.permissionsLock.Lock()
defer a.permissionsLock.Unlock()
delete(a.permissions, ipnet.FingerprintAddr(addr))
}
// AddChannelBind adds a new ChannelBind to the allocation, it also updates the
// permissions needed for this ChannelBind
func (a *Allocation) AddChannelBind(c *ChannelBind, lifetime time.Duration) error {
// Check that this channel id isn't bound to another transport address, and
// that this transport address isn't bound to another channel number.
channelByNumber := a.GetChannelByNumber(c.Number)
if channelByNumber != a.GetChannelByAddr(c.Peer) {
return errSameChannelDifferentPeer
}
// Add or refresh this channel.
if channelByNumber == nil {
a.channelBindingsLock.Lock()
defer a.channelBindingsLock.Unlock()
c.allocation = a
a.channelBindings = append(a.channelBindings, c)
c.start(lifetime)
// Channel binds also refresh permissions.
a.AddPermission(NewPermission(c.Peer, a.log))
} else {
channelByNumber.refresh(lifetime)
// Channel binds also refresh permissions.
a.AddPermission(NewPermission(channelByNumber.Peer, a.log))
}
return nil
}
// RemoveChannelBind removes the ChannelBind from this allocation by id
func (a *Allocation) RemoveChannelBind(number proto.ChannelNumber) bool {
a.channelBindingsLock.Lock()
defer a.channelBindingsLock.Unlock()
for i := len(a.channelBindings) - 1; i >= 0; i-- {
if a.channelBindings[i].Number == number {
a.channelBindings = append(a.channelBindings[:i], a.channelBindings[i+1:]...)
return true
}
}
return false
}
// GetChannelByNumber gets the ChannelBind from this allocation by id
func (a *Allocation) GetChannelByNumber(number proto.ChannelNumber) *ChannelBind {
a.channelBindingsLock.RLock()
defer a.channelBindingsLock.RUnlock()
for _, cb := range a.channelBindings {
if cb.Number == number {
return cb
}
}
return nil
}
// GetChannelByAddr gets the ChannelBind from this allocation by net.Addr
func (a *Allocation) GetChannelByAddr(addr net.Addr) *ChannelBind {
a.channelBindingsLock.RLock()
defer a.channelBindingsLock.RUnlock()
for _, cb := range a.channelBindings {
if ipnet.AddrEqual(cb.Peer, addr) {
return cb
}
}
return nil
}
// Refresh updates the allocations lifetime
func (a *Allocation) Refresh(lifetime time.Duration) {
if !a.lifetimeTimer.Reset(lifetime) {
a.log.Errorf("Failed to reset allocation timer for %v", a.fiveTuple)
}
}
// SetResponseCache cache allocation response for retransmit allocation request
func (a *Allocation) SetResponseCache(transactionID [stun.TransactionIDSize]byte, attrs []stun.Setter) {
a.responseCache.Store(&allocationResponse{
transactionID: transactionID,
responseAttrs: attrs,
})
}
// GetResponseCache return response cache for retransmit allocation request
func (a *Allocation) GetResponseCache() (id [stun.TransactionIDSize]byte, attrs []stun.Setter) {
if res, ok := a.responseCache.Load().(*allocationResponse); ok && res != nil {
id, attrs = res.transactionID, res.responseAttrs
}
return
}
// Close closes the allocation
func (a *Allocation) Close() error {
select {
case <-a.closed:
return nil
default:
}
close(a.closed)
a.lifetimeTimer.Stop()
a.permissionsLock.RLock()
for _, p := range a.permissions {
p.lifetimeTimer.Stop()
}
a.permissionsLock.RUnlock()
a.channelBindingsLock.RLock()
for _, c := range a.channelBindings {
c.lifetimeTimer.Stop()
}
a.channelBindingsLock.RUnlock()
return a.RelaySocket.Close()
}
// https://tools.ietf.org/html/rfc5766#section-10.3
// When the server receives a UDP datagram at a currently allocated
// relayed transport address, the server looks up the allocation
// associated with the relayed transport address. The server then
// checks to see whether the set of permissions for the allocation allow
// the relaying of the UDP datagram as described in Section 8.
//
// If relaying is permitted, then the server checks if there is a
// channel bound to the peer that sent the UDP datagram (see
// Section 11). If a channel is bound, then processing proceeds as
// described in Section 11.7.
//
// If relaying is permitted but no channel is bound to the peer, then
// the server forms and sends a Data indication. The Data indication
// MUST contain both an XOR-PEER-ADDRESS and a DATA attribute. The DATA
// attribute is set to the value of the 'data octets' field from the
// datagram, and the XOR-PEER-ADDRESS attribute is set to the source
// transport address of the received UDP datagram. The Data indication
// is then sent on the 5-tuple associated with the allocation.
const rtpMTU = 1600
func (a *Allocation) packetHandler(m *Manager) {
buffer := make([]byte, rtpMTU)
for {
n, srcAddr, err := a.RelaySocket.ReadFrom(buffer)
if err != nil {
m.DeleteAllocation(a.fiveTuple)
return
}
a.log.Debugf("Relay socket %s received %d bytes from %s",
a.RelaySocket.LocalAddr(),
n,
srcAddr)
if channel := a.GetChannelByAddr(srcAddr); channel != nil {
channelData := &proto.ChannelData{
Data: buffer[:n],
Number: channel.Number,
}
channelData.Encode()
if _, err = a.TurnSocket.WriteTo(channelData.Raw, a.fiveTuple.SrcAddr); err != nil {
a.log.Errorf("Failed to send ChannelData from allocation %v %v", srcAddr, err)
}
} else if p := a.GetPermission(srcAddr); p != nil {
udpAddr, ok := srcAddr.(*net.UDPAddr)
if !ok {
a.log.Errorf("Failed to send DataIndication from allocation %v %v", srcAddr, err)
return
}
peerAddressAttr := proto.PeerAddress{IP: udpAddr.IP, Port: udpAddr.Port}
dataAttr := proto.Data(buffer[:n])
msg, err := stun.Build(stun.TransactionID, stun.NewType(stun.MethodData, stun.ClassIndication), peerAddressAttr, dataAttr)
if err != nil {
a.log.Errorf("Failed to send DataIndication from allocation %v %v", srcAddr, err)
return
}
a.log.Debugf("Relaying message from %s to client at %s",
srcAddr,
a.fiveTuple.SrcAddr)
if _, err = a.TurnSocket.WriteTo(msg.Raw, a.fiveTuple.SrcAddr); err != nil {
a.log.Errorf("Failed to send DataIndication from allocation %v %v", srcAddr, err)
}
} else {
a.log.Infof("No Permission or Channel exists for %v on allocation %v", srcAddr, a.RelayAddr)
}
}
}

View File

@@ -0,0 +1,218 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package allocation
import (
"fmt"
"net"
"sync"
"time"
"github.com/pion/logging"
)
// ManagerConfig a bag of config params for Manager.
type ManagerConfig struct {
LeveledLogger logging.LeveledLogger
AllocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error)
AllocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error)
PermissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool
}
type reservation struct {
token string
port int
}
// Manager is used to hold active allocations
type Manager struct {
lock sync.RWMutex
log logging.LeveledLogger
allocations map[FiveTupleFingerprint]*Allocation
reservations []*reservation
allocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error)
allocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error)
permissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool
}
// NewManager creates a new instance of Manager.
func NewManager(config ManagerConfig) (*Manager, error) {
switch {
case config.AllocatePacketConn == nil:
return nil, errAllocatePacketConnMustBeSet
case config.AllocateConn == nil:
return nil, errAllocateConnMustBeSet
case config.LeveledLogger == nil:
return nil, errLeveledLoggerMustBeSet
}
return &Manager{
log: config.LeveledLogger,
allocations: make(map[FiveTupleFingerprint]*Allocation, 64),
allocatePacketConn: config.AllocatePacketConn,
allocateConn: config.AllocateConn,
permissionHandler: config.PermissionHandler,
}, nil
}
// GetAllocation fetches the allocation matching the passed FiveTuple
func (m *Manager) GetAllocation(fiveTuple *FiveTuple) *Allocation {
m.lock.RLock()
defer m.lock.RUnlock()
return m.allocations[fiveTuple.Fingerprint()]
}
// AllocationCount returns the number of existing allocations
func (m *Manager) AllocationCount() int {
m.lock.RLock()
defer m.lock.RUnlock()
return len(m.allocations)
}
// Close closes the manager and closes all allocations it manages
func (m *Manager) Close() error {
m.lock.Lock()
defer m.lock.Unlock()
for _, a := range m.allocations {
if err := a.Close(); err != nil {
return err
}
}
return nil
}
// CreateAllocation creates a new allocation and starts relaying
func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketConn, requestedPort int, lifetime time.Duration) (*Allocation, error) {
switch {
case fiveTuple == nil:
return nil, errNilFiveTuple
case fiveTuple.SrcAddr == nil:
return nil, errNilFiveTupleSrcAddr
case fiveTuple.DstAddr == nil:
return nil, errNilFiveTupleDstAddr
case turnSocket == nil:
return nil, errNilTurnSocket
case lifetime == 0:
return nil, errLifetimeZero
}
if a := m.GetAllocation(fiveTuple); a != nil {
return nil, fmt.Errorf("%w: %v", errDupeFiveTuple, fiveTuple)
}
a := NewAllocation(turnSocket, fiveTuple, m.log)
conn, relayAddr, err := m.allocatePacketConn("udp4", requestedPort)
if err != nil {
return nil, err
}
a.RelaySocket = conn
a.RelayAddr = relayAddr
m.log.Debugf("Listening on relay address: %s", a.RelayAddr)
a.lifetimeTimer = time.AfterFunc(lifetime, func() {
m.DeleteAllocation(a.fiveTuple)
})
m.lock.Lock()
m.allocations[fiveTuple.Fingerprint()] = a
m.lock.Unlock()
go a.packetHandler(m)
return a, nil
}
// DeleteAllocation removes an allocation
func (m *Manager) DeleteAllocation(fiveTuple *FiveTuple) {
fingerprint := fiveTuple.Fingerprint()
m.lock.Lock()
allocation := m.allocations[fingerprint]
delete(m.allocations, fingerprint)
m.lock.Unlock()
if allocation == nil {
return
}
if err := allocation.Close(); err != nil {
m.log.Errorf("Failed to close allocation: %v", err)
}
}
// CreateReservation stores the reservation for the token+port
func (m *Manager) CreateReservation(reservationToken string, port int) {
time.AfterFunc(30*time.Second, func() {
m.lock.Lock()
defer m.lock.Unlock()
for i := len(m.reservations) - 1; i >= 0; i-- {
if m.reservations[i].token == reservationToken {
m.reservations = append(m.reservations[:i], m.reservations[i+1:]...)
return
}
}
})
m.lock.Lock()
m.reservations = append(m.reservations, &reservation{
token: reservationToken,
port: port,
})
m.lock.Unlock()
}
// GetReservation returns the port for a given reservation if it exists
func (m *Manager) GetReservation(reservationToken string) (int, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
for _, r := range m.reservations {
if r.token == reservationToken {
return r.port, true
}
}
return 0, false
}
// GetRandomEvenPort returns a random un-allocated udp4 port
func (m *Manager) GetRandomEvenPort() (int, error) {
for i := 0; i < 128; i++ {
conn, addr, err := m.allocatePacketConn("udp4", 0)
if err != nil {
return 0, err
}
udpAddr, ok := addr.(*net.UDPAddr)
err = conn.Close()
if err != nil {
return 0, err
}
if !ok {
return 0, errFailedToCastUDPAddr
}
if udpAddr.Port%2 == 0 {
return udpAddr.Port, nil
}
}
return 0, errFailedToAllocateEvenPort
}
// GrantPermission handles permission requests by calling the permission handler callback
// associated with the TURN server listener socket
func (m *Manager) GrantPermission(sourceAddr net.Addr, peerIP net.IP) error {
// No permission handler: open
if m.permissionHandler == nil {
return nil
}
if m.permissionHandler(sourceAddr, peerIP) {
return nil
}
return errAdminProhibited
}

View File

@@ -0,0 +1,46 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package allocation
import (
"net"
"time"
"github.com/pion/logging"
"github.com/pion/turn/v2/internal/proto"
)
// ChannelBind represents a TURN Channel
// See: https://tools.ietf.org/html/rfc5766#section-2.5
type ChannelBind struct {
Peer net.Addr
Number proto.ChannelNumber
allocation *Allocation
lifetimeTimer *time.Timer
log logging.LeveledLogger
}
// NewChannelBind creates a new ChannelBind
func NewChannelBind(number proto.ChannelNumber, peer net.Addr, log logging.LeveledLogger) *ChannelBind {
return &ChannelBind{
Number: number,
Peer: peer,
log: log,
}
}
func (c *ChannelBind) start(lifetime time.Duration) {
c.lifetimeTimer = time.AfterFunc(lifetime, func() {
if !c.allocation.RemoveChannelBind(c.Number) {
c.log.Errorf("Failed to remove ChannelBind for %v %x %v", c.Number, c.Peer, c.allocation.fiveTuple)
}
})
}
func (c *ChannelBind) refresh(lifetime time.Duration) {
if !c.lifetimeTimer.Reset(lifetime) {
c.log.Errorf("Failed to reset ChannelBind timer for %v %x %v", c.Number, c.Peer, c.allocation.fiveTuple)
}
}

View File

@@ -0,0 +1,22 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package allocation
import "errors"
var (
errAllocatePacketConnMustBeSet = errors.New("AllocatePacketConn must be set")
errAllocateConnMustBeSet = errors.New("AllocateConn must be set")
errLeveledLoggerMustBeSet = errors.New("LeveledLogger must be set")
errSameChannelDifferentPeer = errors.New("you cannot use the same channel number with different peer")
errNilFiveTuple = errors.New("allocations must not be created with nil FivTuple")
errNilFiveTupleSrcAddr = errors.New("allocations must not be created with nil FiveTuple.SrcAddr")
errNilFiveTupleDstAddr = errors.New("allocations must not be created with nil FiveTuple.DstAddr")
errNilTurnSocket = errors.New("allocations must not be created with nil turnSocket")
errLifetimeZero = errors.New("allocations must not be created with a lifetime of 0")
errDupeFiveTuple = errors.New("allocation attempt created with duplicate FiveTuple")
errFailedToCastUDPAddr = errors.New("failed to cast net.Addr to *net.UDPAddr")
errFailedToAllocateEvenPort = errors.New("failed to allocate an even port")
errAdminProhibited = errors.New("permission request administratively prohibited")
)

View File

@@ -0,0 +1,63 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package allocation
import (
"net"
)
// Protocol is an enum for relay protocol
type Protocol uint8
// Network protocols for relay
const (
UDP Protocol = iota
TCP
)
// FiveTuple is the combination (client IP address and port, server IP
// address and port, and transport protocol (currently one of UDP,
// TCP, or TLS)) used to communicate between the client and the
// server. The 5-tuple uniquely identifies this communication
// stream. The 5-tuple also uniquely identifies the Allocation on
// the server.
type FiveTuple struct {
Protocol
SrcAddr, DstAddr net.Addr
}
// Equal asserts if two FiveTuples are equal
func (f *FiveTuple) Equal(b *FiveTuple) bool {
return f.Fingerprint() == b.Fingerprint()
}
// FiveTupleFingerprint is a comparable representation of a FiveTuple
type FiveTupleFingerprint struct {
srcIP, dstIP [16]byte
srcPort, dstPort uint16
protocol Protocol
}
// Fingerprint is the identity of a FiveTuple
func (f *FiveTuple) Fingerprint() (fp FiveTupleFingerprint) {
srcIP, srcPort := netAddrIPAndPort(f.SrcAddr)
copy(fp.srcIP[:], srcIP)
fp.srcPort = srcPort
dstIP, dstPort := netAddrIPAndPort(f.DstAddr)
copy(fp.dstIP[:], dstIP)
fp.dstPort = dstPort
fp.protocol = f.Protocol
return
}
func netAddrIPAndPort(addr net.Addr) (net.IP, uint16) {
switch a := addr.(type) {
case *net.UDPAddr:
return a.IP.To16(), uint16(a.Port)
case *net.TCPAddr:
return a.IP.To16(), uint16(a.Port)
default:
return nil, 0
}
}

View File

@@ -0,0 +1,43 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package allocation
import (
"net"
"time"
"github.com/pion/logging"
)
const permissionTimeout = time.Duration(5) * time.Minute
// Permission represents a TURN permission. TURN permissions mimic the address-restricted
// filtering mechanism of NATs that comply with [RFC4787].
// See: https://tools.ietf.org/html/rfc5766#section-2.3
type Permission struct {
Addr net.Addr
allocation *Allocation
lifetimeTimer *time.Timer
log logging.LeveledLogger
}
// NewPermission create a new Permission
func NewPermission(addr net.Addr, log logging.LeveledLogger) *Permission {
return &Permission{
Addr: addr,
log: log,
}
}
func (p *Permission) start(lifetime time.Duration) {
p.lifetimeTimer = time.AfterFunc(lifetime, func() {
p.allocation.RemovePermission(p.Addr)
})
}
func (p *Permission) refresh(lifetime time.Duration) {
if !p.lifetimeTimer.Reset(lifetime) {
p.log.Errorf("Failed to reset permission timer for %v %v", p.Addr, p.allocation.fiveTuple)
}
}

View File

@@ -0,0 +1,189 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package client
import (
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/pion/logging"
"github.com/pion/stun"
"github.com/pion/transport/v2"
"github.com/pion/turn/v2/internal/proto"
)
// AllocationConfig is a set of configuration params use by NewUDPConn and NewTCPAllocation
type AllocationConfig struct {
Client Client
RelayedAddr net.Addr
ServerAddr net.Addr
Integrity stun.MessageIntegrity
Nonce stun.Nonce
Username stun.Username
Realm stun.Realm
Lifetime time.Duration
Net transport.Net
Log logging.LeveledLogger
}
type allocation struct {
client Client // Read-only
relayedAddr net.Addr // Read-only
serverAddr net.Addr // Read-only
permMap *permissionMap // Thread-safe
integrity stun.MessageIntegrity // Read-only
username stun.Username // Read-only
realm stun.Realm // Read-only
_nonce stun.Nonce // Needs mutex x
_lifetime time.Duration // Needs mutex x
net transport.Net // Thread-safe
refreshAllocTimer *PeriodicTimer // Thread-safe
refreshPermsTimer *PeriodicTimer // Thread-safe
readTimer *time.Timer // Thread-safe
mutex sync.RWMutex // Thread-safe
log logging.LeveledLogger // Read-only
}
func (a *allocation) setNonceFromMsg(msg *stun.Message) {
// Update nonce
var nonce stun.Nonce
if err := nonce.GetFrom(msg); err == nil {
a.setNonce(nonce)
a.log.Debug("Refresh allocation: 438, got new nonce.")
} else {
a.log.Warn("Refresh allocation: 438 but no nonce.")
}
}
func (a *allocation) refreshAllocation(lifetime time.Duration, dontWait bool) error {
msg, err := stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodRefresh, stun.ClassRequest),
proto.Lifetime{Duration: lifetime},
a.username,
a.realm,
a.nonce(),
a.integrity,
stun.Fingerprint,
)
if err != nil {
return fmt.Errorf("%w: %s", errFailedToBuildRefreshRequest, err.Error())
}
a.log.Debugf("Send refresh request (dontWait=%v)", dontWait)
trRes, err := a.client.PerformTransaction(msg, a.serverAddr, dontWait)
if err != nil {
return fmt.Errorf("%w: %s", errFailedToRefreshAllocation, err.Error())
}
if dontWait {
a.log.Debug("Refresh request sent")
return nil
}
a.log.Debug("Refresh request sent, and waiting response")
res := trRes.Msg
if res.Type.Class == stun.ClassErrorResponse {
var code stun.ErrorCodeAttribute
if err = code.GetFrom(res); err == nil {
if code.Code == stun.CodeStaleNonce {
a.setNonceFromMsg(res)
return errTryAgain
}
return err
}
return fmt.Errorf("%s", res.Type) //nolint:goerr113
}
// Getting lifetime from response
var updatedLifetime proto.Lifetime
if err := updatedLifetime.GetFrom(res); err != nil {
return fmt.Errorf("%w: %s", errFailedToGetLifetime, err.Error())
}
a.setLifetime(updatedLifetime.Duration)
a.log.Debugf("Updated lifetime: %d seconds", int(a.lifetime().Seconds()))
return nil
}
func (a *allocation) refreshPermissions() error {
addrs := a.permMap.addrs()
if len(addrs) == 0 {
a.log.Debug("No permission to refresh")
return nil
}
if err := a.CreatePermissions(addrs...); err != nil {
if errors.Is(err, errTryAgain) {
return errTryAgain
}
a.log.Errorf("Fail to refresh permissions: %s", err)
return err
}
a.log.Debug("Refresh permissions successful")
return nil
}
func (a *allocation) onRefreshTimers(id int) {
a.log.Debugf("Refresh timer %d expired", id)
switch id {
case timerIDRefreshAlloc:
var err error
lifetime := a.lifetime()
// Limit the max retries on errTryAgain to 3
// when stale nonce returns, sencond retry should succeed
for i := 0; i < maxRetryAttempts; i++ {
err = a.refreshAllocation(lifetime, false)
if !errors.Is(err, errTryAgain) {
break
}
}
if err != nil {
a.log.Warnf("Failed to refresh allocation: %s", err)
}
case timerIDRefreshPerms:
var err error
for i := 0; i < maxRetryAttempts; i++ {
err = a.refreshPermissions()
if !errors.Is(err, errTryAgain) {
break
}
}
if err != nil {
a.log.Warnf("Failed to refresh permissions: %s", err)
}
}
}
func (a *allocation) nonce() stun.Nonce {
a.mutex.RLock()
defer a.mutex.RUnlock()
return a._nonce
}
func (a *allocation) setNonce(nonce stun.Nonce) {
a.mutex.Lock()
defer a.mutex.Unlock()
a.log.Debugf("Set new nonce with %d bytes", len(nonce))
a._nonce = nonce
}
func (a *allocation) lifetime() time.Duration {
a.mutex.RLock()
defer a.mutex.RUnlock()
return a._lifetime
}
func (a *allocation) setLifetime(lifetime time.Duration) {
a.mutex.Lock()
defer a.mutex.Unlock()
a._lifetime = lifetime
}

View File

@@ -0,0 +1,155 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package client
import (
"net"
"sync"
"sync/atomic"
"time"
)
// Channel number:
//
// 0x4000 through 0x7FFF: These values are the allowed channel
// numbers (16,383 possible values).
const (
minChannelNumber uint16 = 0x4000
maxChannelNumber uint16 = 0x7fff
)
type bindingState int32
const (
bindingStateIdle bindingState = iota
bindingStateRequest
bindingStateReady
bindingStateRefresh
bindingStateFailed
)
type binding struct {
number uint16 // Read-only
st bindingState // Thread-safe (atomic op)
addr net.Addr // Read-only
mgr *bindingManager // Read-only
muBind sync.Mutex // Thread-safe, for ChannelBind ops
_refreshedAt time.Time // Protected by mutex
mutex sync.RWMutex // Thread-safe
}
func (b *binding) setState(state bindingState) {
atomic.StoreInt32((*int32)(&b.st), int32(state))
}
func (b *binding) state() bindingState {
return bindingState(atomic.LoadInt32((*int32)(&b.st)))
}
func (b *binding) setRefreshedAt(at time.Time) {
b.mutex.Lock()
defer b.mutex.Unlock()
b._refreshedAt = at
}
func (b *binding) refreshedAt() time.Time {
b.mutex.RLock()
defer b.mutex.RUnlock()
return b._refreshedAt
}
// Thread-safe binding map
type bindingManager struct {
chanMap map[uint16]*binding
addrMap map[string]*binding
next uint16
mutex sync.RWMutex
}
func newBindingManager() *bindingManager {
return &bindingManager{
chanMap: map[uint16]*binding{},
addrMap: map[string]*binding{},
next: minChannelNumber,
}
}
func (mgr *bindingManager) assignChannelNumber() uint16 {
n := mgr.next
if mgr.next == maxChannelNumber {
mgr.next = minChannelNumber
} else {
mgr.next++
}
return n
}
func (mgr *bindingManager) create(addr net.Addr) *binding {
mgr.mutex.Lock()
defer mgr.mutex.Unlock()
b := &binding{
number: mgr.assignChannelNumber(),
addr: addr,
mgr: mgr,
_refreshedAt: time.Now(),
}
mgr.chanMap[b.number] = b
mgr.addrMap[b.addr.String()] = b
return b
}
func (mgr *bindingManager) findByAddr(addr net.Addr) (*binding, bool) {
mgr.mutex.RLock()
defer mgr.mutex.RUnlock()
b, ok := mgr.addrMap[addr.String()]
return b, ok
}
func (mgr *bindingManager) findByNumber(number uint16) (*binding, bool) {
mgr.mutex.RLock()
defer mgr.mutex.RUnlock()
b, ok := mgr.chanMap[number]
return b, ok
}
func (mgr *bindingManager) deleteByAddr(addr net.Addr) bool {
mgr.mutex.Lock()
defer mgr.mutex.Unlock()
b, ok := mgr.addrMap[addr.String()]
if !ok {
return false
}
delete(mgr.addrMap, addr.String())
delete(mgr.chanMap, b.number)
return true
}
func (mgr *bindingManager) deleteByNumber(number uint16) bool {
mgr.mutex.Lock()
defer mgr.mutex.Unlock()
b, ok := mgr.chanMap[number]
if !ok {
return false
}
delete(mgr.addrMap, b.addr.String())
delete(mgr.chanMap, number)
return true
}
func (mgr *bindingManager) size() int {
mgr.mutex.RLock()
defer mgr.mutex.RUnlock()
return len(mgr.chanMap)
}

View File

@@ -0,0 +1,18 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package client implements the API for a TURN client
package client
import (
"net"
"github.com/pion/stun"
)
// Client is an interface for the public turn.Client in order to break cyclic dependencies
type Client interface {
WriteTo(data []byte, to net.Addr) (int, error)
PerformTransaction(msg *stun.Message, to net.Addr, dontWait bool) (TransactionResult, error)
OnDeallocated(relayedAddr net.Addr)
}

View File

@@ -0,0 +1,43 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package client
import (
"errors"
)
var (
errFake = errors.New("fake error")
errTryAgain = errors.New("try again")
errClosed = errors.New("use of closed network connection")
errTCPAddrCast = errors.New("addr is not a TCP address")
errUDPAddrCast = errors.New("addr is not a UDP address")
errAlreadyClosed = errors.New("already closed")
errDoubleLock = errors.New("try-lock is already locked")
errTransactionClosed = errors.New("transaction closed")
errWaitForResultOnNonResultTransaction = errors.New("WaitForResult called on non-result transaction")
errFailedToBuildRefreshRequest = errors.New("failed to build refresh request")
errFailedToRefreshAllocation = errors.New("failed to refresh allocation")
errFailedToGetLifetime = errors.New("failed to get lifetime from refresh response")
errInvalidTURNAddress = errors.New("invalid TURN server address")
errUnexpectedSTUNRequestMessage = errors.New("unexpected STUN request message")
)
type timeoutError struct {
msg string
}
func newTimeoutError(msg string) error {
return &timeoutError{
msg: msg,
}
}
func (e *timeoutError) Error() string {
return e.msg
}
func (e *timeoutError) Timeout() bool {
return true
}

View File

@@ -0,0 +1,85 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package client
import (
"sync"
"time"
)
// PeriodicTimerTimeoutHandler is a handler called on timeout
type PeriodicTimerTimeoutHandler func(timerID int)
// PeriodicTimer is a periodic timer
type PeriodicTimer struct {
id int
interval time.Duration
timeoutHandler PeriodicTimerTimeoutHandler
stopFunc func()
mutex sync.RWMutex
}
// NewPeriodicTimer create a new timer
func NewPeriodicTimer(id int, timeoutHandler PeriodicTimerTimeoutHandler, interval time.Duration) *PeriodicTimer {
return &PeriodicTimer{
id: id,
interval: interval,
timeoutHandler: timeoutHandler,
}
}
// Start starts the timer.
func (t *PeriodicTimer) Start() bool {
t.mutex.Lock()
defer t.mutex.Unlock()
// This is a noop if the timer is always running
if t.stopFunc != nil {
return false
}
cancelCh := make(chan struct{})
go func() {
canceling := false
for !canceling {
timer := time.NewTimer(t.interval)
select {
case <-timer.C:
t.timeoutHandler(t.id)
case <-cancelCh:
canceling = true
timer.Stop()
}
}
}()
t.stopFunc = func() {
close(cancelCh)
}
return true
}
// Stop stops the timer.
func (t *PeriodicTimer) Stop() {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.stopFunc != nil {
t.stopFunc()
t.stopFunc = nil
}
}
// IsRunning tests if the timer is running.
// Debug purpose only
func (t *PeriodicTimer) IsRunning() bool {
t.mutex.RLock()
defer t.mutex.RUnlock()
return (t.stopFunc != nil)
}

View File

@@ -0,0 +1,77 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package client
import (
"net"
"sync"
"sync/atomic"
"github.com/pion/turn/v2/internal/ipnet"
)
type permState int32
const (
permStateIdle permState = iota
permStatePermitted
)
type permission struct {
addr net.Addr
st permState // Thread-safe (atomic op)
mutex sync.RWMutex // Thread-safe
}
func (p *permission) setState(state permState) {
atomic.StoreInt32((*int32)(&p.st), int32(state))
}
func (p *permission) state() permState {
return permState(atomic.LoadInt32((*int32)(&p.st)))
}
// Thread-safe permission map
type permissionMap struct {
permMap map[string]*permission
mutex sync.RWMutex
}
func (m *permissionMap) insert(addr net.Addr, p *permission) bool {
m.mutex.Lock()
defer m.mutex.Unlock()
p.addr = addr
m.permMap[ipnet.FingerprintAddr(addr)] = p
return true
}
func (m *permissionMap) find(addr net.Addr) (*permission, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
p, ok := m.permMap[ipnet.FingerprintAddr(addr)]
return p, ok
}
func (m *permissionMap) delete(addr net.Addr) {
m.mutex.Lock()
defer m.mutex.Unlock()
delete(m.permMap, ipnet.FingerprintAddr(addr))
}
func (m *permissionMap) addrs() []net.Addr {
m.mutex.RLock()
defer m.mutex.RUnlock()
addrs := []net.Addr{}
for _, p := range m.permMap {
addrs = append(addrs, p.addr)
}
return addrs
}
func newPermissionMap() *permissionMap {
return &permissionMap{
permMap: map[string]*permission{},
}
}

View File

@@ -0,0 +1,371 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package client
import (
"encoding/binary"
"errors"
"fmt"
"math"
"net"
"time"
"github.com/pion/stun"
"github.com/pion/transport/v2"
"github.com/pion/turn/v2/internal/proto"
)
var (
_ transport.TCPListener = (*TCPAllocation)(nil) // Includes type check for net.Listener
_ transport.Dialer = (*TCPAllocation)(nil)
)
func noDeadline() time.Time {
return time.Time{}
}
// TCPAllocation is an active TCP allocation on the TURN server
// as specified by RFC 6062.
// The allocation can be used to Dial/Accept relayed outgoing/incoming TCP connections.
type TCPAllocation struct {
connAttemptCh chan *connectionAttempt
acceptTimer *time.Timer
allocation
}
// NewTCPAllocation creates a new instance of TCPConn
func NewTCPAllocation(config *AllocationConfig) *TCPAllocation {
a := &TCPAllocation{
connAttemptCh: make(chan *connectionAttempt, 10),
acceptTimer: time.NewTimer(time.Duration(math.MaxInt64)),
allocation: allocation{
client: config.Client,
relayedAddr: config.RelayedAddr,
serverAddr: config.ServerAddr,
username: config.Username,
realm: config.Realm,
permMap: newPermissionMap(),
integrity: config.Integrity,
_nonce: config.Nonce,
_lifetime: config.Lifetime,
net: config.Net,
log: config.Log,
},
}
a.log.Debugf("Initial lifetime: %d seconds", int(a.lifetime().Seconds()))
a.refreshAllocTimer = NewPeriodicTimer(
timerIDRefreshAlloc,
a.onRefreshTimers,
a.lifetime()/2,
)
a.refreshPermsTimer = NewPeriodicTimer(
timerIDRefreshPerms,
a.onRefreshTimers,
permRefreshInterval,
)
if a.refreshAllocTimer.Start() {
a.log.Debug("Started refreshAllocTimer")
}
if a.refreshPermsTimer.Start() {
a.log.Debug("Started refreshPermsTimer")
}
return a
}
// Connect sends a Connect request to the turn server and returns a chosen connection ID
func (a *TCPAllocation) Connect(peer net.Addr) (proto.ConnectionID, error) {
setters := []stun.Setter{
stun.TransactionID,
stun.NewType(stun.MethodConnect, stun.ClassRequest),
addr2PeerAddress(peer),
a.username,
a.realm,
a.nonce(),
a.integrity,
stun.Fingerprint,
}
msg, err := stun.Build(setters...)
if err != nil {
return 0, err
}
a.log.Debugf("Send connect request (peer=%v)", peer)
trRes, err := a.client.PerformTransaction(msg, a.serverAddr, false)
if err != nil {
return 0, err
}
res := trRes.Msg
if res.Type.Class == stun.ClassErrorResponse {
var code stun.ErrorCodeAttribute
if err = code.GetFrom(res); err == nil {
return 0, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
}
return 0, fmt.Errorf("%s", res.Type) //nolint:goerr113
}
var cid proto.ConnectionID
if err := cid.GetFrom(res); err != nil {
return 0, err
}
a.log.Debugf("Connect request successful (cid=%v)", cid)
return cid, nil
}
// Dial connects to the address on the named network.
func (a *TCPAllocation) Dial(network, rAddrStr string) (net.Conn, error) {
rAddr, err := net.ResolveTCPAddr(network, rAddrStr)
if err != nil {
return nil, err
}
return a.DialTCP(network, nil, rAddr)
}
// DialWithConn connects to the address on the named network with an already existing connection.
// The provided connection must be an already connected TCP connection to the TURN server.
func (a *TCPAllocation) DialWithConn(conn net.Conn, network, rAddrStr string) (*TCPConn, error) {
rAddr, err := net.ResolveTCPAddr(network, rAddrStr)
if err != nil {
return nil, err
}
return a.DialTCPWithConn(conn, network, rAddr)
}
// DialTCP acts like Dial for TCP networks.
func (a *TCPAllocation) DialTCP(network string, lAddr, rAddr *net.TCPAddr) (*TCPConn, error) {
var rAddrServer *net.TCPAddr
if addr, ok := a.serverAddr.(*net.TCPAddr); ok {
rAddrServer = &net.TCPAddr{
IP: addr.IP,
Port: addr.Port,
}
} else {
return nil, errInvalidTURNAddress
}
conn, err := a.net.DialTCP(network, lAddr, rAddrServer)
if err != nil {
return nil, err
}
dataConn, err := a.DialTCPWithConn(conn, network, rAddr)
if err != nil {
conn.Close() //nolint:errcheck,gosec
}
return dataConn, err
}
// DialTCPWithConn acts like DialWithConn for TCP networks.
func (a *TCPAllocation) DialTCPWithConn(conn net.Conn, _ string, rAddr *net.TCPAddr) (*TCPConn, error) {
var err error
// Check if we have a permission for the destination IP addr
perm, ok := a.permMap.find(rAddr)
if !ok {
perm = &permission{}
a.permMap.insert(rAddr, perm)
}
for i := 0; i < maxRetryAttempts; i++ {
if err = a.createPermission(perm, rAddr); !errors.Is(err, errTryAgain) {
break
}
}
if err != nil {
return nil, err
}
// Send connect request if haven't done so.
cid, err := a.Connect(rAddr)
if err != nil {
return nil, err
}
tcpConn, ok := conn.(transport.TCPConn)
if !ok {
return nil, errTCPAddrCast
}
dataConn := &TCPConn{
TCPConn: tcpConn,
ConnectionID: cid,
remoteAddress: rAddr,
allocation: a,
}
if err := a.BindConnection(dataConn, cid); err != nil {
return nil, fmt.Errorf("failed to bind connection: %w", err)
}
return dataConn, nil
}
// BindConnection associates the provided connection
func (a *TCPAllocation) BindConnection(dataConn *TCPConn, cid proto.ConnectionID) error {
msg, err := stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodConnectionBind, stun.ClassRequest),
cid,
a.username,
a.realm,
a.nonce(),
a.integrity,
stun.Fingerprint,
)
if err != nil {
return err
}
a.log.Debugf("Send connectionBind request (cid=%v)", cid)
_, err = dataConn.Write(msg.Raw)
if err != nil {
return err
}
// Read exactly one STUN message, any data after belongs to the user
b := make([]byte, stunHeaderSize)
n, err := dataConn.Read(b)
if n != stunHeaderSize {
return errIncompleteTURNFrame
} else if err != nil {
return err
}
if !stun.IsMessage(b) {
return errInvalidTURNFrame
}
datagramSize := binary.BigEndian.Uint16(b[2:4]) + stunHeaderSize
raw := make([]byte, datagramSize)
copy(raw, b)
_, err = dataConn.Read(raw[stunHeaderSize:])
if err != nil {
return err
}
res := &stun.Message{Raw: raw}
if err = res.Decode(); err != nil {
return fmt.Errorf("failed to decode STUN message: %w", err)
}
switch res.Type.Class {
case stun.ClassErrorResponse:
var code stun.ErrorCodeAttribute
if err = code.GetFrom(res); err == nil {
return fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
}
return fmt.Errorf("%s", res.Type) //nolint:goerr113
case stun.ClassSuccessResponse:
a.log.Debug("Successful connectionBind request")
return nil
default:
return fmt.Errorf("%w: %s", errUnexpectedSTUNRequestMessage, res.String())
}
}
// Accept waits for and returns the next connection to the listener.
func (a *TCPAllocation) Accept() (net.Conn, error) {
return a.AcceptTCP()
}
// AcceptTCP accepts the next incoming call and returns the new connection.
func (a *TCPAllocation) AcceptTCP() (transport.TCPConn, error) {
addr, err := net.ResolveTCPAddr("tcp4", a.serverAddr.String())
if err != nil {
return nil, err
}
tcpConn, err := a.net.DialTCP("tcp", nil, addr)
if err != nil {
return nil, err
}
dataConn, err := a.AcceptTCPWithConn(tcpConn)
if err != nil {
tcpConn.Close() //nolint:errcheck,gosec
}
return dataConn, err
}
// AcceptTCPWithConn accepts the next incoming call and returns the new connection.
func (a *TCPAllocation) AcceptTCPWithConn(conn net.Conn) (*TCPConn, error) {
select {
case attempt := <-a.connAttemptCh:
tcpConn, ok := conn.(transport.TCPConn)
if !ok {
return nil, errTCPAddrCast
}
dataConn := &TCPConn{
TCPConn: tcpConn,
ConnectionID: attempt.cid,
remoteAddress: attempt.from,
allocation: a,
}
if err := a.BindConnection(dataConn, attempt.cid); err != nil {
return nil, fmt.Errorf("failed to bind connection: %w", err)
}
return dataConn, nil
case <-a.acceptTimer.C:
return nil, &net.OpError{
Op: "accept",
Net: a.Addr().Network(),
Addr: a.Addr(),
Err: newTimeoutError("i/o timeout"),
}
}
}
// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
func (a *TCPAllocation) SetDeadline(t time.Time) error {
var d time.Duration
if t == noDeadline() {
d = time.Duration(math.MaxInt64)
} else {
d = time.Until(t)
}
a.acceptTimer.Reset(d)
return nil
}
// Close releases the allocation
// Any blocked Accept operations will be unblocked and return errors.
// Any opened connection via Dial/Accept will be closed.
func (a *TCPAllocation) Close() error {
a.refreshAllocTimer.Stop()
a.refreshPermsTimer.Stop()
a.client.OnDeallocated(a.relayedAddr)
return a.refreshAllocation(0, true /* dontWait=true */)
}
// Addr returns the relayed address of the allocation
func (a *TCPAllocation) Addr() net.Addr {
return a.relayedAddr
}
// HandleConnectionAttempt is called by the TURN client
// when it receives a ConnectionAttempt indication.
func (a *TCPAllocation) HandleConnectionAttempt(from *net.TCPAddr, cid proto.ConnectionID) {
a.connAttemptCh <- &connectionAttempt{
from: from,
cid: cid,
}
}

View File

@@ -0,0 +1,49 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package client
import (
"errors"
"net"
"github.com/pion/transport/v2"
"github.com/pion/turn/v2/internal/proto"
)
var (
errInvalidTURNFrame = errors.New("data is not a valid TURN frame, no STUN or ChannelData found")
errIncompleteTURNFrame = errors.New("data contains incomplete STUN or TURN frame")
)
const (
stunHeaderSize = 20
)
var _ transport.TCPConn = (*TCPConn)(nil) // Includes type check for net.Conn
// TCPConn wraps a transport.TCPConn and returns the allocations relayed
// transport address in response to TCPConn.LocalAddress()
type TCPConn struct {
transport.TCPConn
remoteAddress *net.TCPAddr
allocation *TCPAllocation
ConnectionID proto.ConnectionID
}
type connectionAttempt struct {
from *net.TCPAddr
cid proto.ConnectionID
}
// LocalAddr returns the local network address.
// The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
func (c *TCPConn) LocalAddr() net.Addr {
return c.allocation.Addr()
}
// RemoteAddr returns the remote network address.
// The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
func (c *TCPConn) RemoteAddr() net.Addr {
return c.remoteAddress
}

View File

@@ -0,0 +1,188 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package client
import (
"net"
"sync"
"time"
"github.com/pion/stun"
)
const (
maxRtxInterval time.Duration = 1600 * time.Millisecond
)
// TransactionResult is a bag of result values of a transaction
type TransactionResult struct {
Msg *stun.Message
From net.Addr
Retries int
Err error
}
// TransactionConfig is a set of config params used by NewTransaction
type TransactionConfig struct {
Key string
Raw []byte
To net.Addr
Interval time.Duration
IgnoreResult bool // True to throw away the result of this transaction (it will not be readable using WaitForResult)
}
// Transaction represents a transaction
type Transaction struct {
Key string // Read-only
Raw []byte // Read-only
To net.Addr // Read-only
nRtx int // Modified only by the timer thread
interval time.Duration // Modified only by the timer thread
timer *time.Timer // Thread-safe, set only by the creator, and stopper
resultCh chan TransactionResult // Thread-safe
mutex sync.RWMutex
}
// NewTransaction creates a new instance of Transaction
func NewTransaction(config *TransactionConfig) *Transaction {
var resultCh chan TransactionResult
if !config.IgnoreResult {
resultCh = make(chan TransactionResult)
}
return &Transaction{
Key: config.Key, // Read-only
Raw: config.Raw, // Read-only
To: config.To, // Read-only
interval: config.Interval, // Modified only by the timer thread
resultCh: resultCh, // Thread-safe
}
}
// StartRtxTimer starts the transaction timer
func (t *Transaction) StartRtxTimer(onTimeout func(trKey string, nRtx int)) {
t.mutex.Lock()
defer t.mutex.Unlock()
t.timer = time.AfterFunc(t.interval, func() {
t.mutex.Lock()
t.nRtx++
nRtx := t.nRtx
t.interval *= 2
if t.interval > maxRtxInterval {
t.interval = maxRtxInterval
}
t.mutex.Unlock()
onTimeout(t.Key, nRtx)
})
}
// StopRtxTimer stop the transaction timer
func (t *Transaction) StopRtxTimer() {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.timer != nil {
t.timer.Stop()
}
}
// WriteResult writes the result to the result channel
func (t *Transaction) WriteResult(res TransactionResult) bool {
if t.resultCh == nil {
return false
}
t.resultCh <- res
return true
}
// WaitForResult waits for the transaction result
func (t *Transaction) WaitForResult() TransactionResult {
if t.resultCh == nil {
return TransactionResult{
Err: errWaitForResultOnNonResultTransaction,
}
}
result, ok := <-t.resultCh
if !ok {
result.Err = errTransactionClosed
}
return result
}
// Close closes the transaction
func (t *Transaction) Close() {
if t.resultCh != nil {
close(t.resultCh)
}
}
// Retries returns the number of retransmission it has made
func (t *Transaction) Retries() int {
t.mutex.RLock()
defer t.mutex.RUnlock()
return t.nRtx
}
// TransactionMap is a thread-safe transaction map
type TransactionMap struct {
trMap map[string]*Transaction
mutex sync.RWMutex
}
// NewTransactionMap create a new instance of the transaction map
func NewTransactionMap() *TransactionMap {
return &TransactionMap{
trMap: map[string]*Transaction{},
}
}
// Insert inserts a transaction to the map
func (m *TransactionMap) Insert(key string, tr *Transaction) bool {
m.mutex.Lock()
defer m.mutex.Unlock()
m.trMap[key] = tr
return true
}
// Find looks up a transaction by its key
func (m *TransactionMap) Find(key string) (*Transaction, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
tr, ok := m.trMap[key]
return tr, ok
}
// Delete deletes a transaction by its key
func (m *TransactionMap) Delete(key string) {
m.mutex.Lock()
defer m.mutex.Unlock()
delete(m.trMap, key)
}
// CloseAndDeleteAll closes and deletes all transactions
func (m *TransactionMap) CloseAndDeleteAll() {
m.mutex.Lock()
defer m.mutex.Unlock()
for trKey, tr := range m.trMap {
tr.Close()
delete(m.trMap, trKey)
}
}
// Size returns the length of the transaction map
func (m *TransactionMap) Size() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return len(m.trMap)
}

View File

@@ -0,0 +1,27 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package client
import (
"sync/atomic"
)
// TryLock implement the classic "try-lock" operation.
type TryLock struct {
n int32
}
// Lock tries to lock the try-lock. If successful, it returns true.
// Otherwise, it returns false immediately.
func (c *TryLock) Lock() error {
if !atomic.CompareAndSwapInt32(&c.n, 0, 1) {
return errDoubleLock
}
return nil
}
// Unlock unlocks the try-lock.
func (c *TryLock) Unlock() {
atomic.StoreInt32(&c.n, 0)
}

View File

@@ -0,0 +1,455 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package client implements the API for a TURN client
package client
import (
"errors"
"fmt"
"io"
"math"
"net"
"time"
"github.com/pion/stun"
"github.com/pion/turn/v2/internal/proto"
)
const (
maxReadQueueSize = 1024
permRefreshInterval = 120 * time.Second
maxRetryAttempts = 3
)
const (
timerIDRefreshAlloc int = iota
timerIDRefreshPerms
)
type inboundData struct {
data []byte
from net.Addr
}
// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections.
// compatible with net.PacketConn and net.Conn
type UDPConn struct {
bindingMgr *bindingManager // Thread-safe
readCh chan *inboundData // Thread-safe
closeCh chan struct{} // Thread-safe
allocation
}
// NewUDPConn creates a new instance of UDPConn
func NewUDPConn(config *AllocationConfig) *UDPConn {
c := &UDPConn{
bindingMgr: newBindingManager(),
readCh: make(chan *inboundData, maxReadQueueSize),
closeCh: make(chan struct{}),
allocation: allocation{
client: config.Client,
relayedAddr: config.RelayedAddr,
serverAddr: config.ServerAddr,
readTimer: time.NewTimer(time.Duration(math.MaxInt64)),
permMap: newPermissionMap(),
username: config.Username,
realm: config.Realm,
integrity: config.Integrity,
_nonce: config.Nonce,
_lifetime: config.Lifetime,
net: config.Net,
log: config.Log,
},
}
c.log.Debugf("Initial lifetime: %d seconds", int(c.lifetime().Seconds()))
c.refreshAllocTimer = NewPeriodicTimer(
timerIDRefreshAlloc,
c.onRefreshTimers,
c.lifetime()/2,
)
c.refreshPermsTimer = NewPeriodicTimer(
timerIDRefreshPerms,
c.onRefreshTimers,
permRefreshInterval,
)
if c.refreshAllocTimer.Start() {
c.log.Debugf("Started refresh allocation timer")
}
if c.refreshPermsTimer.Start() {
c.log.Debugf("Started refresh permission timer")
}
return c
}
// ReadFrom reads a packet from the connection,
// copying the payload into p. It returns the number of
// bytes copied into p and the return address that
// was on the packet.
// It returns the number of bytes read (0 <= n <= len(p))
// and any error encountered. Callers should always process
// the n > 0 bytes returned before considering the error err.
// ReadFrom can be made to time out and return
// an Error with Timeout() == true after a fixed time limit;
// see SetDeadline and SetReadDeadline.
func (c *UDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
for {
select {
case ibData := <-c.readCh:
n := copy(p, ibData.data)
if n < len(ibData.data) {
return 0, nil, io.ErrShortBuffer
}
return n, ibData.from, nil
case <-c.readTimer.C:
return 0, nil, &net.OpError{
Op: "read",
Net: c.LocalAddr().Network(),
Addr: c.LocalAddr(),
Err: newTimeoutError("i/o timeout"),
}
case <-c.closeCh:
return 0, nil, &net.OpError{
Op: "read",
Net: c.LocalAddr().Network(),
Addr: c.LocalAddr(),
Err: errClosed,
}
}
}
}
func (a *allocation) createPermission(perm *permission, addr net.Addr) error {
perm.mutex.Lock()
defer perm.mutex.Unlock()
if perm.state() == permStateIdle {
// Punch a hole! (this would block a bit..)
if err := a.CreatePermissions(addr); err != nil {
a.permMap.delete(addr)
return err
}
perm.setState(permStatePermitted)
}
return nil
}
// WriteTo writes a packet with payload p to addr.
// WriteTo can be made to time out and return
// an Error with Timeout() == true after a fixed time limit;
// see SetDeadline and SetWriteDeadline.
// On packet-oriented connections, write timeouts are rare.
func (c *UDPConn) WriteTo(p []byte, addr net.Addr) (int, error) { //nolint: gocognit
var err error
_, ok := addr.(*net.UDPAddr)
if !ok {
return 0, errUDPAddrCast
}
// Check if we have a permission for the destination IP addr
perm, ok := c.permMap.find(addr)
if !ok {
perm = &permission{}
c.permMap.insert(addr, perm)
}
for i := 0; i < maxRetryAttempts; i++ {
// c.createPermission() would block, per destination IP (, or perm),
// until the perm state becomes "requested". Purpose of this is to
// guarantee the order of packets (within the same perm).
// Note that CreatePermission transaction may not be complete before
// all the data transmission. This is done assuming that the request
// will be most likely successful and we can tolerate some loss of
// UDP packet (or reorder), inorder to minimize the latency in most cases.
if err = c.createPermission(perm, addr); !errors.Is(err, errTryAgain) {
break
}
}
if err != nil {
return 0, err
}
// Bind channel
b, ok := c.bindingMgr.findByAddr(addr)
if !ok {
b = c.bindingMgr.create(addr)
}
bindSt := b.state()
if bindSt == bindingStateIdle || bindSt == bindingStateRequest || bindSt == bindingStateFailed {
func() {
// Block only callers with the same binding until
// the binding transaction has been complete
b.muBind.Lock()
defer b.muBind.Unlock()
// Binding state may have been changed while waiting. check again.
if b.state() == bindingStateIdle {
b.setState(bindingStateRequest)
go func() {
err2 := c.bind(b)
if err2 != nil {
c.log.Warnf("Failed to bind bind(): %s", err2)
b.setState(bindingStateFailed)
// Keep going...
} else {
b.setState(bindingStateReady)
}
}()
}
}()
// Send data using SendIndication
peerAddr := addr2PeerAddress(addr)
var msg *stun.Message
msg, err = stun.Build(
stun.TransactionID,
stun.NewType(stun.MethodSend, stun.ClassIndication),
proto.Data(p),
peerAddr,
stun.Fingerprint,
)
if err != nil {
return 0, err
}
// Indication has no transaction (fire-and-forget)
return c.client.WriteTo(msg.Raw, c.serverAddr)
}
// Binding is either ready
// Check if the binding needs a refresh
func() {
b.muBind.Lock()
defer b.muBind.Unlock()
if b.state() == bindingStateReady && time.Since(b.refreshedAt()) > 5*time.Minute {
b.setState(bindingStateRefresh)
go func() {
err = c.bind(b)
if err != nil {
c.log.Warnf("Failed to bind() for refresh: %s", err)
b.setState(bindingStateFailed)
// Keep going...
} else {
b.setRefreshedAt(time.Now())
b.setState(bindingStateReady)
}
}()
}
}()
// Send via ChannelData
_, err = c.sendChannelData(p, b.number)
if err != nil {
return 0, err
}
return len(p), nil
}
// Close closes the connection.
// Any blocked ReadFrom or WriteTo operations will be unblocked and return errors.
func (c *UDPConn) Close() error {
c.refreshAllocTimer.Stop()
c.refreshPermsTimer.Stop()
select {
case <-c.closeCh:
return errAlreadyClosed
default:
close(c.closeCh)
}
c.client.OnDeallocated(c.relayedAddr)
return c.refreshAllocation(0, true /* dontWait=true */)
}
// LocalAddr returns the local network address.
func (c *UDPConn) LocalAddr() net.Addr {
return c.relayedAddr
}
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
//
// A deadline is an absolute time after which I/O operations
// fail with a timeout (see type Error) instead of
// blocking. The deadline applies to all future and pending
// I/O, not just the immediately following call to ReadFrom or
// WriteTo. After a deadline has been exceeded, the connection
// can be refreshed by setting a deadline in the future.
//
// An idle timeout can be implemented by repeatedly extending
// the deadline after successful ReadFrom or WriteTo calls.
//
// A zero value for t means I/O operations will not time out.
func (c *UDPConn) SetDeadline(t time.Time) error {
return c.SetReadDeadline(t)
}
// SetReadDeadline sets the deadline for future ReadFrom calls
// and any currently-blocked ReadFrom call.
// A zero value for t means ReadFrom will not time out.
func (c *UDPConn) SetReadDeadline(t time.Time) error {
var d time.Duration
if t == noDeadline() {
d = time.Duration(math.MaxInt64)
} else {
d = time.Until(t)
}
c.readTimer.Reset(d)
return nil
}
// SetWriteDeadline sets the deadline for future WriteTo calls
// and any currently-blocked WriteTo call.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
// A zero value for t means WriteTo will not time out.
func (c *UDPConn) SetWriteDeadline(time.Time) error {
// Write never blocks.
return nil
}
func addr2PeerAddress(addr net.Addr) proto.PeerAddress {
var peerAddr proto.PeerAddress
switch a := addr.(type) {
case *net.UDPAddr:
peerAddr.IP = a.IP
peerAddr.Port = a.Port
case *net.TCPAddr:
peerAddr.IP = a.IP
peerAddr.Port = a.Port
}
return peerAddr
}
// CreatePermissions Issues a CreatePermission request for the supplied addresses
// as described in https://datatracker.ietf.org/doc/html/rfc5766#section-9
func (a *allocation) CreatePermissions(addrs ...net.Addr) error {
setters := []stun.Setter{
stun.TransactionID,
stun.NewType(stun.MethodCreatePermission, stun.ClassRequest),
}
for _, addr := range addrs {
setters = append(setters, addr2PeerAddress(addr))
}
setters = append(setters,
a.username,
a.realm,
a.nonce(),
a.integrity,
stun.Fingerprint)
msg, err := stun.Build(setters...)
if err != nil {
return err
}
trRes, err := a.client.PerformTransaction(msg, a.serverAddr, false)
if err != nil {
return err
}
res := trRes.Msg
if res.Type.Class == stun.ClassErrorResponse {
var code stun.ErrorCodeAttribute
if err = code.GetFrom(res); err == nil {
if code.Code == stun.CodeStaleNonce {
a.setNonceFromMsg(res)
return errTryAgain
}
return fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
}
return fmt.Errorf("%s", res.Type) //nolint:goerr113
}
return nil
}
// HandleInbound passes inbound data in UDPConn
func (c *UDPConn) HandleInbound(data []byte, from net.Addr) {
// Copy data
copied := make([]byte, len(data))
copy(copied, data)
select {
case c.readCh <- &inboundData{data: copied, from: from}:
default:
c.log.Warnf("Receive buffer full")
}
}
// FindAddrByChannelNumber returns a peer address associated with the
// channel number on this UDPConn
func (c *UDPConn) FindAddrByChannelNumber(chNum uint16) (net.Addr, bool) {
b, ok := c.bindingMgr.findByNumber(chNum)
if !ok {
return nil, false
}
return b.addr, true
}
func (c *UDPConn) bind(b *binding) error {
setters := []stun.Setter{
stun.TransactionID,
stun.NewType(stun.MethodChannelBind, stun.ClassRequest),
addr2PeerAddress(b.addr),
proto.ChannelNumber(b.number),
c.username,
c.realm,
c.nonce(),
c.integrity,
stun.Fingerprint,
}
msg, err := stun.Build(setters...)
if err != nil {
return err
}
trRes, err := c.client.PerformTransaction(msg, c.serverAddr, false)
if err != nil {
c.bindingMgr.deleteByAddr(b.addr)
return err
}
res := trRes.Msg
if res.Type != stun.NewType(stun.MethodChannelBind, stun.ClassSuccessResponse) {
return fmt.Errorf("unexpected response type %s", res.Type) //nolint:goerr113
}
c.log.Debugf("Channel binding successful: %s %d", b.addr, b.number)
// Success.
return nil
}
func (c *UDPConn) sendChannelData(data []byte, chNum uint16) (int, error) {
chData := &proto.ChannelData{
Data: data,
Number: proto.ChannelNumber(chNum),
}
chData.Encode()
_, err := c.client.WriteTo(chData.Raw, c.serverAddr)
if err != nil {
return 0, err
}
return len(data), nil
}

View File

@@ -0,0 +1,55 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package ipnet contains helper functions around net and IP
package ipnet
import (
"errors"
"net"
)
var errFailedToCastAddr = errors.New("failed to cast net.Addr to *net.UDPAddr or *net.TCPAddr")
// AddrIPPort extracts the IP and Port from a net.Addr
func AddrIPPort(a net.Addr) (net.IP, int, error) {
aUDP, ok := a.(*net.UDPAddr)
if ok {
return aUDP.IP, aUDP.Port, nil
}
aTCP, ok := a.(*net.TCPAddr)
if ok {
return aTCP.IP, aTCP.Port, nil
}
return nil, 0, errFailedToCastAddr
}
// AddrEqual asserts that two net.Addrs are equal
// Currently only supports UDP but will be extended in the future to support others
func AddrEqual(a, b net.Addr) bool {
aUDP, ok := a.(*net.UDPAddr)
if !ok {
return false
}
bUDP, ok := b.(*net.UDPAddr)
if !ok {
return false
}
return aUDP.IP.Equal(bUDP.IP) && aUDP.Port == bUDP.Port
}
// FingerprintAddr generates a fingerprint from net.UDPAddr or net.TCPAddr's
// which can be used for indexing maps.
func FingerprintAddr(addr net.Addr) string {
switch a := addr.(type) {
case *net.UDPAddr:
return a.IP.String()
case *net.TCPAddr: // Do we really need this case?
return a.IP.String()
}
return "" // Should never happen
}

View File

@@ -0,0 +1,68 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package proto
import (
"fmt"
"net"
)
// Addr is ip:port.
type Addr struct {
IP net.IP
Port int
}
// Network implements net.Addr.
func (Addr) Network() string { return "turn" }
// FromUDPAddr sets addr to UDPAddr.
func (a *Addr) FromUDPAddr(n *net.UDPAddr) {
a.IP = n.IP
a.Port = n.Port
}
// Equal returns true if b == a.
func (a Addr) Equal(b Addr) bool {
if a.Port != b.Port {
return false
}
return a.IP.Equal(b.IP)
}
// EqualIP returns true if a and b have equal IP addresses.
func (a Addr) EqualIP(b Addr) bool {
return a.IP.Equal(b.IP)
}
func (a Addr) String() string {
return fmt.Sprintf("%s:%d", a.IP, a.Port)
}
// FiveTuple represents 5-TUPLE value.
type FiveTuple struct {
Client Addr
Server Addr
Proto Protocol
}
func (t FiveTuple) String() string {
return fmt.Sprintf("%s->%s (%s)",
t.Client, t.Server, t.Proto,
)
}
// Equal returns true if b == t.
func (t FiveTuple) Equal(b FiveTuple) bool {
if t.Proto != b.Proto {
return false
}
if !t.Client.Equal(b.Client) {
return false
}
if !t.Server.Equal(b.Server) {
return false
}
return true
}

View File

@@ -0,0 +1,143 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package proto
import (
"bytes"
"encoding/binary"
"errors"
"io"
)
// ChannelData represents The ChannelData Message.
//
// See RFC 5766 Section 11.4
type ChannelData struct {
Data []byte // Can be sub slice of Raw
Length int // Ignored while encoding, len(Data) is used
Number ChannelNumber
Raw []byte
}
// Equal returns true if b == c.
func (c *ChannelData) Equal(b *ChannelData) bool {
if c == nil && b == nil {
return true
}
if c == nil || b == nil {
return false
}
if c.Number != b.Number {
return false
}
if len(c.Data) != len(b.Data) {
return false
}
return bytes.Equal(c.Data, b.Data)
}
// Grow ensures that internal buffer will fit v more bytes and
// increases it capacity if necessary.
//
// Similar to stun.Message.grow method.
func (c *ChannelData) grow(v int) {
n := len(c.Raw) + v
for cap(c.Raw) < n {
c.Raw = append(c.Raw, 0)
}
c.Raw = c.Raw[:n]
}
// Reset resets Length, Data and Raw length.
func (c *ChannelData) Reset() {
c.Raw = c.Raw[:0]
c.Length = 0
c.Data = c.Data[:0]
}
// Encode encodes ChannelData Message to Raw.
func (c *ChannelData) Encode() {
c.Raw = c.Raw[:0]
c.WriteHeader()
c.Raw = append(c.Raw, c.Data...)
padded := nearestPaddedValueLength(len(c.Raw))
if bytesToAdd := padded - len(c.Raw); bytesToAdd > 0 {
for i := 0; i < bytesToAdd; i++ {
c.Raw = append(c.Raw, 0)
}
}
}
const padding = 4
func nearestPaddedValueLength(l int) int {
n := padding * (l / padding)
if n < l {
n += padding
}
return n
}
// WriteHeader writes channel number and length.
func (c *ChannelData) WriteHeader() {
if len(c.Raw) < channelDataHeaderSize {
// Making WriteHeader call valid even when c.Raw
// is nil or len(c.Raw) is less than needed for header.
c.grow(channelDataHeaderSize)
}
// Early bounds check to guarantee safety of writes below.
_ = c.Raw[:channelDataHeaderSize]
binary.BigEndian.PutUint16(c.Raw[:channelDataNumberSize], uint16(c.Number))
binary.BigEndian.PutUint16(c.Raw[channelDataNumberSize:channelDataHeaderSize],
uint16(len(c.Data)),
)
}
// ErrBadChannelDataLength means that channel data length is not equal
// to actual data length.
var ErrBadChannelDataLength = errors.New("channelData length != len(Data)")
// Decode decodes The ChannelData Message from Raw.
func (c *ChannelData) Decode() error {
buf := c.Raw
if len(buf) < channelDataHeaderSize {
return io.ErrUnexpectedEOF
}
num := binary.BigEndian.Uint16(buf[:channelDataNumberSize])
c.Number = ChannelNumber(num)
l := binary.BigEndian.Uint16(buf[channelDataNumberSize:channelDataHeaderSize])
c.Data = buf[channelDataHeaderSize:]
c.Length = int(l)
if !c.Number.Valid() {
return ErrInvalidChannelNumber
}
if int(l) < len(c.Data) {
c.Data = c.Data[:int(l)]
}
if int(l) > len(buf[channelDataHeaderSize:]) {
return ErrBadChannelDataLength
}
return nil
}
const (
channelDataLengthSize = 2
channelDataNumberSize = channelDataLengthSize
channelDataHeaderSize = channelDataLengthSize + channelDataNumberSize
)
// IsChannelData returns true if buf looks like the ChannelData Message.
func IsChannelData(buf []byte) bool {
if len(buf) < channelDataHeaderSize {
return false
}
if int(binary.BigEndian.Uint16(buf[channelDataNumberSize:channelDataHeaderSize])) > len(buf[channelDataHeaderSize:]) {
return false
}
// Quick check for channel number.
num := binary.BigEndian.Uint16(buf[0:channelNumberSize])
return isChannelNumberValid(num)
}

View File

@@ -0,0 +1,70 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package proto
import (
"encoding/binary"
"errors"
"strconv"
"github.com/pion/stun"
)
// ChannelNumber represents CHANNEL-NUMBER attribute.
//
// The CHANNEL-NUMBER attribute contains the number of the channel.
//
// RFC 5766 Section 14.1
type ChannelNumber uint16 // Encoded as uint16
func (n ChannelNumber) String() string { return strconv.Itoa(int(n)) }
// 16 bits of uint + 16 bits of RFFU = 0.
const channelNumberSize = 4
// AddTo adds CHANNEL-NUMBER to message.
func (n ChannelNumber) AddTo(m *stun.Message) error {
v := make([]byte, channelNumberSize)
binary.BigEndian.PutUint16(v[:2], uint16(n))
// v[2:4] are zeroes (RFFU = 0)
m.Add(stun.AttrChannelNumber, v)
return nil
}
// GetFrom decodes CHANNEL-NUMBER from message.
func (n *ChannelNumber) GetFrom(m *stun.Message) error {
v, err := m.Get(stun.AttrChannelNumber)
if err != nil {
return err
}
if err = stun.CheckSize(stun.AttrChannelNumber, len(v), channelNumberSize); err != nil {
return err
}
_ = v[channelNumberSize-1] // Asserting length
*n = ChannelNumber(binary.BigEndian.Uint16(v[:2]))
// v[2:4] is RFFU and equals to 0.
return nil
}
// See https://tools.ietf.org/html/rfc5766#section-11:
//
// 0x4000 through 0x7FFF: These values are the allowed channel
// numbers (16,383 possible values).
const (
MinChannelNumber = 0x4000
MaxChannelNumber = 0x7FFF
)
// ErrInvalidChannelNumber means that channel number is not valid as by RFC 5766 Section 11.
var ErrInvalidChannelNumber = errors.New("channel number not in [0x4000, 0x7FFF]")
// isChannelNumberValid returns true if c in [0x4000, 0x7FFF].
func isChannelNumberValid(c uint16) bool {
return c >= MinChannelNumber && c <= MaxChannelNumber
}
// Valid returns true if channel number has correct value that complies RFC 5766 Section 11 range.
func (n ChannelNumber) Valid() bool {
return isChannelNumberValid(uint16(n))
}

View File

@@ -0,0 +1,42 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package proto
import (
"encoding/binary"
"github.com/pion/stun"
)
// ConnectionID represents CONNECTION-ID attribute.
//
// The CONNECTION-ID attribute uniquely identifies a peer data
// connection. It is a 32-bit unsigned integral value.
//
// RFC 6062 Section 6.2.1
type ConnectionID uint32
const connectionIDSize = 4 // uint32: 4 bytes, 32 bits
// AddTo adds CONNECTION-ID to message.
func (c ConnectionID) AddTo(m *stun.Message) error {
v := make([]byte, lifetimeSize)
binary.BigEndian.PutUint32(v, uint32(c))
m.Add(stun.AttrConnectionID, v)
return nil
}
// GetFrom decodes CONNECTION-ID from message.
func (c *ConnectionID) GetFrom(m *stun.Message) error {
v, err := m.Get(stun.AttrConnectionID)
if err != nil {
return err
}
if err = stun.CheckSize(stun.AttrConnectionID, len(v), connectionIDSize); err != nil {
return err
}
_ = v[connectionIDSize-1] // Asserting length
*(*uint32)(c) = binary.BigEndian.Uint32(v)
return nil
}

View File

@@ -0,0 +1,33 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package proto
import "github.com/pion/stun"
// Data represents DATA attribute.
//
// The DATA attribute is present in all Send and Data indications. The
// value portion of this attribute is variable length and consists of
// the application data (that is, the data that would immediately follow
// the UDP header if the data was been sent directly between the client
// and the peer).
//
// RFC 5766 Section 14.4
type Data []byte
// AddTo adds DATA to message.
func (d Data) AddTo(m *stun.Message) error {
m.Add(stun.AttrData, d)
return nil
}
// GetFrom decodes DATA from message.
func (d *Data) GetFrom(m *stun.Message) error {
v, err := m.Get(stun.AttrData)
if err != nil {
return err
}
*d = v
return nil
}

View File

@@ -0,0 +1,45 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package proto
import (
"github.com/pion/stun"
)
// DontFragmentAttr is a deprecated alias for DontFragment
// Deprecated: Please use DontFragment
type DontFragmentAttr = DontFragment
// DontFragment represents DONT-FRAGMENT attribute.
//
// This attribute is used by the client to request that the server set
// the DF (Don't Fragment) bit in the IP header when relaying the
// application data onward to the peer. This attribute has no value
// part and thus the attribute length field is 0.
//
// RFC 5766 Section 14.8
type DontFragment struct{}
const dontFragmentSize = 0
// AddTo adds DONT-FRAGMENT attribute to message.
func (DontFragment) AddTo(m *stun.Message) error {
m.Add(stun.AttrDontFragment, nil)
return nil
}
// GetFrom decodes DONT-FRAGMENT from message.
func (d *DontFragment) GetFrom(m *stun.Message) error {
v, err := m.Get(stun.AttrDontFragment)
if err != nil {
return err
}
return stun.CheckSize(stun.AttrDontFragment, len(v), dontFragmentSize)
}
// IsSet returns true if DONT-FRAGMENT attribute is set.
func (DontFragment) IsSet(m *stun.Message) bool {
_, err := m.Get(stun.AttrDontFragment)
return err == nil
}

View File

@@ -0,0 +1,58 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package proto
import "github.com/pion/stun"
// EvenPort represents EVEN-PORT attribute.
//
// This attribute allows the client to request that the port in the
// relayed transport address be even, and (optionally) that the server
// reserve the next-higher port number.
//
// RFC 5766 Section 14.6
type EvenPort struct {
// ReservePort means that the server is requested to reserve
// the next-higher port number (on the same IP address)
// for a subsequent allocation.
ReservePort bool
}
func (p EvenPort) String() string {
if p.ReservePort {
return "reserve: true"
}
return "reserve: false"
}
const (
evenPortSize = 1
firstBitSet = (1 << 8) - 1 // 0b100000000
)
// AddTo adds EVEN-PORT to message.
func (p EvenPort) AddTo(m *stun.Message) error {
v := make([]byte, evenPortSize)
if p.ReservePort {
// Set first bit to 1.
v[0] = firstBitSet
}
m.Add(stun.AttrEvenPort, v)
return nil
}
// GetFrom decodes EVEN-PORT from message.
func (p *EvenPort) GetFrom(m *stun.Message) error {
v, err := m.Get(stun.AttrEvenPort)
if err != nil {
return err
}
if err = stun.CheckSize(stun.AttrEvenPort, len(v), evenPortSize); err != nil {
return err
}
if v[0]&firstBitSet > 0 {
p.ReservePort = true
}
return nil
}

View File

@@ -0,0 +1,55 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package proto
import (
"encoding/binary"
"time"
"github.com/pion/stun"
)
// DefaultLifetime in RFC 5766 is 10 minutes.
//
// RFC 5766 Section 2.2
const DefaultLifetime = time.Minute * 10
// Lifetime represents LIFETIME attribute.
//
// The LIFETIME attribute represents the duration for which the server
// will maintain an allocation in the absence of a refresh. The value
// portion of this attribute is 4-bytes long and consists of a 32-bit
// unsigned integral value representing the number of seconds remaining
// until expiration.
//
// RFC 5766 Section 14.2
type Lifetime struct {
time.Duration
}
// Seconds in uint32
const lifetimeSize = 4 // 4 bytes, 32 bits
// AddTo adds LIFETIME to message.
func (l Lifetime) AddTo(m *stun.Message) error {
v := make([]byte, lifetimeSize)
binary.BigEndian.PutUint32(v, uint32(l.Seconds()))
m.Add(stun.AttrLifetime, v)
return nil
}
// GetFrom decodes LIFETIME from message.
func (l *Lifetime) GetFrom(m *stun.Message) error {
v, err := m.Get(stun.AttrLifetime)
if err != nil {
return err
}
if err = stun.CheckSize(stun.AttrLifetime, len(v), lifetimeSize); err != nil {
return err
}
_ = v[lifetimeSize-1] // Asserting length
seconds := binary.BigEndian.Uint32(v)
l.Duration = time.Second * time.Duration(seconds)
return nil
}

View File

@@ -0,0 +1,45 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package proto
import (
"net"
"github.com/pion/stun"
)
// PeerAddress implements XOR-PEER-ADDRESS attribute.
//
// The XOR-PEER-ADDRESS specifies the address and port of the peer as
// seen from the TURN server. (For example, the peer's server-reflexive
// transport address if the peer is behind a NAT.)
//
// RFC 5766 Section 14.3
type PeerAddress struct {
IP net.IP
Port int
}
func (a PeerAddress) String() string {
return stun.XORMappedAddress(a).String()
}
// AddTo adds XOR-PEER-ADDRESS to message.
func (a PeerAddress) AddTo(m *stun.Message) error {
return stun.XORMappedAddress(a).AddToAs(m, stun.AttrXORPeerAddress)
}
// GetFrom decodes XOR-PEER-ADDRESS from message.
func (a *PeerAddress) GetFrom(m *stun.Message) error {
return (*stun.XORMappedAddress)(a).GetFromAs(m, stun.AttrXORPeerAddress)
}
// XORPeerAddress implements XOR-PEER-ADDRESS attribute.
//
// The XOR-PEER-ADDRESS specifies the address and port of the peer as
// seen from the TURN server. (For example, the peer's server-reflexive
// transport address if the peer is behind a NAT.)
//
// RFC 5766 Section 14.3
type XORPeerAddress = PeerAddress

View File

@@ -0,0 +1,31 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package proto implements RFC 5766 Traversal Using Relays around NAT.
package proto
import (
"github.com/pion/stun"
)
// Default ports for TURN from RFC 5766 Section 4.
const (
// DefaultPort for TURN is same as STUN.
DefaultPort = stun.DefaultPort
// DefaultTLSPort is for TURN over TLS and is same as STUN.
DefaultTLSPort = stun.DefaultTLSPort
)
// CreatePermissionRequest is shorthand for create permission request type.
func CreatePermissionRequest() stun.MessageType {
return stun.NewType(stun.MethodCreatePermission, stun.ClassRequest)
}
// AllocateRequest is shorthand for allocation request message type.
func AllocateRequest() stun.MessageType { return stun.NewType(stun.MethodAllocate, stun.ClassRequest) }
// SendIndication is shorthand for send indication message type.
func SendIndication() stun.MessageType { return stun.NewType(stun.MethodSend, stun.ClassIndication) }
// RefreshRequest is shorthand for refresh request message type.
func RefreshRequest() stun.MessageType { return stun.NewType(stun.MethodRefresh, stun.ClassRequest) }

View File

@@ -0,0 +1,43 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package proto
import (
"net"
"github.com/pion/stun"
)
// RelayedAddress implements XOR-RELAYED-ADDRESS attribute.
//
// It specifies the address and port that the server allocated to the
// client. It is encoded in the same way as XOR-MAPPED-ADDRESS.
//
// RFC 5766 Section 14.5
type RelayedAddress struct {
IP net.IP
Port int
}
func (a RelayedAddress) String() string {
return stun.XORMappedAddress(a).String()
}
// AddTo adds XOR-PEER-ADDRESS to message.
func (a RelayedAddress) AddTo(m *stun.Message) error {
return stun.XORMappedAddress(a).AddToAs(m, stun.AttrXORRelayedAddress)
}
// GetFrom decodes XOR-PEER-ADDRESS from message.
func (a *RelayedAddress) GetFrom(m *stun.Message) error {
return (*stun.XORMappedAddress)(a).GetFromAs(m, stun.AttrXORRelayedAddress)
}
// XORRelayedAddress implements XOR-RELAYED-ADDRESS attribute.
//
// It specifies the address and port that the server allocated to the
// client. It is encoded in the same way as XOR-MAPPED-ADDRESS.
//
// RFC 5766 Section 14.5
type XORRelayedAddress = RelayedAddress

View File

@@ -0,0 +1,64 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package proto
import (
"errors"
"github.com/pion/stun"
)
// RequestedAddressFamily represents the REQUESTED-ADDRESS-FAMILY Attribute as
// defined in RFC 6156 Section 4.1.1.
type RequestedAddressFamily byte
const requestedFamilySize = 4
var errInvalidRequestedFamilyValue = errors.New("invalid value for requested family attribute")
// GetFrom decodes REQUESTED-ADDRESS-FAMILY from message.
func (f *RequestedAddressFamily) GetFrom(m *stun.Message) error {
v, err := m.Get(stun.AttrRequestedAddressFamily)
if err != nil {
return err
}
if err = stun.CheckSize(stun.AttrRequestedAddressFamily, len(v), requestedFamilySize); err != nil {
return err
}
switch v[0] {
case byte(RequestedFamilyIPv4), byte(RequestedFamilyIPv6):
*f = RequestedAddressFamily(v[0])
default:
return errInvalidRequestedFamilyValue
}
return nil
}
func (f RequestedAddressFamily) String() string {
switch f {
case RequestedFamilyIPv4:
return "IPv4"
case RequestedFamilyIPv6:
return "IPv6"
default:
return "unknown"
}
}
// AddTo adds REQUESTED-ADDRESS-FAMILY to message.
func (f RequestedAddressFamily) AddTo(m *stun.Message) error {
v := make([]byte, requestedFamilySize)
v[0] = byte(f)
// b[1:4] is RFFU = 0.
// The RFFU field MUST be set to zero on transmission and MUST be
// ignored on reception. It is reserved for future uses.
m.Add(stun.AttrRequestedAddressFamily, v)
return nil
}
// Values for RequestedAddressFamily as defined in RFC 6156 Section 4.1.1.
const (
RequestedFamilyIPv4 RequestedAddressFamily = 0x01
RequestedFamilyIPv6 RequestedAddressFamily = 0x02
)

View File

@@ -0,0 +1,72 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package proto
import (
"strconv"
"github.com/pion/stun"
)
// Protocol is IANA assigned protocol number.
type Protocol byte
const (
// ProtoTCP is IANA assigned protocol number for TCP.
ProtoTCP Protocol = 6
// ProtoUDP is IANA assigned protocol number for UDP.
ProtoUDP Protocol = 17
)
func (p Protocol) String() string {
switch p {
case ProtoTCP:
return "TCP"
case ProtoUDP:
return "UDP"
default:
return strconv.Itoa(int(p))
}
}
// RequestedTransport represents REQUESTED-TRANSPORT attribute.
//
// This attribute is used by the client to request a specific transport
// protocol for the allocated transport address. RFC 5766 only allows the use of
// code point 17 (User Datagram Protocol).
//
// RFC 5766 Section 14.7
type RequestedTransport struct {
Protocol Protocol
}
func (t RequestedTransport) String() string {
return "protocol: " + t.Protocol.String()
}
const requestedTransportSize = 4
// AddTo adds REQUESTED-TRANSPORT to message.
func (t RequestedTransport) AddTo(m *stun.Message) error {
v := make([]byte, requestedTransportSize)
v[0] = byte(t.Protocol)
// b[1:4] is RFFU = 0.
// The RFFU field MUST be set to zero on transmission and MUST be
// ignored on reception. It is reserved for future uses.
m.Add(stun.AttrRequestedTransport, v)
return nil
}
// GetFrom decodes REQUESTED-TRANSPORT from message.
func (t *RequestedTransport) GetFrom(m *stun.Message) error {
v, err := m.Get(stun.AttrRequestedTransport)
if err != nil {
return err
}
if err = stun.CheckSize(stun.AttrRequestedTransport, len(v), requestedTransportSize); err != nil {
return err
}
t.Protocol = Protocol(v[0])
return nil
}

View File

@@ -0,0 +1,42 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package proto
import "github.com/pion/stun"
// ReservationToken represents RESERVATION-TOKEN attribute.
//
// The RESERVATION-TOKEN attribute contains a token that uniquely
// identifies a relayed transport address being held in reserve by the
// server. The server includes this attribute in a success response to
// tell the client about the token, and the client includes this
// attribute in a subsequent Allocate request to request the server use
// that relayed transport address for the allocation.
//
// RFC 5766 Section 14.9
type ReservationToken []byte
const reservationTokenSize = 8 // 8 bytes
// AddTo adds RESERVATION-TOKEN to message.
func (t ReservationToken) AddTo(m *stun.Message) error {
if err := stun.CheckSize(stun.AttrReservationToken, len(t), reservationTokenSize); err != nil {
return err
}
m.Add(stun.AttrReservationToken, t)
return nil
}
// GetFrom decodes RESERVATION-TOKEN from message.
func (t *ReservationToken) GetFrom(m *stun.Message) error {
v, err := m.Get(stun.AttrReservationToken)
if err != nil {
return err
}
if err = stun.CheckSize(stun.AttrReservationToken, len(v), reservationTokenSize); err != nil {
return err
}
*t = v
return nil
}

View File

@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package server
import "errors"
var (
errFailedToGenerateNonce = errors.New("failed to generate nonce")
errInvalidNonce = errors.New("invalid nonce")
errFailedToSendError = errors.New("failed to send error message")
errNoSuchUser = errors.New("no such user exists")
errUnexpectedClass = errors.New("unexpected class")
errUnexpectedMethod = errors.New("unexpected method")
errFailedToHandle = errors.New("failed to handle")
errUnhandledSTUNPacket = errors.New("unhandled STUN packet")
errUnableToHandleChannelData = errors.New("unable to handle ChannelData")
errFailedToCreateSTUNPacket = errors.New("failed to create stun message from packet")
errFailedToCreateChannelData = errors.New("failed to create channel data from packet")
errRelayAlreadyAllocatedForFiveTuple = errors.New("relay already allocated for 5-TUPLE")
errUnsupportedTransportProtocol = errors.New("RequestedTransport must be UDP or TCP")
errNoDontFragmentSupport = errors.New("no support for DONT-FRAGMENT")
errRequestWithReservationTokenAndEvenPort = errors.New("Request must not contain RESERVATION-TOKEN and EVEN-PORT")
errNoAllocationFound = errors.New("no allocation found")
errNoPermission = errors.New("unable to handle send-indication, no permission added")
errShortWrite = errors.New("packet write smaller than packet")
errNoSuchChannelBind = errors.New("no such channel bind")
errFailedWriteSocket = errors.New("failed writing to socket")
)

View File

@@ -0,0 +1,71 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package server
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"time"
)
const (
nonceLifetime = time.Hour // See: https://tools.ietf.org/html/rfc5766#section-4
nonceLength = 40
nonceKeyLength = 64
)
// NewNonceHash creates a NonceHash
func NewNonceHash() (*NonceHash, error) {
key := make([]byte, nonceKeyLength)
if _, err := rand.Read(key); err != nil {
return nil, err
}
return &NonceHash{key}, nil
}
// NonceHash is used to create and verify nonces
type NonceHash struct {
key []byte
}
// Generate a nonce
func (n *NonceHash) Generate() (string, error) {
nonce := make([]byte, 8, nonceLength)
binary.BigEndian.PutUint64(nonce, uint64(time.Now().UnixMilli()))
hash := hmac.New(sha256.New, n.key)
if _, err := hash.Write(nonce[:8]); err != nil {
return "", fmt.Errorf("%w: %v", errFailedToGenerateNonce, err) //nolint:errorlint
}
nonce = hash.Sum(nonce)
return hex.EncodeToString(nonce), nil
}
// Validate checks that nonce is signed and is not expired
func (n *NonceHash) Validate(nonce string) error {
b, err := hex.DecodeString(nonce)
if err != nil || len(b) != nonceLength {
return fmt.Errorf("%w: %v", errInvalidNonce, err) //nolint:errorlint
}
if ts := time.UnixMilli(int64(binary.BigEndian.Uint64(b))); time.Since(ts) > nonceLifetime {
return errInvalidNonce
}
hash := hmac.New(sha256.New, n.key)
if _, err = hash.Write(b[:8]); err != nil {
return fmt.Errorf("%w: %v", errInvalidNonce, err) //nolint:errorlint
}
if !hmac.Equal(b[8:], hash.Sum(nil)) {
return errInvalidNonce
}
return nil
}

View File

@@ -0,0 +1,111 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package server implements the private API to implement a TURN server
package server
import (
"fmt"
"net"
"time"
"github.com/pion/logging"
"github.com/pion/stun"
"github.com/pion/turn/v2/internal/allocation"
"github.com/pion/turn/v2/internal/proto"
)
// Request contains all the state needed to process a single incoming datagram
type Request struct {
// Current Request State
Conn net.PacketConn
SrcAddr net.Addr
Buff []byte
// Server State
AllocationManager *allocation.Manager
NonceHash *NonceHash
// User Configuration
AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool)
Log logging.LeveledLogger
Realm string
ChannelBindTimeout time.Duration
}
// HandleRequest processes the give Request
func HandleRequest(r Request) error {
r.Log.Debugf("Received %d bytes of udp from %s on %s", len(r.Buff), r.SrcAddr, r.Conn.LocalAddr())
if proto.IsChannelData(r.Buff) {
return handleDataPacket(r)
}
return handleTURNPacket(r)
}
func handleDataPacket(r Request) error {
r.Log.Debugf("Received DataPacket from %s", r.SrcAddr.String())
c := proto.ChannelData{Raw: r.Buff}
if err := c.Decode(); err != nil {
return fmt.Errorf("%w: %v", errFailedToCreateChannelData, err) //nolint:errorlint
}
err := handleChannelData(r, &c)
if err != nil {
err = fmt.Errorf("%w from %v: %v", errUnableToHandleChannelData, r.SrcAddr, err) //nolint:errorlint
}
return err
}
func handleTURNPacket(r Request) error {
r.Log.Debug("Handling TURN packet")
m := &stun.Message{Raw: append([]byte{}, r.Buff...)}
if err := m.Decode(); err != nil {
return fmt.Errorf("%w: %v", errFailedToCreateSTUNPacket, err) //nolint:errorlint
}
h, err := getMessageHandler(m.Type.Class, m.Type.Method)
if err != nil {
return fmt.Errorf("%w %v-%v from %v: %v", errUnhandledSTUNPacket, m.Type.Method, m.Type.Class, r.SrcAddr, err) //nolint:errorlint
}
err = h(r, m)
if err != nil {
return fmt.Errorf("%w %v-%v from %v: %v", errFailedToHandle, m.Type.Method, m.Type.Class, r.SrcAddr, err) //nolint:errorlint
}
return nil
}
func getMessageHandler(class stun.MessageClass, method stun.Method) (func(r Request, m *stun.Message) error, error) {
switch class {
case stun.ClassIndication:
switch method {
case stun.MethodSend:
return handleSendIndication, nil
default:
return nil, fmt.Errorf("%w: %s", errUnexpectedMethod, method)
}
case stun.ClassRequest:
switch method {
case stun.MethodAllocate:
return handleAllocateRequest, nil
case stun.MethodRefresh:
return handleRefreshRequest, nil
case stun.MethodCreatePermission:
return handleCreatePermissionRequest, nil
case stun.MethodChannelBind:
return handleChannelBindRequest, nil
case stun.MethodBinding:
return handleBindingRequest, nil
default:
return nil, fmt.Errorf("%w: %s", errUnexpectedMethod, method)
}
default:
return nil, fmt.Errorf("%w: %s", errUnexpectedClass, class)
}
}

View File

@@ -0,0 +1,25 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package server
import (
"github.com/pion/stun"
"github.com/pion/turn/v2/internal/ipnet"
)
func handleBindingRequest(r Request, m *stun.Message) error {
r.Log.Debugf("Received BindingRequest from %s", r.SrcAddr)
ip, port, err := ipnet.AddrIPPort(r.SrcAddr)
if err != nil {
return err
}
attrs := buildMsg(m.TransactionID, stun.BindingSuccess, &stun.XORMappedAddress{
IP: ip,
Port: port,
}, stun.Fingerprint)
return buildAndSend(r.Conn, r.SrcAddr, attrs...)
}

View File

@@ -0,0 +1,376 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package server
import (
"fmt"
"net"
"github.com/pion/stun"
"github.com/pion/turn/v2/internal/allocation"
"github.com/pion/turn/v2/internal/ipnet"
"github.com/pion/turn/v2/internal/proto"
)
// See: https://tools.ietf.org/html/rfc5766#section-6.2
func handleAllocateRequest(r Request, m *stun.Message) error {
r.Log.Debugf("Received AllocateRequest from %s", r.SrcAddr)
// 1. The server MUST require that the request be authenticated. This
// authentication MUST be done using the long-term credential
// mechanism of [https://tools.ietf.org/html/rfc5389#section-10.2.2]
// unless the client and server agree to use another mechanism through
// some procedure outside the scope of this document.
messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodAllocate)
if !hasAuth {
return err
}
fiveTuple := &allocation.FiveTuple{
SrcAddr: r.SrcAddr,
DstAddr: r.Conn.LocalAddr(),
Protocol: allocation.UDP,
}
requestedPort := 0
reservationToken := ""
badRequestMsg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeBadRequest})
insufficientCapacityMsg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeInsufficientCapacity})
// 2. The server checks if the 5-tuple is currently in use by an
// existing allocation. If yes, the server rejects the request with
// a 437 (Allocation Mismatch) error.
if alloc := r.AllocationManager.GetAllocation(fiveTuple); alloc != nil {
id, attrs := alloc.GetResponseCache()
if id != m.TransactionID {
msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeAllocMismatch})
return buildAndSendErr(r.Conn, r.SrcAddr, errRelayAlreadyAllocatedForFiveTuple, msg...)
}
// A retry allocation
msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), append(attrs, messageIntegrity)...)
return buildAndSend(r.Conn, r.SrcAddr, msg...)
}
// 3. The server checks if the request contains a REQUESTED-TRANSPORT
// attribute. If the REQUESTED-TRANSPORT attribute is not included
// or is malformed, the server rejects the request with a 400 (Bad
// Request) error. Otherwise, if the attribute is included but
// specifies a protocol other that UDP/TCP, the server rejects the
// request with a 442 (Unsupported Transport Protocol) error.
var requestedTransport proto.RequestedTransport
if err = requestedTransport.GetFrom(m); err != nil {
return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...)
} else if requestedTransport.Protocol != proto.ProtoUDP && requestedTransport.Protocol != proto.ProtoTCP {
msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeUnsupportedTransProto})
return buildAndSendErr(r.Conn, r.SrcAddr, errUnsupportedTransportProtocol, msg...)
}
// 4. The request may contain a DONT-FRAGMENT attribute. If it does,
// but the server does not support sending UDP datagrams with the DF
// bit set to 1 (see Section 12), then the server treats the DONT-
// FRAGMENT attribute in the Allocate request as an unknown
// comprehension-required attribute.
if m.Contains(stun.AttrDontFragment) {
msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeUnknownAttribute}, &stun.UnknownAttributes{stun.AttrDontFragment})
return buildAndSendErr(r.Conn, r.SrcAddr, errNoDontFragmentSupport, msg...)
}
// 5. The server checks if the request contains a RESERVATION-TOKEN
// attribute. If yes, and the request also contains an EVEN-PORT
// attribute, then the server rejects the request with a 400 (Bad
// Request) error. Otherwise, it checks to see if the token is
// valid (i.e., the token is in range and has not expired and the
// corresponding relayed transport address is still available). If
// the token is not valid for some reason, the server rejects the
// request with a 508 (Insufficient Capacity) error.
var reservationTokenAttr proto.ReservationToken
if err = reservationTokenAttr.GetFrom(m); err == nil {
var evenPort proto.EvenPort
if err = evenPort.GetFrom(m); err == nil {
return buildAndSendErr(r.Conn, r.SrcAddr, errRequestWithReservationTokenAndEvenPort, badRequestMsg...)
}
}
// 6. The server checks if the request contains an EVEN-PORT attribute.
// If yes, then the server checks that it can satisfy the request
// (i.e., can allocate a relayed transport address as described
// below). If the server cannot satisfy the request, then the
// server rejects the request with a 508 (Insufficient Capacity)
// error.
var evenPort proto.EvenPort
if err = evenPort.GetFrom(m); err == nil {
var randomPort int
randomPort, err = r.AllocationManager.GetRandomEvenPort()
if err != nil {
return buildAndSendErr(r.Conn, r.SrcAddr, err, insufficientCapacityMsg...)
}
requestedPort = randomPort
reservationToken = randSeq(8)
}
// 7. At any point, the server MAY choose to reject the request with a
// 486 (Allocation Quota Reached) error if it feels the client is
// trying to exceed some locally defined allocation quota. The
// server is free to define this allocation quota any way it wishes,
// but SHOULD define it based on the username used to authenticate
// the request, and not on the client's transport address.
// 8. Also at any point, the server MAY choose to reject the request
// with a 300 (Try Alternate) error if it wishes to redirect the
// client to a different server. The use of this error code and
// attribute follow the specification in [RFC5389].
lifetimeDuration := allocationLifeTime(m)
a, err := r.AllocationManager.CreateAllocation(
fiveTuple,
r.Conn,
requestedPort,
lifetimeDuration)
if err != nil {
return buildAndSendErr(r.Conn, r.SrcAddr, err, insufficientCapacityMsg...)
}
// Once the allocation is created, the server replies with a success
// response.
// The success response contains:
// * An XOR-RELAYED-ADDRESS attribute containing the relayed transport
// address.
// * A LIFETIME attribute containing the current value of the time-to-
// expiry timer.
// * A RESERVATION-TOKEN attribute (if a second relayed transport
// address was reserved).
// * An XOR-MAPPED-ADDRESS attribute containing the client's IP address
// and port (from the 5-tuple).
srcIP, srcPort, err := ipnet.AddrIPPort(r.SrcAddr)
if err != nil {
return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...)
}
relayIP, relayPort, err := ipnet.AddrIPPort(a.RelayAddr)
if err != nil {
return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...)
}
responseAttrs := []stun.Setter{
&proto.RelayedAddress{
IP: relayIP,
Port: relayPort,
},
&proto.Lifetime{
Duration: lifetimeDuration,
},
&stun.XORMappedAddress{
IP: srcIP,
Port: srcPort,
},
}
if reservationToken != "" {
r.AllocationManager.CreateReservation(reservationToken, relayPort)
responseAttrs = append(responseAttrs, proto.ReservationToken([]byte(reservationToken)))
}
msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), append(responseAttrs, messageIntegrity)...)
a.SetResponseCache(m.TransactionID, responseAttrs)
return buildAndSend(r.Conn, r.SrcAddr, msg...)
}
func handleRefreshRequest(r Request, m *stun.Message) error {
r.Log.Debugf("Received RefreshRequest from %s", r.SrcAddr)
messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodRefresh)
if !hasAuth {
return err
}
lifetimeDuration := allocationLifeTime(m)
fiveTuple := &allocation.FiveTuple{
SrcAddr: r.SrcAddr,
DstAddr: r.Conn.LocalAddr(),
Protocol: allocation.UDP,
}
if lifetimeDuration != 0 {
a := r.AllocationManager.GetAllocation(fiveTuple)
if a == nil {
return fmt.Errorf("%w %v:%v", errNoAllocationFound, r.SrcAddr, r.Conn.LocalAddr())
}
a.Refresh(lifetimeDuration)
} else {
r.AllocationManager.DeleteAllocation(fiveTuple)
}
return buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(stun.MethodRefresh, stun.ClassSuccessResponse), []stun.Setter{
&proto.Lifetime{
Duration: lifetimeDuration,
},
messageIntegrity,
}...)...)
}
func handleCreatePermissionRequest(r Request, m *stun.Message) error {
r.Log.Debugf("Received CreatePermission from %s", r.SrcAddr)
a := r.AllocationManager.GetAllocation(&allocation.FiveTuple{
SrcAddr: r.SrcAddr,
DstAddr: r.Conn.LocalAddr(),
Protocol: allocation.UDP,
})
if a == nil {
return fmt.Errorf("%w %v:%v", errNoAllocationFound, r.SrcAddr, r.Conn.LocalAddr())
}
messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodCreatePermission)
if !hasAuth {
return err
}
addCount := 0
if err := m.ForEach(stun.AttrXORPeerAddress, func(m *stun.Message) error {
var peerAddress proto.PeerAddress
if err := peerAddress.GetFrom(m); err != nil {
return err
}
if err := r.AllocationManager.GrantPermission(r.SrcAddr, peerAddress.IP); err != nil {
r.Log.Infof("permission denied for client %s to peer %s", r.SrcAddr, peerAddress.IP)
return err
}
r.Log.Debugf("Adding permission for %s", fmt.Sprintf("%s:%d",
peerAddress.IP, peerAddress.Port))
a.AddPermission(allocation.NewPermission(
&net.UDPAddr{
IP: peerAddress.IP,
Port: peerAddress.Port,
},
r.Log,
))
addCount++
return nil
}); err != nil {
addCount = 0
}
respClass := stun.ClassSuccessResponse
if addCount == 0 {
respClass = stun.ClassErrorResponse
}
return buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(stun.MethodCreatePermission, respClass), []stun.Setter{messageIntegrity}...)...)
}
func handleSendIndication(r Request, m *stun.Message) error {
r.Log.Debugf("Received SendIndication from %s", r.SrcAddr)
a := r.AllocationManager.GetAllocation(&allocation.FiveTuple{
SrcAddr: r.SrcAddr,
DstAddr: r.Conn.LocalAddr(),
Protocol: allocation.UDP,
})
if a == nil {
return fmt.Errorf("%w %v:%v", errNoAllocationFound, r.SrcAddr, r.Conn.LocalAddr())
}
dataAttr := proto.Data{}
if err := dataAttr.GetFrom(m); err != nil {
return err
}
peerAddress := proto.PeerAddress{}
if err := peerAddress.GetFrom(m); err != nil {
return err
}
msgDst := &net.UDPAddr{IP: peerAddress.IP, Port: peerAddress.Port}
if perm := a.GetPermission(msgDst); perm == nil {
return fmt.Errorf("%w: %v", errNoPermission, msgDst)
}
l, err := a.RelaySocket.WriteTo(dataAttr, msgDst)
if l != len(dataAttr) {
return fmt.Errorf("%w %d != %d (expected) err: %v", errShortWrite, l, len(dataAttr), err) //nolint:errorlint
}
return err
}
func handleChannelBindRequest(r Request, m *stun.Message) error {
r.Log.Debugf("Received ChannelBindRequest from %s", r.SrcAddr)
a := r.AllocationManager.GetAllocation(&allocation.FiveTuple{
SrcAddr: r.SrcAddr,
DstAddr: r.Conn.LocalAddr(),
Protocol: allocation.UDP,
})
if a == nil {
return fmt.Errorf("%w %v:%v", errNoAllocationFound, r.SrcAddr, r.Conn.LocalAddr())
}
badRequestMsg := buildMsg(m.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeBadRequest})
messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodChannelBind)
if !hasAuth {
return err
}
var channel proto.ChannelNumber
if err = channel.GetFrom(m); err != nil {
return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...)
}
peerAddr := proto.PeerAddress{}
if err = peerAddr.GetFrom(m); err != nil {
return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...)
}
if err = r.AllocationManager.GrantPermission(r.SrcAddr, peerAddr.IP); err != nil {
r.Log.Infof("permission denied for client %s to peer %s", r.SrcAddr, peerAddr.IP)
unauthorizedRequestMsg := buildMsg(m.TransactionID,
stun.NewType(stun.MethodChannelBind, stun.ClassErrorResponse),
&stun.ErrorCodeAttribute{Code: stun.CodeUnauthorized})
return buildAndSendErr(r.Conn, r.SrcAddr, err, unauthorizedRequestMsg...)
}
r.Log.Debugf("Binding channel %d to %s", channel, peerAddr)
err = a.AddChannelBind(allocation.NewChannelBind(
channel,
&net.UDPAddr{IP: peerAddr.IP, Port: peerAddr.Port},
r.Log,
), r.ChannelBindTimeout)
if err != nil {
return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...)
}
return buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassSuccessResponse), []stun.Setter{messageIntegrity}...)...)
}
func handleChannelData(r Request, c *proto.ChannelData) error {
r.Log.Debugf("Received ChannelData from %s", r.SrcAddr)
a := r.AllocationManager.GetAllocation(&allocation.FiveTuple{
SrcAddr: r.SrcAddr,
DstAddr: r.Conn.LocalAddr(),
Protocol: allocation.UDP,
})
if a == nil {
return fmt.Errorf("%w %v:%v", errNoAllocationFound, r.SrcAddr, r.Conn.LocalAddr())
}
channel := a.GetChannelByNumber(c.Number)
if channel == nil {
return fmt.Errorf("%w %x", errNoSuchChannelBind, uint16(c.Number))
}
l, err := a.RelaySocket.WriteTo(c.Data, channel.Peer)
if err != nil {
return fmt.Errorf("%w: %s", errFailedWriteSocket, err.Error())
} else if l != len(c.Data) {
return fmt.Errorf("%w %d != %d (expected)", errShortWrite, l, len(c.Data))
}
return nil
}

View File

@@ -0,0 +1,117 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package server
import (
"errors"
"fmt"
"math/rand"
"net"
"time"
"github.com/pion/stun"
"github.com/pion/turn/v2/internal/proto"
)
const (
maximumAllocationLifetime = time.Hour // See: https://tools.ietf.org/html/rfc5766#section-6.2 defines 3600 seconds recommendation
)
func randSeq(n int) string {
letters := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))] //nolint:gosec
}
return string(b)
}
func buildAndSend(conn net.PacketConn, dst net.Addr, attrs ...stun.Setter) error {
msg, err := stun.Build(attrs...)
if err != nil {
return err
}
_, err = conn.WriteTo(msg.Raw, dst)
if errors.Is(err, net.ErrClosed) {
return nil
}
return err
}
// Send a STUN packet and return the original error to the caller
func buildAndSendErr(conn net.PacketConn, dst net.Addr, err error, attrs ...stun.Setter) error {
if sendErr := buildAndSend(conn, dst, attrs...); sendErr != nil {
err = fmt.Errorf("%w %v %v", errFailedToSendError, sendErr, err) //nolint:errorlint
}
return err
}
func buildMsg(transactionID [stun.TransactionIDSize]byte, msgType stun.MessageType, additional ...stun.Setter) []stun.Setter {
return append([]stun.Setter{&stun.Message{TransactionID: transactionID}, msgType}, additional...)
}
func authenticateRequest(r Request, m *stun.Message, callingMethod stun.Method) (stun.MessageIntegrity, bool, error) {
respondWithNonce := func(responseCode stun.ErrorCode) (stun.MessageIntegrity, bool, error) {
nonce, err := r.NonceHash.Generate()
if err != nil {
return nil, false, err
}
return nil, false, buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID,
stun.NewType(callingMethod, stun.ClassErrorResponse),
&stun.ErrorCodeAttribute{Code: responseCode},
stun.NewNonce(nonce),
stun.NewRealm(r.Realm),
)...)
}
if !m.Contains(stun.AttrMessageIntegrity) {
return respondWithNonce(stun.CodeUnauthorized)
}
nonceAttr := &stun.Nonce{}
usernameAttr := &stun.Username{}
realmAttr := &stun.Realm{}
badRequestMsg := buildMsg(m.TransactionID, stun.NewType(callingMethod, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeBadRequest})
if err := nonceAttr.GetFrom(m); err != nil {
return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...)
}
// Assert Nonce is signed and is not expired
if err := r.NonceHash.Validate(nonceAttr.String()); err != nil {
return respondWithNonce(stun.CodeStaleNonce)
}
if err := realmAttr.GetFrom(m); err != nil {
return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...)
} else if err := usernameAttr.GetFrom(m); err != nil {
return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...)
}
ourKey, ok := r.AuthHandler(usernameAttr.String(), realmAttr.String(), r.SrcAddr)
if !ok {
return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, fmt.Errorf("%w %s", errNoSuchUser, usernameAttr.String()), badRequestMsg...)
}
if err := stun.MessageIntegrity(ourKey).Check(m); err != nil {
return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...)
}
return stun.MessageIntegrity(ourKey), true, nil
}
func allocationLifeTime(m *stun.Message) time.Duration {
lifetimeDuration := proto.DefaultLifetime
var lifetime proto.Lifetime
if err := lifetime.GetFrom(m); err == nil {
if lifetime.Duration < maximumAllocationLifetime {
lifetimeDuration = lifetime.Duration
}
}
return lifetimeDuration
}