Files
web/server/pkg/weblive/hub.go

187 lines
3.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package weblive
import (
"sync"
"github.com/gorilla/websocket"
"github.com/pion/rtp"
"github.com/pion/webrtc/v3"
)
// trackForwarder 从主播轨读 RTP复制到所有观众本地轨
type trackForwarder struct {
remote *webrtc.TrackRemote
mu sync.Mutex
locals map[string]*webrtc.TrackLocalStaticRTP
stopCh chan struct{}
}
func newTrackForwarder(track *webrtc.TrackRemote) *trackForwarder {
return &trackForwarder{
remote: track,
locals: make(map[string]*webrtc.TrackLocalStaticRTP),
stopCh: make(chan struct{}),
}
}
func (tf *trackForwarder) addViewer(id string, t *webrtc.TrackLocalStaticRTP) {
tf.mu.Lock()
defer tf.mu.Unlock()
tf.locals[id] = t
}
func (tf *trackForwarder) removeViewer(id string) {
tf.mu.Lock()
defer tf.mu.Unlock()
delete(tf.locals, id)
}
func (tf *trackForwarder) close() {
select {
case <-tf.stopCh:
default:
close(tf.stopCh)
}
}
func (tf *trackForwarder) runReadLoop() {
buf := make([]byte, 1500)
for {
select {
case <-tf.stopCh:
return
default:
}
n, _, err := tf.remote.Read(buf)
if err != nil {
return
}
tf.mu.Lock()
for _, lt := range tf.locals {
cp := &rtp.Packet{}
if err := cp.Unmarshal(buf[:n]); err != nil {
continue
}
_ = lt.WriteRTP(cp)
}
tf.mu.Unlock()
}
}
// Hub 单房间:一名主播、多名观众(进程内内存态,重启清空)
type Hub struct {
mu sync.RWMutex
api *webrtc.API
cfg webrtc.Configuration
publishConn *websocket.Conn
pubPC *webrtc.PeerConnection
forwarders []*trackForwarder
viewers map[string]*viewerSession
}
type viewerSession struct {
id string
ws *websocket.Conn
pc *webrtc.PeerConnection
pending []webrtc.ICECandidateInit
answered bool
}
func newHub(api *webrtc.API) *Hub {
return &Hub{
api: api,
cfg: webrtc.Configuration{ICEServers: iceServersFromEnv()},
viewers: make(map[string]*viewerSession),
}
}
var (
defaultHub *Hub
hubOnce sync.Once
hubInitErr error
)
func getHub() (*Hub, error) {
hubOnce.Do(func() {
var api *webrtc.API
api, hubInitErr = buildAPI()
if hubInitErr != nil {
return
}
defaultHub = newHub(api)
})
return defaultHub, hubInitErr
}
func (h *Hub) clearPublisher() {
h.mu.Lock()
defer h.mu.Unlock()
for _, tf := range h.forwarders {
tf.close()
}
h.forwarders = nil
if h.pubPC != nil {
_ = h.pubPC.Close()
h.pubPC = nil
}
h.publishConn = nil
}
func (h *Hub) removeViewer(id string) {
h.mu.Lock()
vs, ok := h.viewers[id]
if ok {
delete(h.viewers, id)
}
for _, tf := range h.forwarders {
tf.removeViewer(id)
}
h.mu.Unlock()
if ok && vs != nil && vs.pc != nil {
_ = vs.pc.Close()
}
}
func (h *Hub) onPublisherTrack(track *webrtc.TrackRemote) {
// 仅转发视频轨,降低协商复杂度
if track.Kind() != webrtc.RTPCodecTypeVideo {
return
}
tf := newTrackForwarder(track)
h.mu.Lock()
h.forwarders = append(h.forwarders, tf)
h.mu.Unlock()
go tf.runReadLoop()
// 观众仅在「已开播」后拉流:首次协商时 attachForwardersToViewerPC 会带上当前全部轨,无需在此重协商
}
func (h *Hub) attachForwardersToViewerPC(v *viewerSession) {
h.mu.RLock()
fwd := append([]*trackForwarder(nil), h.forwarders...)
h.mu.RUnlock()
for _, tf := range fwd {
cap := tf.remote.Codec().RTPCodecCapability
lt, err := webrtc.NewTrackLocalStaticRTP(cap, tf.remote.ID()+"_"+v.id, tf.remote.StreamID())
if err != nil {
continue
}
rtpSender, err := v.pc.AddTrack(lt)
if err != nil {
continue
}
// Drain RTCP feedback to keep interceptors/senders healthy.
go func() {
rtcpBuf := make([]byte, 1500)
for {
if _, _, e := rtpSender.Read(rtcpBuf); e != nil {
return
}
}
}()
tf.addViewer(v.id, lt)
}
}