189 lines
3.8 KiB
Go
189 lines
3.8 KiB
Go
package weblive
|
||
|
||
import (
|
||
"sync"
|
||
|
||
"github.com/gorilla/websocket"
|
||
"github.com/pion/rtp"
|
||
"github.com/pion/webrtc/v3"
|
||
)
|
||
|
||
// trackForwarder 从主播轨读 RTP,复制到所有观众本地轨
|
||
type trackForwarder struct {
|
||
remote *webrtc.TrackRemote
|
||
mu sync.Mutex
|
||
locals map[string]*webrtc.TrackLocalStaticRTP
|
||
stopCh chan struct{}
|
||
}
|
||
|
||
func newTrackForwarder(track *webrtc.TrackRemote) *trackForwarder {
|
||
return &trackForwarder{
|
||
remote: track,
|
||
locals: make(map[string]*webrtc.TrackLocalStaticRTP),
|
||
stopCh: make(chan struct{}),
|
||
}
|
||
}
|
||
|
||
func (tf *trackForwarder) addViewer(id string, t *webrtc.TrackLocalStaticRTP) {
|
||
tf.mu.Lock()
|
||
defer tf.mu.Unlock()
|
||
tf.locals[id] = t
|
||
}
|
||
|
||
func (tf *trackForwarder) removeViewer(id string) {
|
||
tf.mu.Lock()
|
||
defer tf.mu.Unlock()
|
||
delete(tf.locals, id)
|
||
}
|
||
|
||
func (tf *trackForwarder) close() {
|
||
select {
|
||
case <-tf.stopCh:
|
||
default:
|
||
close(tf.stopCh)
|
||
}
|
||
}
|
||
|
||
func (tf *trackForwarder) runReadLoop() {
|
||
buf := make([]byte, 1500)
|
||
for {
|
||
select {
|
||
case <-tf.stopCh:
|
||
return
|
||
default:
|
||
}
|
||
n, _, err := tf.remote.Read(buf)
|
||
if err != nil {
|
||
return
|
||
}
|
||
tf.mu.Lock()
|
||
for _, lt := range tf.locals {
|
||
cp := &rtp.Packet{}
|
||
if err := cp.Unmarshal(buf[:n]); err != nil {
|
||
continue
|
||
}
|
||
_ = lt.WriteRTP(cp)
|
||
}
|
||
tf.mu.Unlock()
|
||
}
|
||
}
|
||
|
||
// Hub 单房间:一名主播、多名观众(进程内内存态,重启清空)
|
||
type Hub struct {
|
||
mu sync.RWMutex
|
||
|
||
api *webrtc.API
|
||
cfg webrtc.Configuration
|
||
|
||
publishConn *websocket.Conn
|
||
pubPC *webrtc.PeerConnection
|
||
// 开播 WebSocket 上 quality= 参数,供 GET /live/info 只读输出
|
||
publishQuality string
|
||
forwarders []*trackForwarder
|
||
|
||
viewers map[string]*viewerSession
|
||
}
|
||
|
||
type viewerSession struct {
|
||
id string
|
||
ws *websocket.Conn
|
||
pc *webrtc.PeerConnection
|
||
pending []webrtc.ICECandidateInit
|
||
answered bool
|
||
}
|
||
|
||
func newHub(api *webrtc.API) *Hub {
|
||
return &Hub{
|
||
api: api,
|
||
cfg: webrtc.Configuration{ICEServers: iceServersFromEnv()},
|
||
viewers: make(map[string]*viewerSession),
|
||
}
|
||
}
|
||
|
||
var (
|
||
defaultHub *Hub
|
||
hubOnce sync.Once
|
||
hubInitErr error
|
||
)
|
||
|
||
func getHub() (*Hub, error) {
|
||
hubOnce.Do(func() {
|
||
var api *webrtc.API
|
||
api, hubInitErr = buildAPI()
|
||
if hubInitErr != nil {
|
||
return
|
||
}
|
||
defaultHub = newHub(api)
|
||
})
|
||
return defaultHub, hubInitErr
|
||
}
|
||
|
||
func (h *Hub) clearPublisher() {
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
for _, tf := range h.forwarders {
|
||
tf.close()
|
||
}
|
||
h.forwarders = nil
|
||
if h.pubPC != nil {
|
||
_ = h.pubPC.Close()
|
||
h.pubPC = nil
|
||
}
|
||
h.publishConn = nil
|
||
h.publishQuality = ""
|
||
}
|
||
|
||
func (h *Hub) removeViewer(id string) {
|
||
h.mu.Lock()
|
||
vs, ok := h.viewers[id]
|
||
if ok {
|
||
delete(h.viewers, id)
|
||
}
|
||
for _, tf := range h.forwarders {
|
||
tf.removeViewer(id)
|
||
}
|
||
h.mu.Unlock()
|
||
if ok && vs != nil && vs.pc != nil {
|
||
_ = vs.pc.Close()
|
||
}
|
||
}
|
||
|
||
func (h *Hub) onPublisherTrack(track *webrtc.TrackRemote) {
|
||
if track.Kind() != webrtc.RTPCodecTypeVideo && track.Kind() != webrtc.RTPCodecTypeAudio {
|
||
return
|
||
}
|
||
tf := newTrackForwarder(track)
|
||
h.mu.Lock()
|
||
h.forwarders = append(h.forwarders, tf)
|
||
h.mu.Unlock()
|
||
goSafe("trackRead", tf.runReadLoop)
|
||
// 观众仅在「已开播」后拉流:首次协商时 attachForwardersToViewerPC 会带上当前全部轨,无需在此重协商
|
||
}
|
||
|
||
func (h *Hub) attachForwardersToViewerPC(v *viewerSession) {
|
||
h.mu.RLock()
|
||
fwd := append([]*trackForwarder(nil), h.forwarders...)
|
||
h.mu.RUnlock()
|
||
for _, tf := range fwd {
|
||
cap := tf.remote.Codec().RTPCodecCapability
|
||
lt, err := webrtc.NewTrackLocalStaticRTP(cap, tf.remote.ID()+"_"+v.id, tf.remote.StreamID())
|
||
if err != nil {
|
||
continue
|
||
}
|
||
rtpSender, err := v.pc.AddTrack(lt)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
// Drain RTCP feedback to keep interceptors/senders healthy.
|
||
goSafe("viewerRTCP", func() {
|
||
rtcpBuf := make([]byte, 1500)
|
||
for {
|
||
if _, _, e := rtpSender.Read(rtcpBuf); e != nil {
|
||
return
|
||
}
|
||
}
|
||
})
|
||
tf.addViewer(v.id, lt)
|
||
}
|
||
}
|