diff --git a/group/group.go b/group/group.go index 9b4e992..540f72e 100644 --- a/group/group.go +++ b/group/group.go @@ -26,9 +26,25 @@ var Directory, DataDirectory string var UseMDNS bool var UDPMin, UDPMax uint16 -var ErrNotAuthorised = errors.New("not authorised") -var ErrAnonymousNotAuthorised = errors.New("anonymous users not authorised in this group") -var ErrDuplicateUsername = errors.New("this username is taken") +type NotAuthorisedError struct { + err error +} +func (err *NotAuthorisedError) Error() string { + if err.err != nil { + return "not authorised: " + err.err.Error() + } + return "not authorised" +} +func (err *NotAuthorisedError) Unwrap() error { + return err.err +} + +var ErrAnonymousNotAuthorised = &NotAuthorisedError{ + err: errors.New("anonymous users not authorised in this group"), +} +var ErrDuplicateUsername = &NotAuthorisedError{ + errors.New("this username is taken"), +} type UserError string @@ -935,7 +951,7 @@ func (g *Group) getPasswordPermission(creds ClientCredentials) ([]string, error) } return p, nil } - return nil, ErrNotAuthorised + return nil, &NotAuthorisedError{} } if found, good := matchClient(creds, desc.Presenter); found { if good { @@ -945,7 +961,7 @@ func (g *Group) getPasswordPermission(creds ClientCredentials) ([]string, error) } return p, nil } - return nil, ErrNotAuthorised + return nil, &NotAuthorisedError{} } if found, good := matchClient(creds, desc.Other); found { if good { @@ -955,9 +971,10 @@ func (g *Group) getPasswordPermission(creds ClientCredentials) ([]string, error) } return p, nil } - return nil, ErrNotAuthorised + return nil, &NotAuthorisedError{} + } - return nil, ErrNotAuthorised + return nil, &NotAuthorisedError{} } // Return true if there is a user entry with the given username. @@ -1006,7 +1023,7 @@ func (g *Group) getPermission(creds ClientCredentials) (string, []string, error) username, perms, err = tok.Check(conf.CanonicalHost, g.name, creds.Username) if err != nil { - return "", nil, err + return "", nil, &NotAuthorisedError{err: err} } if username == "" && creds.Username != nil { if g.userExists(*creds.Username) { diff --git a/group/group_test.go b/group/group_test.go index ae94f8b..5b77c0c 100644 --- a/group/group_test.go +++ b/group/group_test.go @@ -2,6 +2,7 @@ package group import ( "encoding/json" + "errors" "fmt" "reflect" "testing" @@ -157,8 +158,9 @@ func TestPermissions(t *testing.T) { for _, c := range badClients { t.Run("bad "+*c.Username, func(t *testing.T) { + var autherr *NotAuthorisedError _, p, err := g.GetPermission(c) - if err != ErrNotAuthorised { + if !errors.As(err, &autherr) { t.Errorf("GetPermission %v: %v %v", c, err, p) } }) diff --git a/rtpconn/webclient.go b/rtpconn/webclient.go index f3e1cfc..3c4e61a 100644 --- a/rtpconn/webclient.go +++ b/rtpconn/webclient.go @@ -1063,7 +1063,7 @@ func pushDownConn(c *webClient, id string, up conn.Up, tracks []conn.UpTrack, re down, _, err := addDownConn(c, up) if err != nil { - if err == os.ErrClosed { + if errors.Is(err, os.ErrClosed) { return nil } return err @@ -1412,21 +1412,22 @@ func handleClientMessage(c *webClient, m clientMessage) error { ) if err != nil { var e, s string + var autherr *group.NotAuthorisedError if os.IsNotExist(err) { s = "group does not exist" - } else if err == group.ErrNotAuthorised { - s = "not authorised" - time.Sleep(200 * time.Millisecond) - } else if err == group.ErrAnonymousNotAuthorised { + } else if errors.Is(err, group.ErrAnonymousNotAuthorised) { s = "please choose a username" - } else if _, ok := err.(group.UserError); ok { - s = err.Error() - } else if err == token.ErrUsernameRequired { + } else if errors.Is(err, token.ErrUsernameRequired) { s = err.Error() e = "need-username" - } else if err == group.ErrDuplicateUsername { + } else if errors.Is(err, group.ErrDuplicateUsername) { s = err.Error() e = "duplicate-username" + } else if errors.As(err, &autherr) { + s = "not authorised" + time.Sleep(200 * time.Millisecond) + } else if _, ok := err.(group.UserError); ok { + s = err.Error() } else { s = "internal server error" log.Printf("Join group: %v", err) diff --git a/webserver/webserver.go b/webserver/webserver.go index 1ce65a8..55c82f0 100644 --- a/webserver/webserver.go +++ b/webserver/webserver.go @@ -132,8 +132,10 @@ func httpError(w http.ResponseWriter, err error) { notFound(w) return } - if os.IsPermission(err) { - http.Error(w, "Forbidden", http.StatusForbidden) + var autherr *group.NotAuthorisedError + if errors.As(err, &autherr) { + log.Printf("HTTP server error: %v", err) + http.Error(w, "not authorised", http.StatusUnauthorized) return } var mberr *http.MaxBytesError @@ -230,7 +232,8 @@ func (fh *fileHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err != nil { // return 403 if index.html doesn't exist if os.IsNotExist(err) { - err = os.ErrPermission + http.Error(w, "Forbidden", http.StatusForbidden) + return } httpError(w, err) return @@ -706,7 +709,8 @@ func checkGroupPermissions(w http.ResponseWriter, r *http.Request, groupname str } } if err != nil || !record { - if err == group.ErrNotAuthorised { + var autherr *group.NotAuthorisedError + if errors.As(err, &autherr) { time.Sleep(200 * time.Millisecond) } return false diff --git a/webserver/whip.go b/webserver/whip.go index 324d7d6..15f69dc 100644 --- a/webserver/whip.go +++ b/webserver/whip.go @@ -219,11 +219,7 @@ func whipEndpointHandler(w http.ResponseWriter, r *http.Request) { c := rtpconn.NewWhipClient(g, id, token) _, err = group.AddClient(g.Name(), c, creds) - if err == group.ErrNotAuthorised || - err == group.ErrAnonymousNotAuthorised { - http.Error(w, "Authentication failed", http.StatusUnauthorized) - return - } else if err != nil { + if err != nil { log.Printf("WHIP: %v", err) httpError(w, err) return