package weblive import ( "sync" "github.com/gorilla/websocket" "github.com/pion/rtp" "github.com/pion/webrtc/v3" ) // trackForwarder 从主播轨读 RTP,复制到所有观众本地轨 type trackForwarder struct { remote *webrtc.TrackRemote mu sync.Mutex locals map[string]*webrtc.TrackLocalStaticRTP stopCh chan struct{} } func newTrackForwarder(track *webrtc.TrackRemote) *trackForwarder { return &trackForwarder{ remote: track, locals: make(map[string]*webrtc.TrackLocalStaticRTP), stopCh: make(chan struct{}), } } func (tf *trackForwarder) addViewer(id string, t *webrtc.TrackLocalStaticRTP) { tf.mu.Lock() defer tf.mu.Unlock() tf.locals[id] = t } func (tf *trackForwarder) removeViewer(id string) { tf.mu.Lock() defer tf.mu.Unlock() delete(tf.locals, id) } func (tf *trackForwarder) close() { select { case <-tf.stopCh: default: close(tf.stopCh) } } func (tf *trackForwarder) runReadLoop() { buf := make([]byte, 1500) for { select { case <-tf.stopCh: return default: } n, _, err := tf.remote.Read(buf) if err != nil { return } tf.mu.Lock() for _, lt := range tf.locals { cp := &rtp.Packet{} if err := cp.Unmarshal(buf[:n]); err != nil { continue } _ = lt.WriteRTP(cp) } tf.mu.Unlock() } } // Hub 单房间:一名主播、多名观众(进程内内存态,重启清空) type Hub struct { mu sync.RWMutex api *webrtc.API cfg webrtc.Configuration publishConn *websocket.Conn pubPC *webrtc.PeerConnection forwarders []*trackForwarder viewers map[string]*viewerSession } type viewerSession struct { id string ws *websocket.Conn pc *webrtc.PeerConnection pending []webrtc.ICECandidateInit answered bool } func newHub(api *webrtc.API) *Hub { return &Hub{ api: api, cfg: webrtc.Configuration{ICEServers: iceServersFromEnv()}, viewers: make(map[string]*viewerSession), } } var ( defaultHub *Hub hubOnce sync.Once hubInitErr error ) func getHub() (*Hub, error) { hubOnce.Do(func() { var api *webrtc.API api, hubInitErr = buildAPI() if hubInitErr != nil { return } defaultHub = newHub(api) }) return defaultHub, hubInitErr } func (h *Hub) clearPublisher() { h.mu.Lock() defer h.mu.Unlock() for _, tf := range h.forwarders { tf.close() } h.forwarders = nil if h.pubPC != nil { _ = h.pubPC.Close() h.pubPC = nil } h.publishConn = nil } func (h *Hub) removeViewer(id string) { h.mu.Lock() vs, ok := h.viewers[id] if ok { delete(h.viewers, id) } for _, tf := range h.forwarders { tf.removeViewer(id) } h.mu.Unlock() if ok && vs != nil && vs.pc != nil { _ = vs.pc.Close() } } func (h *Hub) onPublisherTrack(track *webrtc.TrackRemote) { // 仅转发视频轨,降低协商复杂度 if track.Kind() != webrtc.RTPCodecTypeVideo { return } tf := newTrackForwarder(track) h.mu.Lock() h.forwarders = append(h.forwarders, tf) h.mu.Unlock() go tf.runReadLoop() // 观众仅在「已开播」后拉流:首次协商时 attachForwardersToViewerPC 会带上当前全部轨,无需在此重协商 } func (h *Hub) attachForwardersToViewerPC(v *viewerSession) { h.mu.RLock() fwd := append([]*trackForwarder(nil), h.forwarders...) h.mu.RUnlock() for _, tf := range fwd { cap := tf.remote.Codec().RTPCodecCapability lt, err := webrtc.NewTrackLocalStaticRTP(cap, tf.remote.ID()+"_"+v.id, tf.remote.StreamID()) if err != nil { continue } if _, err := v.pc.AddTrack(lt); err != nil { continue } tf.addViewer(v.id, lt) } }