diff --git a/client.go b/client.go index 686d5a5..d9d30f5 100644 --- a/client.go +++ b/client.go @@ -1,5 +1,9 @@ package main +import ( + "sfu/conn" +) + type clientCredentials struct { Username string `json:"username,omitempty"` Password string `json:"password,omitempty"` @@ -16,7 +20,7 @@ type client interface { Id() string Credentials() clientCredentials SetPermissions(clientPermissions) - pushConn(id string, conn upConnection, tracks []upTrack, label string) error + pushConn(id string, conn conn.Up, tracks []conn.UpTrack, label string) error pushClient(id, username string, add bool) error } diff --git a/conn.go b/conn.go deleted file mode 100644 index de559e4..0000000 --- a/conn.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2020 by Juliusz Chroboczek. - -// This is not open source software. Copy it, and I'll break into your -// house and tell your three year-old that Santa doesn't exist. - -package main - -import ( - "errors" - - "github.com/pion/rtp" - "github.com/pion/webrtc/v3" -) - -var ErrConnectionClosed = errors.New("connection is closed") -var ErrKeyframeNeeded = errors.New("keyframe needed") - -type upConnection interface { - addLocal(downConnection) error - delLocal(downConnection) bool - Id() string - Label() string -} - -type upTrack interface { - addLocal(downTrack) error - delLocal(downTrack) bool - Label() string - Codec() *webrtc.RTPCodec - // get a recent packet. Returns 0 if the packet is not in cache. - getRTP(seqno uint16, result []byte) uint16 -} - -type downConnection interface { - GetMaxBitrate(now uint64) uint64 -} - -type downTrack interface { - WriteRTP(packat *rtp.Packet) error - Accumulate(bytes uint32) - setTimeOffset(ntp uint64, rtp uint32) - setCname(string) -} diff --git a/conn/conn.go b/conn/conn.go new file mode 100644 index 0000000..c4de7e6 --- /dev/null +++ b/conn/conn.go @@ -0,0 +1,48 @@ +// Copyright (c) 2020 by Juliusz Chroboczek. + +// This is not open source software. Copy it, and I'll break into your +// house and tell your three year-old that Santa doesn't exist. + +// Package conn defines interfaces for connections and tracks. +package conn + +import ( + "errors" + + "github.com/pion/rtp" + "github.com/pion/webrtc/v3" +) + +var ErrConnectionClosed = errors.New("connection is closed") +var ErrKeyframeNeeded = errors.New("keyframe needed") + +// Type Up represents a connection in the client to server direction. +type Up interface { + AddLocal(Down) error + DelLocal(Down) bool + Id() string + Label() string +} + +// Type UpTrack represents a track in the client to server direction. +type UpTrack interface { + AddLocal(DownTrack) error + DelLocal(DownTrack) bool + Label() string + Codec() *webrtc.RTPCodec + // get a recent packet. Returns 0 if the packet is not in cache. + GetRTP(seqno uint16, result []byte) uint16 +} + +// Type Down represents a connection in the server to client direction. +type Down interface { + GetMaxBitrate(now uint64) uint64 +} + +// Type DownTrack represents a track in the server to client direction. +type DownTrack interface { + WriteRTP(packat *rtp.Packet) error + Accumulate(bytes uint32) + SetTimeOffset(ntp uint64, rtp uint32) + SetCname(string) +} diff --git a/disk.go b/disk.go index 880bd49..ea9f922 100644 --- a/disk.go +++ b/disk.go @@ -14,6 +14,8 @@ import ( "github.com/pion/rtp/codecs" "github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3/pkg/media/samplebuilder" + + "sfu/conn" ) type diskClient struct { @@ -81,7 +83,7 @@ func (client *diskClient) kick(message string) error { return err } -func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrack, label string) error { +func (client *diskClient) pushConn(id string, up conn.Up, tracks []conn.UpTrack, label string) error { client.mu.Lock() defer client.mu.Unlock() @@ -95,7 +97,7 @@ func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrac delete(client.down, id) } - if conn == nil { + if up == nil { return nil } @@ -109,12 +111,12 @@ func (client *diskClient) pushConn(id string, conn upConnection, tracks []upTrac client.down = make(map[string]*diskConn) } - down, err := newDiskConn(directory, label, conn, tracks) + down, err := newDiskConn(directory, label, up, tracks) if err != nil { return err } - client.down[conn.Id()] = down + client.down[up.Id()] = down return nil } @@ -125,7 +127,7 @@ type diskConn struct { mu sync.Mutex file *os.File - remote upConnection + remote conn.Up tracks []*diskTrack width, height uint32 } @@ -150,7 +152,7 @@ func (conn *diskConn) reopen() error { } func (conn *diskConn) Close() error { - conn.remote.delLocal(conn) + conn.remote.DelLocal(conn) conn.mu.Lock() tracks := make([]*diskTrack, 0, len(conn.tracks)) @@ -164,7 +166,7 @@ func (conn *diskConn) Close() error { conn.mu.Unlock() for _, t := range tracks { - t.remote.delLocal(t) + t.remote.DelLocal(t) } return nil } @@ -196,7 +198,7 @@ func openDiskFile(directory, label string) (*os.File, error) { } type diskTrack struct { - remote upTrack + remote conn.UpTrack conn *diskConn writer webm.BlockWriteCloser @@ -206,7 +208,7 @@ type diskTrack struct { origin uint64 } -func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrack) (*diskConn, error) { +func newDiskConn(directory, label string, up conn.Up, remoteTracks []conn.UpTrack) (*diskConn, error) { conn := diskConn{ directory: directory, label: label, @@ -231,10 +233,10 @@ func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrac conn: &conn, } conn.tracks = append(conn.tracks, track) - remote.addLocal(track) + remote.AddLocal(track) } - err := up.addLocal(&conn) + err := up.AddLocal(&conn) if err != nil { return nil, err } @@ -242,10 +244,10 @@ func newDiskConn(directory, label string, up upConnection, remoteTracks []upTrac return &conn, nil } -func (t *diskTrack) setTimeOffset(ntp uint64, rtp uint32) { +func (t *diskTrack) SetTimeOffset(ntp uint64, rtp uint32) { } -func (t *diskTrack) setCname(string) { +func (t *diskTrack) SetCname(string) { } func clonePacket(packet *rtp.Packet) *rtp.Packet { @@ -310,7 +312,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { if t.writer == nil { if !keyframe { - return ErrKeyframeNeeded + return conn.ErrKeyframeNeeded } return nil } diff --git a/rtpconn.go b/rtpconn.go index 5104b4d..193405f 100644 --- a/rtpconn.go +++ b/rtpconn.go @@ -14,14 +14,15 @@ import ( "sync/atomic" "time" + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/pion/webrtc/v3" + + "sfu/conn" "sfu/estimator" "sfu/jitter" "sfu/packetcache" "sfu/rtptime" - - "github.com/pion/rtcp" - "github.com/pion/rtp" - "github.com/pion/webrtc/v3" ) type bitrate struct { @@ -71,7 +72,7 @@ type iceConnection interface { type rtpDownTrack struct { track *webrtc.Track - remote upTrack + remote conn.UpTrack maxBitrate *bitrate rate *estimator.Estimator stats *receiverStats @@ -91,25 +92,25 @@ func (down *rtpDownTrack) Accumulate(bytes uint32) { down.rate.Accumulate(bytes) } -func (down *rtpDownTrack) setTimeOffset(ntp uint64, rtp uint32) { +func (down *rtpDownTrack) SetTimeOffset(ntp uint64, rtp uint32) { atomic.StoreUint64(&down.remoteNTPTime, ntp) atomic.StoreUint32(&down.remoteRTPTime, rtp) } -func (down *rtpDownTrack) setCname(cname string) { +func (down *rtpDownTrack) SetCname(cname string) { down.cname.Store(cname) } type rtpDownConnection struct { id string pc *webrtc.PeerConnection - remote upConnection + remote conn.Up tracks []*rtpDownTrack maxREMBBitrate *bitrate iceCandidates []*webrtc.ICECandidateInit } -func newDownConn(c client, id string, remote upConnection) (*rtpDownConnection, error) { +func newDownConn(c client, id string, remote conn.Up) (*rtpDownConnection, error) { pc, err := c.Group().API().NewPeerConnection(iceConfiguration()) if err != nil { return nil, err @@ -193,7 +194,7 @@ type rtpUpTrack struct { mu sync.Mutex cname string - local []downTrack + local []conn.DownTrack srTime uint64 srNTPTime uint64 srRTPTime uint32 @@ -201,17 +202,17 @@ type rtpUpTrack struct { type localTrackAction struct { add bool - track downTrack + track conn.DownTrack } -func (up *rtpUpTrack) notifyLocal(add bool, track downTrack) { +func (up *rtpUpTrack) notifyLocal(add bool, track conn.DownTrack) { select { case up.localCh <- localTrackAction{add, track}: case <-up.readerDone: } } -func (up *rtpUpTrack) addLocal(local downTrack) error { +func (up *rtpUpTrack) AddLocal(local conn.DownTrack) error { up.mu.Lock() for _, t := range up.local { if t == local { @@ -226,7 +227,7 @@ func (up *rtpUpTrack) addLocal(local downTrack) error { return nil } -func (up *rtpUpTrack) delLocal(local downTrack) bool { +func (up *rtpUpTrack) DelLocal(local conn.DownTrack) bool { up.mu.Lock() for i, l := range up.local { if l == local { @@ -240,15 +241,15 @@ func (up *rtpUpTrack) delLocal(local downTrack) bool { return false } -func (up *rtpUpTrack) getLocal() []downTrack { +func (up *rtpUpTrack) getLocal() []conn.DownTrack { up.mu.Lock() defer up.mu.Unlock() - local := make([]downTrack, len(up.local)) + local := make([]conn.DownTrack, len(up.local)) copy(local, up.local) return local } -func (up *rtpUpTrack) getRTP(seqno uint16, result []byte) uint16 { +func (up *rtpUpTrack) GetRTP(seqno uint16, result []byte) uint16 { return up.cache.Get(seqno, result) } @@ -278,7 +279,7 @@ type rtpUpConnection struct { mu sync.Mutex tracks []*rtpUpTrack - local []downConnection + local []conn.Down } func (up *rtpUpConnection) getTracks() []*rtpUpTrack { @@ -297,7 +298,7 @@ func (up *rtpUpConnection) Label() string { return up.label } -func (up *rtpUpConnection) addLocal(local downConnection) error { +func (up *rtpUpConnection) AddLocal(local conn.Down) error { up.mu.Lock() defer up.mu.Unlock() for _, t := range up.local { @@ -309,7 +310,7 @@ func (up *rtpUpConnection) addLocal(local downConnection) error { return nil } -func (up *rtpUpConnection) delLocal(local downConnection) bool { +func (up *rtpUpConnection) DelLocal(local conn.Down) bool { up.mu.Lock() defer up.mu.Unlock() for i, l := range up.local { @@ -321,10 +322,10 @@ func (up *rtpUpConnection) delLocal(local downConnection) bool { return false } -func (up *rtpUpConnection) getLocal() []downConnection { +func (up *rtpUpConnection) getLocal() []conn.Down { up.mu.Lock() defer up.mu.Unlock() - local := make([]downConnection, len(up.local)) + local := make([]conn.Down, len(up.local)) copy(local, up.local) return local } @@ -396,10 +397,10 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) { return nil, err } - conn := &rtpUpConnection{id: id, pc: pc} + up := &rtpUpConnection{id: id, pc: pc} pc.OnTrack(func(remote *webrtc.Track, receiver *webrtc.RTPReceiver) { - conn.mu.Lock() + up.mu.Lock() mid := getTrackMid(pc, remote) if mid == "" { @@ -407,7 +408,7 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) { return } - label, ok := conn.labels[mid] + label, ok := up.labels[mid] if !ok { log.Printf("Couldn't get track's label") isvideo := remote.Kind() == webrtc.RTPCodecTypeVideo @@ -428,34 +429,34 @@ func newUpConn(c client, id string) (*rtpUpConnection, error) { readerDone: make(chan struct{}), } - conn.tracks = append(conn.tracks, track) + up.tracks = append(up.tracks, track) - go readLoop(conn, track) + go readLoop(up, track) - go rtcpUpListener(conn, track, receiver) + go rtcpUpListener(up, track, receiver) - complete := conn.complete() - var tracks []upTrack - if(complete) { - tracks = make([]upTrack, len(conn.tracks)) - for i, t := range conn.tracks { + complete := up.complete() + var tracks []conn.UpTrack + if complete { + tracks = make([]conn.UpTrack, len(up.tracks)) + for i, t := range up.tracks { tracks[i] = t } } // pushConn might need to take the lock - conn.mu.Unlock() + up.mu.Unlock() if complete { clients := c.Group().getClients(c) for _, cc := range clients { - cc.pushConn(conn.id, conn, tracks, conn.label) + cc.pushConn(up.id, up, tracks, up.label) } - go rtcpUpSender(conn) + go rtcpUpSender(up) } }) - return conn, nil + return up, nil } func readLoop(conn *rtpUpConnection, track *rtpUpTrack) { @@ -606,7 +607,7 @@ func sendRecovery(p *rtcp.TransportLayerNack, track *rtpDownTrack) { buf := make([]byte, packetcache.BufSize) for _, nack := range p.Nacks { for _, seqno := range nack.PacketList() { - l := track.remote.getRTP(seqno, buf) + l := track.remote.GetRTP(seqno, buf) if l == 0 { continue } @@ -650,7 +651,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei track.srRTPTime = p.RTPTime track.mu.Unlock() for _, l := range local { - l.setTimeOffset(p.NTPTime, p.RTPTime) + l.SetTimeOffset(p.NTPTime, p.RTPTime) } case *rtcp.SourceDescription: for _, c := range p.Chunks { @@ -665,7 +666,7 @@ func rtcpUpListener(conn *rtpUpConnection, track *rtpUpTrack, r *webrtc.RTPRecei track.cname = i.Text track.mu.Unlock() for _, l := range local { - l.setCname(i.Text) + l.SetCname(i.Text) } } } diff --git a/rtpwriter.go b/rtpwriter.go index 3e1d27b..1910cea 100644 --- a/rtpwriter.go +++ b/rtpwriter.go @@ -7,6 +7,7 @@ import ( "github.com/pion/rtp" + "sfu/conn" "sfu/packetcache" "sfu/rtptime" ) @@ -43,7 +44,7 @@ func sqrt(n int) int { } // add adds or removes a track from a writer pool -func (wp *rtpWriterPool) add(track downTrack, add bool) error { +func (wp *rtpWriterPool) add(track conn.DownTrack, add bool) error { n := 4 if wp.count > 16 { n = sqrt(wp.count) @@ -166,7 +167,7 @@ var ErrUnknownTrack = errors.New("unknown track") type writerAction struct { add bool - track downTrack + track conn.DownTrack maxTracks int ch chan error } @@ -192,7 +193,7 @@ func newRtpWriter(conn *rtpUpConnection, track *rtpUpTrack) *rtpWriter { } // add adds or removes a track from a writer. -func (writer *rtpWriter) add(track downTrack, add bool, max int) error { +func (writer *rtpWriter) add(track conn.DownTrack, add bool, max int) error { ch := make(chan error, 1) select { case writer.action <- writerAction{add, track, max, ch}: @@ -208,13 +209,13 @@ func (writer *rtpWriter) add(track downTrack, add bool, max int) error { } // rtpWriterLoop is the main loop of an rtpWriter. -func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack) { +func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { defer close(writer.done) buf := make([]byte, packetcache.BufSize) var packet rtp.Packet - local := make([]downTrack, 0) + local := make([]conn.DownTrack, 0) // reset whenever a new track is inserted firSent := false @@ -239,10 +240,10 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack) cname := track.cname track.mu.Unlock() if ntp != 0 { - action.track.setTimeOffset(ntp, rtp) + action.track.SetTimeOffset(ntp, rtp) } if cname != "" { - action.track.setCname(cname) + action.track.SetCname(cname) } } else { found := false @@ -283,7 +284,7 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack) for _, l := range local { err := l.WriteRTP(&packet) if err != nil { - if err == ErrKeyframeNeeded { + if err == conn.ErrKeyframeNeeded { kfNeeded = true } continue @@ -292,9 +293,9 @@ func rtpWriterLoop(writer *rtpWriter, conn *rtpUpConnection, track *rtpUpTrack) } if kfNeeded { - err := conn.sendFIR(track, !firSent) + err := up.sendFIR(track, !firSent) if err == ErrUnsupportedFeedback { - conn.sendPLI(track) + up.sendPLI(track) } firSent = true } diff --git a/webclient.go b/webclient.go index 7b81959..5e2d62c 100644 --- a/webclient.go +++ b/webclient.go @@ -14,10 +14,11 @@ import ( "sync" "time" - "sfu/estimator" - "github.com/gorilla/websocket" "github.com/pion/webrtc/v3" + + "sfu/conn" + "sfu/estimator" ) var iceConf webrtc.Configuration @@ -300,7 +301,7 @@ func getConn(c *webClient, id string) iceConnection { return nil } -func addDownConn(c *webClient, id string, remote upConnection) (*rtpDownConnection, error) { +func addDownConn(c *webClient, id string, remote conn.Up) (*rtpDownConnection, error) { conn, err := newDownConn(c, id, remote) if err != nil { return nil, err @@ -333,7 +334,7 @@ func addDownConn(c *webClient, id string, remote upConnection) (*rtpDownConnecti } }) - err = remote.addLocal(conn) + err = remote.AddLocal(conn) if err != nil { conn.pc.Close() return nil, err @@ -355,18 +356,18 @@ func delDownConn(c *webClient, id string) bool { return false } - conn.remote.delLocal(conn) + conn.remote.DelLocal(conn) for _, track := range conn.tracks { // we only insert the track after we get an answer, so // ignore errors here. - track.remote.delLocal(track) + track.remote.DelLocal(track) } conn.pc.Close() delete(c.down, id) return true } -func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack upTrack, remoteConn upConnection) (*webrtc.RTPSender, error) { +func addDownTrack(c *webClient, conn *rtpDownConnection, remoteTrack conn.UpTrack, remoteConn conn.Up) (*webrtc.RTPSender, error) { var pt uint8 var ssrc uint32 var id, label string @@ -524,7 +525,7 @@ func gotAnswer(c *webClient, id string, answer webrtc.SessionDescription) error } for _, t := range down.tracks { - t.remote.addLocal(t) + t.remote.AddLocal(t) } return nil } @@ -568,7 +569,7 @@ func (c *webClient) isRequested(label string) bool { return c.requested[label] != 0 } -func addDownConnTracks(c *webClient, remote upConnection, tracks []upTrack) (*rtpDownConnection, error) { +func addDownConnTracks(c *webClient, remote conn.Up, tracks []conn.UpTrack) (*rtpDownConnection, error) { requested := false for _, t := range tracks { if c.isRequested(t.Label()) { @@ -601,13 +602,13 @@ func addDownConnTracks(c *webClient, remote upConnection, tracks []upTrack) (*rt return down, nil } -func (c *webClient) pushConn(id string, conn upConnection, tracks []upTrack, label string) error { - err := c.action(pushConnAction{id, conn, tracks}) +func (c *webClient) pushConn(id string, up conn.Up, tracks []conn.UpTrack, label string) error { + err := c.action(pushConnAction{id, up, tracks}) if err != nil { return err } - if conn != nil && label != "" { - err := c.action(addLabelAction{conn.Id(), conn.Label()}) + if up != nil && label != "" { + err := c.action(addLabelAction{up.Id(), up.Label()}) if err != nil { return err } @@ -726,8 +727,8 @@ func startClient(conn *websocket.Conn) (err error) { type pushConnAction struct { id string - conn upConnection - tracks []upTrack + conn conn.Up + tracks []conn.UpTrack } type addLabelAction struct { @@ -749,9 +750,9 @@ type kickAction struct { message string } -func clientLoop(c *webClient, conn *websocket.Conn) error { +func clientLoop(c *webClient, ws *websocket.Conn) error { read := make(chan interface{}, 1) - go clientReader(conn, read, c.done) + go clientReader(ws, read, c.done) defer func() { c.setRequested(map[string]uint32{}) @@ -848,7 +849,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error { case pushConnsAction: for _, u := range c.up { tracks := u.getTracks() - ts := make([]upTrack, len(tracks)) + ts := make([]conn.UpTrack, len(tracks)) for i, t := range tracks { ts[i] = t } @@ -861,7 +862,7 @@ func clientLoop(c *webClient, conn *websocket.Conn) error { return err } tracks := make( - []upTrack, len(down.tracks), + []conn.UpTrack, len(down.tracks), ) for i, t := range down.tracks { tracks[i] = t.remote