From 2831158f099017ddc1ee400083e48378a31141e7 Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Thu, 9 Sep 2021 16:24:14 +0200 Subject: [PATCH] WIP: WHIP support --- galene.go | 3 + rtpconn/rtpconn.go | 64 +++--- rtpconn/webclient.go | 30 +-- webserver/webserver.go | 7 + whip/whip.go | 427 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 489 insertions(+), 42 deletions(-) create mode 100644 whip/whip.go diff --git a/galene.go b/galene.go index bb9c8bf..b544b62 100644 --- a/galene.go +++ b/galene.go @@ -18,6 +18,7 @@ import ( "github.com/jech/galene/limit" "github.com/jech/galene/turnserver" "github.com/jech/galene/webserver" + "github.com/jech/galene/whip" ) func main() { @@ -48,6 +49,8 @@ func main() { "require use of TURN relays for all media traffic") flag.StringVar(&turnserver.Address, "turn", "auto", "built-in TURN server `address` (\"\" to disable)") + flag.BoolVar(&whip.PublicServer, "public-server", false, + "allow browser access from arbitrary origins") flag.Parse() if udpRange != "" { diff --git a/rtpconn/rtpconn.go b/rtpconn/rtpconn.go index ec29355..354afa1 100644 --- a/rtpconn/rtpconn.go +++ b/rtpconn/rtpconn.go @@ -397,7 +397,7 @@ func (down *rtpDownConnection) flushICECandidates() error { type rtpUpTrack struct { track *webrtc.TrackRemote receiver *webrtc.RTPReceiver - conn *rtpUpConnection + conn *UpConn rate *estimator.Estimator cache *packetcache.Cache jitter *jitter.Estimator @@ -504,11 +504,11 @@ func (up *rtpUpTrack) hasRtcpFb(tpe, parameter string) bool { return false } -type rtpUpConnection struct { +type UpConn struct { id string client group.Client label string - pc *webrtc.PeerConnection + PC *webrtc.PeerConnection iceCandidates []*webrtc.ICECandidateInit mu sync.Mutex @@ -519,7 +519,7 @@ type rtpUpConnection struct { local []conn.Down } -func (up *rtpUpConnection) getTracks() []*rtpUpTrack { +func (up *UpConn) getTracks() []*rtpUpTrack { up.mu.Lock() defer up.mu.Unlock() tracks := make([]*rtpUpTrack, len(up.tracks)) @@ -527,7 +527,7 @@ func (up *rtpUpConnection) getTracks() []*rtpUpTrack { return tracks } -func (up *rtpUpConnection) getReplace(reset bool) string { +func (up *UpConn) getReplace(reset bool) string { up.mu.Lock() defer up.mu.Unlock() replace := up.replace @@ -537,19 +537,19 @@ func (up *rtpUpConnection) getReplace(reset bool) string { return replace } -func (up *rtpUpConnection) Id() string { +func (up *UpConn) Id() string { return up.id } -func (up *rtpUpConnection) Label() string { +func (up *UpConn) Label() string { return up.label } -func (up *rtpUpConnection) User() (string, string) { +func (up *UpConn) User() (string, string) { return up.client.Id(), up.client.Username() } -func (up *rtpUpConnection) AddLocal(local conn.Down) error { +func (up *UpConn) AddLocal(local conn.Down) error { up.mu.Lock() defer up.mu.Unlock() // the connection may have been closed in the meantime, in which @@ -566,7 +566,7 @@ func (up *rtpUpConnection) AddLocal(local conn.Down) error { return nil } -func (up *rtpUpConnection) DelLocal(local conn.Down) bool { +func (up *UpConn) DelLocal(local conn.Down) bool { up.mu.Lock() defer up.mu.Unlock() for i, l := range up.local { @@ -578,7 +578,7 @@ func (up *rtpUpConnection) DelLocal(local conn.Down) bool { return false } -func (up *rtpUpConnection) getLocal() []conn.Down { +func (up *UpConn) getLocal() []conn.Down { up.mu.Lock() defer up.mu.Unlock() local := make([]conn.Down, len(up.local)) @@ -586,22 +586,22 @@ func (up *rtpUpConnection) getLocal() []conn.Down { return local } -func (up *rtpUpConnection) addICECandidate(candidate *webrtc.ICECandidateInit) error { - if up.pc.RemoteDescription() != nil { - return up.pc.AddICECandidate(*candidate) +func (up *UpConn) addICECandidate(candidate *webrtc.ICECandidateInit) error { + if up.PC.RemoteDescription() != nil { + return up.PC.AddICECandidate(*candidate) } up.iceCandidates = append(up.iceCandidates, candidate) return nil } -func (up *rtpUpConnection) flushICECandidates() error { - err := flushICECandidates(up.pc, up.iceCandidates) +func (up *UpConn) flushICECandidates() error { + err := flushICECandidates(up.PC, up.iceCandidates) up.iceCandidates = nil return err } // pushConnNow pushes a connection to all of the clients in a group -func pushConnNow(up *rtpUpConnection, g *group.Group, cs []group.Client) { +func pushConnNow(up *UpConn, g *group.Group, cs []group.Client) { up.mu.Lock() up.pushed = true replace := up.replace @@ -617,8 +617,18 @@ func pushConnNow(up *rtpUpConnection, g *group.Group, cs []group.Client) { } } +func (up *UpConn) GetTracks() []conn.UpTrack { + up.mu.Lock() + defer up.mu.Unlock() + ts := make([]conn.UpTrack, len(up.tracks)) + for i, t := range up.tracks { + ts[i] = t + } + return ts +} + // pushConn schedules a call to pushConnNow -func pushConn(up *rtpUpConnection, g *group.Group, cs []group.Client) { +func pushConn(up *UpConn, g *group.Group, cs []group.Client) { up.mu.Lock() up.pushed = false up.mu.Unlock() @@ -635,7 +645,7 @@ func pushConn(up *rtpUpConnection, g *group.Group, cs []group.Client) { }(g, cs) } -func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpConnection, error) { +func NewUpConn(c group.Client, id string, label string, offer string) (*UpConn, error) { var o sdp.SessionDescription err := o.Unmarshal([]byte(offer)) if err != nil { @@ -664,7 +674,7 @@ func newUpConn(c group.Client, id string, label string, offer string) (*rtpUpCon } } - up := &rtpUpConnection{id: id, client: c, label: label, pc: pc} + up := &UpConn{id: id, client: c, label: label, PC: pc} pc.OnTrack(func(remote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { up.mu.Lock() @@ -704,7 +714,7 @@ func (track *rtpUpTrack) sendPLI() error { if !track.hasRtcpFb("nack", "pli") { return ErrUnsupportedFeedback } - return sendPLI(track.conn.pc, track.track.SSRC()) + return sendPLI(track.conn.PC, track.track.SSRC()) } func sendPLI(pc *webrtc.PeerConnection, ssrc webrtc.SSRC) error { @@ -718,7 +728,7 @@ func (track *rtpUpTrack) sendNACK(first uint16, bitmap uint16) error { return ErrUnsupportedFeedback } - err := sendNACKs(track.conn.pc, track.track.SSRC(), + err := sendNACKs(track.conn.PC, track.track.SSRC(), []rtcp.NackPair{{first, rtcp.PacketBitmap(bitmap)}}, ) if err == nil { @@ -748,7 +758,7 @@ func (track *rtpUpTrack) sendNACKs(seqnos []uint16) error { f, b, seqnos = packetcache.ToBitmap(seqnos) nacks = append(nacks, rtcp.NackPair{f, rtcp.PacketBitmap(b)}) } - err := sendNACKs(track.conn.pc, track.track.SSRC(), nacks) + err := sendNACKs(track.conn.PC, track.track.SSRC(), nacks) if err == nil { track.cache.Expect(count) } @@ -931,11 +941,11 @@ func maxUpBitrate(t *rtpUpTrack) uint64 { return maxrate } -func sendUpRTCP(up *rtpUpConnection) error { +func sendUpRTCP(up *UpConn) error { tracks := up.getTracks() if len(up.tracks) == 0 { - state := up.pc.ConnectionState() + state := up.PC.ConnectionState() if state == webrtc.PeerConnectionStateClosed { return io.ErrClosedPipe } @@ -1016,10 +1026,10 @@ func sendUpRTCP(up *rtpUpConnection) error { }, ) } - return up.pc.WriteRTCP(packets) + return up.PC.WriteRTCP(packets) } -func rtcpUpSender(conn *rtpUpConnection) { +func rtcpUpSender(conn *UpConn) { for { time.Sleep(time.Second) err := sendUpRTCP(conn) diff --git a/rtpconn/webclient.go b/rtpconn/webclient.go index a0a5792..a385a91 100644 --- a/rtpconn/webclient.go +++ b/rtpconn/webclient.go @@ -66,7 +66,7 @@ type webClient struct { mu sync.Mutex down map[string]*rtpDownConnection - up map[string]*rtpUpConnection + up map[string]*UpConn // action may be called with the group mutex taken, and therefore // actions needs to use its own mutex. @@ -131,7 +131,7 @@ type closeMessage struct { data []byte } -func getUpConn(c *webClient, id string) *rtpUpConnection { +func getUpConn(c *webClient, id string) *UpConn { c.mu.Lock() defer c.mu.Unlock() @@ -141,22 +141,22 @@ func getUpConn(c *webClient, id string) *rtpUpConnection { return c.up[id] } -func getUpConns(c *webClient) []*rtpUpConnection { +func getUpConns(c *webClient) []*UpConn { c.mu.Lock() defer c.mu.Unlock() - up := make([]*rtpUpConnection, 0, len(c.up)) + up := make([]*UpConn, 0, len(c.up)) for _, u := range c.up { up = append(up, u) } return up } -func addUpConn(c *webClient, id, label string, offer string) (*rtpUpConnection, bool, error) { +func addUpConn(c *webClient, id, label string, offer string) (*UpConn, bool, error) { c.mu.Lock() defer c.mu.Unlock() if c.up == nil { - c.up = make(map[string]*rtpUpConnection) + c.up = make(map[string]*UpConn) } if c.down != nil && c.down[id] != nil { return nil, false, errors.New("Adding duplicate connection") @@ -167,18 +167,18 @@ func addUpConn(c *webClient, id, label string, offer string) (*rtpUpConnection, return old, false, nil } - conn, err := newUpConn(c, id, label, offer) + conn, err := NewUpConn(c, id, label, offer) if err != nil { return nil, false, err } c.up[id] = conn - conn.pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + conn.PC.OnICECandidate(func(candidate *webrtc.ICECandidate) { sendICE(c, id, candidate) }) - conn.pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { + conn.PC.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { if state == webrtc.ICEConnectionStateFailed { c.action(connectionFailedAction{id: id}) } @@ -220,7 +220,7 @@ func delUpConn(c *webClient, id string, userId string, push bool) error { conn.closed = true conn.mu.Unlock() - conn.pc.Close() + conn.PC.Close() if push && g != nil { for _, c := range g.GetClients(c) { @@ -575,7 +575,7 @@ func gotOffer(c *webClient, id, label string, sdp string, replace string) error delUpConn(c, replace, c.Id(), false) } - err = up.pc.SetRemoteDescription(webrtc.SessionDescription{ + err = up.PC.SetRemoteDescription(webrtc.SessionDescription{ Type: webrtc.SDPTypeOffer, SDP: sdp, }) @@ -583,12 +583,12 @@ func gotOffer(c *webClient, id, label string, sdp string, replace string) error return err } - answer, err := up.pc.CreateAnswer(nil) + answer, err := up.PC.CreateAnswer(nil) if err != nil { return err } - err = up.pc.SetLocalDescription(answer) + err = up.PC.SetLocalDescription(answer) if err != nil { return err } @@ -601,7 +601,7 @@ func gotOffer(c *webClient, id, label string, sdp string, replace string) error return c.write(clientMessage{ Type: "answer", Id: id, - SDP: up.pc.LocalDescription().SDP, + SDP: up.PC.LocalDescription().SDP, }) } @@ -713,7 +713,7 @@ func (c *webClient) setRequested(requested map[string][]string) error { func (c *webClient) setRequestedStream(down *rtpDownConnection, requested []string) error { var remoteClient group.Client - remote, ok := down.remote.(*rtpUpConnection) + remote, ok := down.remote.(*UpConn) if ok { remoteClient = remote.client } diff --git a/webserver/webserver.go b/webserver/webserver.go index 537347b..f57a1e6 100644 --- a/webserver/webserver.go +++ b/webserver/webserver.go @@ -26,6 +26,7 @@ import ( "github.com/jech/galene/group" "github.com/jech/galene/rtpconn" "github.com/jech/galene/stats" + "github.com/jech/galene/whip" ) var server atomic.Value @@ -37,6 +38,7 @@ var Insecure bool func Serve(address string, dataDir string) error { http.Handle("/", &fileHandler{http.Dir(StaticRoot)}) http.HandleFunc("/group/", groupHandler) + http.HandleFunc("/whip/", whip.Handler) http.HandleFunc("/recordings", func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, @@ -316,6 +318,11 @@ func groupHandler(w http.ResponseWriter, r *http.Request) { return } + if r.Method == "POST" || r.Method == "OPTIONS" { + whip.Endpoint(g, w, r) + return + } + cspHeader(w) serveFile(w, r, filepath.Join(StaticRoot, "galene.html")) } diff --git a/whip/whip.go b/whip/whip.go new file mode 100644 index 0000000..3c68d4c --- /dev/null +++ b/whip/whip.go @@ -0,0 +1,427 @@ +package whip + +import ( + "bytes" + crand "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net/http" + "os" + "path" + "strings" + "sync" + + "github.com/jech/galene/conn" + "github.com/jech/galene/group" + "github.com/jech/galene/rtpconn" + "github.com/pion/webrtc/v3" +) + +var PublicServer bool + +type Client struct { + group *group.Group + id string + username string + + mu sync.Mutex + permissions group.ClientPermissions + connection *rtpconn.UpConn +} + +func (c *Client) Group() *group.Group { + return c.group +} + +func (c *Client) Id() string { + return c.id +} + +func (c *Client) Username() string { + return c.username +} + +func (c *Client) Permissions() group.ClientPermissions { + c.mu.Lock() + defer c.mu.Unlock() + return c.permissions +} + +func (c *Client) SetPermissions(perms group.ClientPermissions) { + c.mu.Lock() + defer c.mu.Unlock() + c.permissions = perms +} + +func (c *Client) Status() map[string]interface{} { + return nil +} + +func (c *Client) PushConn(g *group.Group, id string, conn conn.Up, tracks []conn.UpTrack, replace string) error { + return nil +} + +func (c *Client) RequestConns(target group.Client, g *group.Group, id string) error { + if g != c.group { + return nil + } + + c.mu.Lock() + conn := c.connection + c.mu.Unlock() + + if conn == nil { + return nil + } + target.PushConn(g, conn.Id(), conn, conn.GetTracks(), "") + return nil +} + +func (c *Client) Joined(group, kind string) error { + return nil +} + +func (c *Client) PushClient(group, kind, id, username string, permissions group.ClientPermissions, status map[string]interface{}) error { + return nil +} + +func (c *Client) Kick(id, user, message string) error { + return c.close() +} + +func (c *Client) conn() *rtpconn.UpConn { + c.mu.Lock() + defer c.mu.Unlock() + return c.connection +} + +func (c *Client) close() error { + c.mu.Lock() + defer c.mu.Unlock() + g := c.group + if g == nil { + return nil + } + if c.connection != nil { + id := c.connection.Id() + c.connection.PC.OnConnectionStateChange(nil) + c.connection.PC.Close() + c.connection = nil + for _, c := range g.GetClients(c) { + c.PushConn(g, id, nil, nil, "") + } + c.connection = nil + } + group.DelClient(c) + c.group = nil + return nil +} + +func newId() string { + b := make([]byte, 16) + crand.Read(b) + return hex.EncodeToString(b) +} + +func httpError(w http.ResponseWriter, err error) { + if os.IsNotExist(err) { + http.Error(w, "404 not found", http.StatusNotFound) + return + } + if os.IsPermission(err) { + http.Error(w, "403 forbidden", http.StatusForbidden) + return + } + log.Printf("WHIP: %v", err) + http.Error(w, "500 Internal Server Error", + http.StatusInternalServerError) + return +} + +const sdpLimit = 1024 * 1024 + +func readLimited(r io.Reader) ([]byte, error) { + v, err := ioutil.ReadAll(io.LimitReader(r, sdpLimit)) + if len(v) == sdpLimit { + err = errors.New("SDP too large") + } + return v, err +} + +func Endpoint(g *group.Group, w http.ResponseWriter, r *http.Request) { + if PublicServer { + w.Header().Set("Access-Control-Allow-Origin", "*") + } + + if r.Method == "OPTIONS" { + w.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST") + w.Header().Set("Access-Control-Allow-Headers", + "Authorization, Content-Type", + ) + return + } + + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + ctype := r.Header.Get("content-type") + if !strings.EqualFold(ctype, "application/sdp") { + http.Error(w, "bad content type", http.StatusBadRequest) + return + } + + body, err := readLimited(r.Body) + if err != nil { + httpError(w, err) + return + } + + id := newId() + c := &Client{ + group: g, + id: id, + } + + username, password, _ := r.BasicAuth() + c.username = username + creds := group.ClientCredentials{ + Username: username, + Password: password, + } + + _, err = group.AddClient(g.Name(), c, creds) + if err == group.ErrNotAuthorised || + err == group.ErrAnonymousNotAuthorised { + w.Header().Set("www-authenticate", + fmt.Sprintf("basic realm=\"%v\"", + path.Join("/whip/", g.Name()))) + http.Error(w, "Authentication failed", http.StatusUnauthorized) + return + } else if err != nil { + log.Printf("WHIP: %v", err) + http.Error(w, + "Internal Server Error", http.StatusInternalServerError, + ) + return + } + + if !c.Permissions().Present { + http.Error(w, "Not authorised", http.StatusUnauthorized) + return + } + + conn, err := rtpconn.NewUpConn(c, id, "", string(body)) + if err != nil { + group.DelClient(c) + httpError(w, err) + return + } + + conn.PC.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { + switch state { + case webrtc.ICEConnectionStateFailed, + webrtc.ICEConnectionStateClosed: + c.close() + } + }) + + c.mu.Lock() + c.connection = conn + c.mu.Unlock() + + sdp, err := gotOffer(conn, body) + if err != nil { + group.DelClient(c) + httpError(w, err) + return + } + + w.Header().Set("Location", path.Join("/whip/", path.Join(g.Name(), id))) + w.Header().Set("Access-Control-Expose-Headers", "Location, Content-Type") + w.Header().Set("Content-Type", "application/sdp") + w.WriteHeader(http.StatusCreated) + w.Write(sdp) + return +} + +func Handler(w http.ResponseWriter, r *http.Request) { + p := path.Dir(r.URL.Path) + id := path.Base(r.URL.Path) + + if p[:6] != "/whip/" { + httpError(w, errors.New("bad URL")) + return + } + name := p[6:] + + g := group.Get(name) + if g == nil { + http.Error(w, "404 not found", http.StatusNotFound) + return + } + + cc := g.GetClient(id) + if cc == nil { + http.Error(w, "404 not found", http.StatusNotFound) + return + } + + c, ok := cc.(*Client) + if !ok { + httpError(w, errors.New("unexpected type for client")) + return + } + + if PublicServer { + w.Header().Set("Access-Control-Allow-Origin", "*") + } + + if r.Method == "OPTIONS" { + w.Header().Set("Access-Control-Allow-Methods", + "GET, HEAD, PATCH, DELETE", + ) + w.Header().Set("Access-Control-Allow-Headers", + "Authorization, Content-Type", + ) + return + } + + username, password, _ := r.BasicAuth() + if username != c.username { + http.Error(w, "Client changed username", http.StatusUnauthorized) + return + } + creds := group.ClientCredentials{ + Username: username, + Password: password, + } + perms, err := g.Description().GetPermission(name, creds) + if err == group.ErrNotAuthorised || + err == group.ErrAnonymousNotAuthorised { + w.Header().Set("www-authenticate", + fmt.Sprintf("basic realm=\"%v\"", + path.Join("/whip/", g.Name()))) + http.Error(w, "Authentication failed", http.StatusUnauthorized) + return + } else if err != nil { + log.Printf("WHIP: %v", err) + http.Error(w, + "Internal Server Error", http.StatusInternalServerError, + ) + return + } + if !perms.Present { + http.Error(w, "Not authorised", http.StatusUnauthorized) + return + } + + if r.Method == "DELETE" { + c.close() + return + } + + if r.Method != "PATCH" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + ctype := r.Header.Get("content-type") + if !strings.EqualFold(ctype, "application/trickle-ice-sdpfrag") { + http.Error(w, "bad content type", http.StatusBadRequest) + return + } + + conn := c.conn() + if conn == nil { + http.Error(w, "connection closed", http.StatusNotFound) + return + } + + body, err := readLimited(r.Body) + if err != nil { + httpError(w, err) + return + } + + if len(body) < 2 { + http.Error(w, "SDP truncated", http.StatusBadRequest) + return + } + + if string(body[:2]) == "v=" { + answer, err := gotOffer(conn, body) + if err != nil { + httpError(w, err) + return + } + + w.Header().Set("Content-Type", "application/sdp") + w.Write(answer) + return + } + + // RFC 8840 + lines := bytes.Split(body, []byte{'\n'}) + var ufrag []byte + for _, l := range lines { + l = bytes.TrimRight(l, " \r") + if bytes.HasPrefix(l, []byte("a=ice-ufrag:")) { + ufrag = l[len("a=ice-ufrag:"):] + } else if bytes.HasPrefix(l, []byte("a=candidate:")) { + err := gotCandidate(conn, l[2:], ufrag) + if err != nil { + log.Printf("WHIP candidate: %v", err) + } + } else if bytes.Equal(l, []byte("a=end-of-candidates")) { + gotCandidate(conn, nil, ufrag) + } + } + w.WriteHeader(http.StatusNoContent) + return +} + +func gotOffer(conn *rtpconn.UpConn, offer []byte) ([]byte, error) { + err := conn.PC.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: string(offer), + }) + if err != nil { + return nil, err + } + + answer, err := conn.PC.CreateAnswer(nil) + if err != nil { + return nil, err + } + + gatherComplete := webrtc.GatheringCompletePromise(conn.PC) + + err = conn.PC.SetLocalDescription(answer) + if err != nil { + return nil, err + } + + <-gatherComplete + + return []byte(conn.PC.CurrentLocalDescription().SDP), nil +} + +func gotCandidate(conn *rtpconn.UpConn, candidate, ufrag []byte) error { + zero := uint16(0) + init := webrtc.ICECandidateInit{ + Candidate: string(candidate), + SDPMLineIndex: &zero, + } + if ufrag != nil { + u := string(ufrag) + init.UsernameFragment = &u + } + + err := conn.PC.AddICECandidate(init) + return err +}