diff --git a/conn/conn.go b/conn/conn.go index 64ce618..cab6fa2 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -4,7 +4,6 @@ package conn import ( "errors" - "github.com/pion/rtp" "github.com/pion/webrtc/v3" ) @@ -38,7 +37,7 @@ type Down interface { // Type DownTrack represents a track in the server to client direction. type DownTrack interface { - WriteRTP(packat *rtp.Packet) error + Write(buf []byte) (int, error) SetTimeOffset(ntp uint64, rtp uint32) SetCname(string) GetMaxBitrate() uint64 diff --git a/diskwriter/diskwriter.go b/diskwriter/diskwriter.go index db8bd2b..2296c2b 100644 --- a/diskwriter/diskwriter.go +++ b/diskwriter/diskwriter.go @@ -338,19 +338,6 @@ func (t *diskTrack) SetTimeOffset(ntp uint64, rtp uint32) { func (t *diskTrack) SetCname(string) { } -func clonePacket(packet *rtp.Packet) *rtp.Packet { - buf, err := packet.Marshal() - if err != nil { - return nil - } - var p rtp.Packet - err = p.Unmarshal(buf) - if err != nil { - return nil - } - return &p -} - func isKeyframe(codec string, data []byte) bool { switch strings.ToLower(codec) { case "video/vp8": @@ -417,20 +404,24 @@ func keyframeDimensions(codec string, data []byte, packet *rtp.Packet) (uint32, } } -func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { +func (t *diskTrack) Write(buf []byte) (int, error) { // since we call initWriter, we take the connection lock for simplicity. t.conn.mu.Lock() defer t.conn.mu.Unlock() if t.builder == nil { - return nil + return 0, nil } codec := t.remote.Codec() - p := clonePacket(packet) - if p == nil { - return nil + data := make([]byte, len(buf)) + copy(data, buf) + p := new(rtp.Packet) + err := p.Unmarshal(data) + if err != nil { + log.Printf("Diskwriter: %v", err) + return 0, nil } if strings.ToLower(codec.MimeType) == "video/vp9" { @@ -459,8 +450,9 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { if sample == nil { if kfNeeded { t.remote.RequestKeyframe() + return 0, nil } - return nil + return len(buf), nil } keyframe := true @@ -479,7 +471,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { t.conn.warn( "Write to disk " + err.Error(), ) - return err + return 0, err } t.lastKf = ts } else if t.writer != nil { @@ -498,7 +490,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { "Write to disk " + err.Error(), ) - return err + return 0, err } } } @@ -508,7 +500,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { if !keyframe { t.remote.RequestKeyframe() } - return nil + return 0, nil } if t.origin == 0 { @@ -519,7 +511,7 @@ func (t *diskTrack) WriteRTP(packet *rtp.Packet) error { tm := ts / (t.remote.Codec().ClockRate / 1000) _, err := t.writer.Write(keyframe, int64(tm), sample.Data) if err != nil { - return err + return 0, err } } } diff --git a/rtpconn/rtpconn.go b/rtpconn/rtpconn.go index 10c6e61..b4ca691 100644 --- a/rtpconn/rtpconn.go +++ b/rtpconn/rtpconn.go @@ -89,15 +89,12 @@ type rtpDownTrack struct { cname atomic.Value } -func (down *rtpDownTrack) WriteRTP(packet *rtp.Packet) error { - err := down.track.WriteRTP(packet) +func (down *rtpDownTrack) Write(buf []byte) (int, error) { + n, err := down.track.Write(buf) if err == nil { - // we should account for extensions - down.rate.Accumulate( - uint32(12 + 4*len(packet.CSRC) + len(packet.Payload)), - ) + down.rate.Accumulate(uint32(n)) } - return err + return n, err } func (down *rtpDownTrack) SetTimeOffset(ntp uint64, rtp uint32) { diff --git a/rtpconn/rtpwriter.go b/rtpconn/rtpwriter.go index bc4166d..81c47e6 100644 --- a/rtpconn/rtpwriter.go +++ b/rtpconn/rtpwriter.go @@ -6,8 +6,6 @@ import ( "sort" "time" - "github.com/pion/rtp" - "github.com/jech/galene/conn" "github.com/jech/galene/packetcache" "github.com/jech/galene/rtptime" @@ -211,17 +209,13 @@ func (writer *rtpWriter) add(track conn.DownTrack, add bool, max int) error { func sendKeyframe(kf []uint16, track conn.DownTrack, cache *packetcache.Cache) { buf := make([]byte, packetcache.BufSize) - var packet rtp.Packet for _, seqno := range kf { bytes := cache.Get(seqno, buf) if bytes == 0 { return } - err := packet.Unmarshal(buf[:bytes]) - if err != nil { - return - } - err = track.WriteRTP(&packet) + + _, err := track.Write(buf[:bytes]) if err != nil { return } @@ -233,8 +227,6 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { defer close(writer.done) buf := make([]byte, packetcache.BufSize) - var packet rtp.Packet - local := make([]conn.DownTrack, 0) for { @@ -310,13 +302,8 @@ func rtpWriterLoop(writer *rtpWriter, up *rtpUpConnection, track *rtpUpTrack) { continue } - err := packet.Unmarshal(buf[:bytes]) - if err != nil { - continue - } - for _, l := range local { - err := l.WriteRTP(&packet) + _, err := l.Write(buf[:bytes]) if err != nil { continue }