From 1e050fa4e3e0b1082a42f3b3b14e6a540c7b8693 Mon Sep 17 00:00:00 2001 From: Juliusz Chroboczek Date: Mon, 10 Jul 2023 16:24:30 +0200 Subject: [PATCH] Implement the WHIP protocol. --- rtpconn/whipclient.go | 213 ++++++++++++++++++++++ webserver/webserver.go | 10 ++ webserver/webserver_test.go | 78 ++++++++ webserver/whip.go | 343 ++++++++++++++++++++++++++++++++++++ 4 files changed, 644 insertions(+) create mode 100644 rtpconn/whipclient.go create mode 100644 webserver/whip.go diff --git a/rtpconn/whipclient.go b/rtpconn/whipclient.go new file mode 100644 index 0000000..36e2404 --- /dev/null +++ b/rtpconn/whipclient.go @@ -0,0 +1,213 @@ +package rtpconn + +import ( + "context" + "errors" + "sync" + + "github.com/jech/galene/conn" + "github.com/jech/galene/group" + "github.com/pion/webrtc/v3" +) + +type WhipClient struct { + group *group.Group + id string + token string + username string + + mu sync.Mutex + permissions []string + connection *rtpUpConnection +} + +func NewWhipClient(g *group.Group, id string, token string) *WhipClient { + return &WhipClient{group: g, id: id, token: token} +} + +func (c *WhipClient) Group() *group.Group { + return c.group +} + +func (c *WhipClient) Id() string { + return c.id +} + +func (c *WhipClient) Token() string { + return c.token +} + +func (c *WhipClient) Username() string { + return c.username +} + +func (c *WhipClient) SetUsername(username string) { + c.username = username +} + +func (c *WhipClient) Permissions() []string { + c.mu.Lock() + defer c.mu.Unlock() + return c.permissions +} + +func (c *WhipClient) SetPermissions(perms []string) { + c.mu.Lock() + defer c.mu.Unlock() + c.permissions = perms +} + +func (c *WhipClient) Data() map[string]interface{} { + return nil +} + +func (c *WhipClient) PushConn(g *group.Group, id string, conn conn.Up, tracks []conn.UpTrack, replace string) error { + return nil +} + +func (c *WhipClient) RequestConns(target group.Client, g *group.Group, id string) error { + if g != c.group { + return nil + } + + c.mu.Lock() + up := c.connection + c.mu.Unlock() + if up == nil { + return nil + } + tracks := up.getTracks() + ts := make([]conn.UpTrack, len(tracks)) + for i, t := range tracks { + ts[i] = t + } + target.PushConn(g, up.Id(), up, ts, "") + return nil +} + +func (c *WhipClient) Joined(group, kind string) error { + return nil +} + +func (c *WhipClient) PushClient(group, kind, id, username string, permissions []string, status map[string]interface{}) error { + return nil +} + +func (c *WhipClient) Kick(id string, user *string, message string) error { + return c.Close() +} + +func (c *WhipClient) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + g := c.group + if g == nil { + return nil + } + if c.connection != nil { + id := c.connection.Id() + c.connection.pc.OnICEConnectionStateChange(nil) + c.connection.pc.Close() + c.connection = nil + for _, c := range g.GetClients(c) { + c.PushConn(g, id, nil, nil, "") + } + c.connection = nil + } + group.DelClient(c) + c.group = nil + return nil +} + +func (c *WhipClient) NewConnection(ctx context.Context, offer []byte) ([]byte, error) { + conn, err := newUpConn(c, c.id, "", string(offer)) + if err != nil { + return nil, err + } + + conn.pc.OnICEConnectionStateChange( + func(state webrtc.ICEConnectionState) { + switch state { + case webrtc.ICEConnectionStateFailed, + webrtc.ICEConnectionStateClosed: + c.Close() + } + }) + + c.mu.Lock() + defer c.mu.Unlock() + if c.connection != nil { + conn.pc.OnICEConnectionStateChange(nil) + conn.pc.Close() + return nil, errors.New("duplicate connection") + } + c.connection = conn + + answer, err := c.gotOffer(ctx, offer) + if err != nil { + conn.pc.OnICEConnectionStateChange(nil) + conn.pc.Close() + return nil, err + } + + return answer, nil +} + +func (c *WhipClient) GotOffer(ctx context.Context, offer []byte) ([]byte, error) { + c.mu.Lock() + defer c.mu.Unlock() + return c.gotOffer(ctx, offer) +} + +// called locked +func (c *WhipClient) gotOffer(ctx context.Context, offer []byte) ([]byte, error) { + conn := c.connection + err := conn.pc.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: string(offer), + }) + if err != nil { + return nil, err + } + + answer, err := conn.pc.CreateAnswer(nil) + if err != nil { + return nil, err + } + + gatherComplete := webrtc.GatheringCompletePromise(conn.pc) + + err = conn.pc.SetLocalDescription(answer) + if err != nil { + return nil, err + } + + conn.flushICECandidates() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-gatherComplete: + } + + return []byte(conn.pc.CurrentLocalDescription().SDP), nil +} + +func (c *WhipClient) GotICECandidate(candidate, ufrag []byte) error { + zero := uint16(0) + init := webrtc.ICECandidateInit{ + Candidate: string(candidate), + SDPMLineIndex: &zero, + } + if ufrag != nil { + u := string(ufrag) + init.UsernameFragment = &u + } + + c.mu.Lock() + defer c.mu.Unlock() + if c.connection == nil { + return nil + } + return c.connection.addICECandidate(&init) +} diff --git a/webserver/webserver.go b/webserver/webserver.go index 01bab76..ec9518b 100644 --- a/webserver/webserver.go +++ b/webserver/webserver.go @@ -304,6 +304,16 @@ func groupHandler(w http.ResponseWriter, r *http.Request) { return } + dir, id := parseWhip(r.URL.Path) + if dir != "" { + if id == "" { + whipEndpointHandler(w, r) + } else { + whipResourceHandler(w, r) + } + return + } + name := parseGroupName("/group/", r.URL.Path) if name == "" { notFound(w) diff --git a/webserver/webserver_test.go b/webserver/webserver_test.go index f7d36c2..13a263e 100644 --- a/webserver/webserver_test.go +++ b/webserver/webserver_test.go @@ -2,6 +2,8 @@ package webserver import ( "testing" + + "github.com/pion/webrtc/v3" ) func TestParseGroupName(t *testing.T) { @@ -29,3 +31,79 @@ func TestParseGroupName(t *testing.T) { }) } } + +func TestParseWhip(t *testing.T) { + a := []struct{ p, d, b string }{ + {"", "", ""}, + {"/", "", ""}, + {"/foo", "", ""}, + {"/foo/", "", ""}, + {"/foo/bar", "", ""}, + {"/foo/bar/", "", ""}, + {"/foo/bar/baz", "", ""}, + {"/foo/bar/baz/", "", ""}, + {"/foo/.whip", "/foo/", ""}, + {"/foo/.whip/", "/foo/", ""}, + {"/foo/.whip/bar", "/foo/", "bar"}, + {"/foo/.whip/bar/", "/foo/", "bar"}, + {"/foo/.whip/bar/baz", "", ""}, + {"/foo/.whip/bar/baz/", "", ""}, + } + + for _, pdb := range a { + t.Run(pdb.p, func(t *testing.T) { + d, b := parseWhip(pdb.p) + if d != pdb.d || b != pdb.b { + t.Errorf("Path %v, got %v %v, expected %v %v", + pdb.p, d, b, pdb.d, pdb.b) + } + }) + } +} + +func TestFormatICEServer(t *testing.T) { + a := []struct { + s webrtc.ICEServer + v string + }{ + { + webrtc.ICEServer{ + URLs: []string{"stun:stun.example.org:3478"}, + }, "; rel=\"ice-server\"", + }, + { + webrtc.ICEServer{ + URLs: []string{"turn:turn.example.org:3478"}, + Username: "toto", + Credential: "titi", + CredentialType: webrtc.ICECredentialTypePassword, + }, "; rel=\"ice-server\"; " + + "username=\"toto\"; credential=\"titi\"; " + + "credential-type=\"password\"", + }, + { + webrtc.ICEServer{ + URLs: []string{"turns:turn.example.org:5349"}, + Username: "toto", + Credential: "titi", + CredentialType: webrtc.ICECredentialTypePassword, + }, "; rel=\"ice-server\"; " + + "username=\"toto\"; credential=\"titi\"; " + + "credential-type=\"password\"", + }, + { + webrtc.ICEServer{ + URLs: []string{"https://stun.example.org"}, + }, "", + }, + } + + for _, sv := range a { + t.Run(sv.s.URLs[0], func(t *testing.T) { + v := formatICEServer(sv.s, sv.s.URLs[0]) + if v != sv.v { + t.Errorf("Got %v, expected %v", v, sv.v) + } + }) + } +} diff --git a/webserver/whip.go b/webserver/whip.go new file mode 100644 index 0000000..aaac79d --- /dev/null +++ b/webserver/whip.go @@ -0,0 +1,343 @@ +package webserver + +import ( + "bytes" + crand "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/url" + "os" + "path" + "strings" + + "github.com/pion/webrtc/v3" + + "github.com/jech/galene/group" + "github.com/jech/galene/ice" + "github.com/jech/galene/rtpconn" +) + +func parseWhip(pth string) (string, string) { + if pth != "/" { + pth = strings.TrimSuffix(pth, "/") + } + dir := path.Dir(pth) + base := path.Base(pth) + if base == ".whip" { + return dir + "/", "" + } + + if path.Base(dir) == ".whip" { + return path.Dir(dir) + "/", base + } + + return "", "" +} + +func newId() string { + b := make([]byte, 16) + crand.Read(b) + return base64.RawURLEncoding.EncodeToString(b) +} + +const sdpLimit = 1024 * 1024 + +func readLimited(r io.Reader) ([]byte, error) { + v, err := io.ReadAll(io.LimitReader(r, sdpLimit)) + if len(v) == sdpLimit { + err = errors.New("SDP too large") + } + return v, err +} + +func canPresent(perms []string) bool { + for _, p := range perms { + if p == "present" { + return true + } + } + return false +} + +func getBearerToken(r *http.Request) string { + auth := r.Header.Get("Authorization") + auths := strings.Split(auth, ",") + for _, a := range auths { + a = strings.Trim(a, " \t") + s := strings.Split(a, " ") + if len(s) == 2 && strings.EqualFold(s[0], "bearer") { + return s[1] + } + } + return "" +} + +var iceServerReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func formatICEServer(server webrtc.ICEServer, u string) string { + quote := func(s string) string { + return iceServerReplacer.Replace(s) + } + uu, err := url.Parse(u) + if err != nil { + return "" + } + + if strings.EqualFold(uu.Scheme, "stun") { + return fmt.Sprintf("<%v>; rel=\"ice-server\"", u) + } else if strings.EqualFold(uu.Scheme, "turn") || + strings.EqualFold(uu.Scheme, "turns") { + pw, ok := server.Credential.(string) + if !ok { + return "" + } + return fmt.Sprintf("<%v>; rel=\"ice-server\"; "+ + "username=\"%v\"; "+ + "credential=\"%v\"; "+ + "credential-type=\"%v\"", + u, + quote(server.Username), + quote(pw), + quote(server.CredentialType.String())) + } + return "" +} + +func whipICEServers(w http.ResponseWriter) { + conf := ice.ICEConfiguration() + for _, server := range conf.ICEServers { + for _, u := range server.URLs { + v := formatICEServer(server, u) + if v != "" { + w.Header().Add("Link", v) + } + } + } +} + +func whipEndpointHandler(w http.ResponseWriter, r *http.Request) { + if redirect(w, r) { + return + } + + pth, pthid := parseWhip(r.URL.Path) + if pthid != "" { + http.Error(w, "Internal server error", + http.StatusInternalServerError) + return + } + + name := parseGroupName("/group/", pth) + if name == "" { + notFound(w) + return + } + + g, err := group.Add(name, nil) + if err != nil { + if os.IsNotExist(err) { + notFound(w) + return + } + log.Printf("group.Add: %v", err) + http.Error(w, "Internal server error", + http.StatusInternalServerError) + return + } + + conf, err := group.GetConfiguration() + if err != nil { + http.Error(w, "Internal server error", + http.StatusInternalServerError) + return + } + + if conf.PublicServer { + w.Header().Set("Access-Control-Allow-Origin", "*") + } + + if r.Method == "OPTIONS" { + w.Header().Set("Access-Control-Allow-Methods", "OPTIONS, POST") + w.Header().Set("Access-Control-Allow-Headers", + "Authorization, Content-Type", + ) + w.Header().Set("Access-Control-Expose-Headers", "Link") + whipICEServers(w) + return + } + + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + ctype := r.Header.Get("content-type") + if !strings.EqualFold(ctype, "application/sdp") { + http.Error(w, "bad content type", http.StatusBadRequest) + return + } + + body, err := readLimited(r.Body) + if err != nil { + httpError(w, err) + return + } + + token := getBearerToken(r) + whip := "whip" + creds := group.ClientCredentials{ + Username: &whip, + Token: token, + } + + id := newId() + c := rtpconn.NewWhipClient(g, id, token) + + _, err = group.AddClient(g.Name(), c, creds) + if err == group.ErrNotAuthorised || + err == group.ErrAnonymousNotAuthorised { + http.Error(w, "Authentication failed", http.StatusUnauthorized) + return + } else if err != nil { + log.Printf("WHIP: %v", err) + http.Error(w, "Internal Server Error", + http.StatusInternalServerError) + return + } + + if !canPresent(c.Permissions()) { + group.DelClient(c) + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + answer, err := c.NewConnection(r.Context(), body) + if err != nil { + group.DelClient(c) + log.Printf("WHIP offer: %v", err) + http.Error(w, "Internal Server Error", + http.StatusInternalServerError) + } + + w.Header().Set("Location", path.Join(r.URL.Path, id)) + w.Header().Set("Access-Control-Expose-Headers", + "Location, Content-Type, Link") + whipICEServers(w) + w.Header().Set("Content-Type", "application/sdp") + w.WriteHeader(http.StatusCreated) + w.Write(answer) + + return +} + +func whipResourceHandler(w http.ResponseWriter, r *http.Request) { + pth, id := parseWhip(r.URL.Path) + if pth == "" || id == "" { + http.Error(w, "Internal server error", + http.StatusInternalServerError) + return + } + + name := parseGroupName("/group/", pth) + if name == "" { + notFound(w) + return + } + + g := group.Get(name) + if g == nil { + notFound(w) + return + } + + cc := g.GetClient(id) + if cc == nil { + notFound(w) + return + } + + c, ok := cc.(*rtpconn.WhipClient) + if !ok { + notFound(w) + return + } + + if t := c.Token(); t != "" { + token := getBearerToken(r) + if token != t { + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + } + + conf, err := group.GetConfiguration() + if err != nil { + http.Error(w, "Internal server error", + http.StatusInternalServerError) + return + } + + if conf.PublicServer { + w.Header().Set("Access-Control-Allow-Origin", "*") + } + + if r.Method == "OPTIONS" { + w.Header().Set("Access-Control-Allow-Methods", + "OPTIONS, PATCH, DELETE", + ) + w.Header().Set("Access-Control-Allow-Headers", + "Authorization, Content-Type", + ) + return + } + + if r.Method == "DELETE" { + c.Close() + return + } + + if r.Method != "PATCH" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + + } + + ctype := r.Header.Get("content-type") + if !strings.EqualFold(ctype, "application/trickle-ice-sdpfrag") { + http.Error(w, "bad content type", http.StatusBadRequest) + return + } + + body, err := readLimited(r.Body) + if err != nil { + httpError(w, err) + return + } + + if len(body) < 2 { + http.Error(w, "SDP truncated", http.StatusBadRequest) + return + } + + // RFC 8840 + lines := bytes.Split(body, []byte{'\n'}) + var ufrag []byte + for _, l := range lines { + l = bytes.TrimRight(l, " \r") + if bytes.HasPrefix(l, []byte("a=ice-ufrag:")) { + ufrag = l[len("a=ice-ufrag:"):] + } else if bytes.HasPrefix(l, []byte("a=candidate:")) { + err := c.GotICECandidate(l[2:], ufrag) + if err != nil { + log.Printf("WHIP candidate: %v", err) + } + } else if bytes.Equal(l, []byte("a=end-of-candidates")) { + c.GotICECandidate(nil, ufrag) + } + } + w.WriteHeader(http.StatusNoContent) + return +}