diff --git a/galene.go b/galene.go index bb9c8bf..a884c72 100644 --- a/galene.go +++ b/galene.go @@ -16,6 +16,7 @@ import ( "github.com/jech/galene/group" "github.com/jech/galene/ice" "github.com/jech/galene/limit" + "github.com/jech/galene/token" "github.com/jech/galene/turnserver" "github.com/jech/galene/webserver" ) @@ -111,6 +112,12 @@ func main() { } ice.ICEFilename = filepath.Join(group.DataDirectory, "ice-servers.json") + token.SetStatefulFilename( + filepath.Join( + filepath.Join(group.DataDirectory, "var"), + "tokens.jsonl", + ), + ) // make sure the list of public groups is updated early go group.Update() diff --git a/group/group.go b/group/group.go index 3eb43ae..7d6e8d2 100644 --- a/group/group.go +++ b/group/group.go @@ -27,6 +27,7 @@ var UDPMin, UDPMax uint16 var ErrNotAuthorised = errors.New("not authorised") var ErrAnonymousNotAuthorised = errors.New("anonymous users not authorised in this group") +var ErrDuplicateUsername = errors.New("this username is taken") type UserError string @@ -1136,6 +1137,33 @@ func (g *Group) getPasswordPermission(creds ClientCredentials) ([]string, error) return nil, ErrNotAuthorised } +// Return true if there is a user entry with the given username. +// Always return false for an empty username. +func (g *Group) UserExists(username string) bool { + g.mu.Lock() + defer g.mu.Unlock() + return g.userExists(username) +} + +// called locked +func (g *Group) userExists(username string) bool { + if username == "" { + return false + } + + desc := g.description + for _, ps := range [][]ClientPattern{ + desc.Op, desc.Presenter, desc.Other, + } { + for _, p := range ps { + if p.Username == username { + return true + } + } + } + return false +} + // called locked func (g *Group) getPermission(creds ClientCredentials) (string, []string, error) { desc := g.description @@ -1157,6 +1185,12 @@ func (g *Group) getPermission(creds ClientCredentials) (string, []string, error) if err != nil { return "", nil, err } + if username == "" && creds.Username != nil { + if g.userExists(*creds.Username) { + return "", nil, ErrDuplicateUsername + } + username = *creds.Username + } } else if creds.Username != nil { username = *creds.Username var err error diff --git a/group/group_test.go b/group/group_test.go index b6e5f07..7be200b 100644 --- a/group/group_test.go +++ b/group/group_test.go @@ -179,6 +179,27 @@ func TestPermissions(t *testing.T) { } +func TestUsernameTaken(t *testing.T) { + var g Group + err := json.Unmarshal([]byte(descJSON), &g.description) + if err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if g.UserExists("") { + t.Error("UserExists(\"\") is true, expected false") + } + if !g.UserExists("john") { + t.Error("UserExists(john) is false") + } + if !g.UserExists("john") { + t.Error("UserExists(james) is false") + } + if g.UserExists("paul") { + t.Error("UserExists(paul) is true") + } +} + func TestFmtpValue(t *testing.T) { type fmtpTest struct { fmtp string diff --git a/token/stateful.go b/token/stateful.go new file mode 100644 index 0000000..dfa3636 --- /dev/null +++ b/token/stateful.go @@ -0,0 +1,357 @@ +package token + +import ( + "encoding/json" + "errors" + "io" + "os" + "path/filepath" + "sort" + "sync" + "time" +) + +// A stateful token +type Stateful struct { + Token string `json:"token"` + Group string `json:"group"` + Username *string `json:"username,omitempty"` + Permissions []string `json:"permissions"` + Expires *time.Time `json:"expires"` + NotBefore *time.Time `json:"not-before,omitempty"` +} + +func (token *Stateful) Clone() *Stateful { + return &Stateful{ + Token: token.Token, + Group: token.Group, + Username: token.Username, + Permissions: append([]string(nil), token.Permissions...), + Expires: token.Expires, + NotBefore: token.NotBefore, + } +} + +// A set of stateful tokens, kept in sync with a JSONL representation in +// a file. The synchronisation is slightly racy, so both reading and +// modifying tokens are protected by a mutex. +type state struct { + filename string + mu sync.Mutex + fileSize int64 + modTime time.Time + tokens map[string]*Stateful +} + +var tokens state + +func SetStatefulFilename(filename string) { + tokens.mu.Lock() + defer tokens.mu.Unlock() + tokens.filename = filename + tokens.fileSize = 0 + tokens.modTime = time.Time{} +} + +func getStateful(token string) (Token, error) { + tokens.mu.Lock() + defer tokens.mu.Unlock() + err := tokens.load() + if err != nil { + return nil, err + } + if tokens.tokens == nil { + return nil, nil + } + return tokens.tokens[token], nil +} + +func (token *Stateful) Check(host, group string, username *string) (string, []string, error) { + if token.Group == "" || group != token.Group { + return "", nil, errors.New("token for bad group") + } + now := time.Now() + if token.Expires == nil || now.After(*token.Expires) { + return "", nil, errors.New("token has expired") + } + if token.NotBefore != nil && now.Before(*token.NotBefore) { + return "", nil, errors.New("token is in the future") + } + + // the username from the token overrides the one from the client. + user := "" + if token.Username != nil { + user = *token.Username + } else if username == nil { + return "", nil, ErrUsernameRequired + } + + return user, token.Permissions, nil +} + +// called locked +func (state *state) load() error { + if state.filename == "" { + state.modTime = time.Time{} + state.tokens = nil + return nil + } + + fi, err := os.Stat(state.filename) + if err != nil { + state.modTime = time.Time{} + state.fileSize = 0 + state.tokens = nil + if os.IsNotExist(err) { + return nil + } + return err + } + + if state.modTime.Equal(fi.ModTime()) && + state.fileSize == fi.Size() { + return nil + } + + f, err := os.Open(state.filename) + if err != nil { + state.modTime = time.Time{} + state.fileSize = 0 + state.tokens = nil + if os.IsNotExist(err) { + return nil + } + return err + } + + defer f.Close() + + ts := make(map[string]*Stateful) + decoder := json.NewDecoder(f) + for { + var t Stateful + err := decoder.Decode(&t) + if err == io.EOF { + break + } else if err != nil { + state.modTime = time.Time{} + state.fileSize = 0 + return err + } + ts[t.Token] = &t + } + state.tokens = ts + fi, err = f.Stat() + if err != nil { + state.modTime = time.Time{} + state.fileSize = 0 + state.tokens = nil + if os.IsNotExist(err) { + return nil + } + return err + } + state.modTime = fi.ModTime() + state.fileSize = fi.Size() + return nil +} + +func (state *state) Add(token *Stateful) (*Stateful, error) { + tokens.mu.Lock() + defer tokens.mu.Unlock() + + if state.filename == "" { + return nil, os.ErrNotExist + } + + err := state.load() + if err != nil { + return nil, err + } + + if state.tokens != nil { + if _, ok := state.tokens[token.Token]; ok { + return nil, os.ErrExist + } + } + + err = os.MkdirAll(filepath.Dir(state.filename), 0700) + if err != nil { + return nil, err + } + f, err := os.OpenFile(state.filename, + os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600, + ) + if err != nil { + return nil, err + } + defer f.Close() + + encoder := json.NewEncoder(f) + err = encoder.Encode(token) + if err != nil { + return nil, err + } + + if state.tokens == nil { + state.tokens = make(map[string]*Stateful) + } + state.tokens[token.Token] = token.Clone() + + fi, err := f.Stat() + if err != nil { + state.modTime = fi.ModTime() + state.fileSize = fi.Size() + } + return token, nil +} + +func Add(token *Stateful) (*Stateful, error) { + return tokens.Add(token) +} + +func (state *state) Del(group, token string) error { + tokens.mu.Lock() + defer tokens.mu.Unlock() + _, err := state.edit(group, token, nil) + return err +} + +func Del(group, token string) error { + return tokens.Del(group, token) +} + +func (state *state) Edit(group, token string, expires time.Time) (*Stateful, error) { + tokens.mu.Lock() + defer tokens.mu.Unlock() + return state.edit(group, token, &expires) +} + +func Edit(group, token string, expires time.Time) (*Stateful, error) { + return tokens.Edit(group, token, expires) +} + +// called locked +func (state *state) edit(group, token string, expires *time.Time) (*Stateful, error) { + if state.tokens == nil { + return nil, os.ErrNotExist + } + + old := state.tokens[token] + if old == nil { + return nil, os.ErrNotExist + } + if old.Group != group { + return nil, os.ErrPermission + } + var new *Stateful + if expires.Equal(time.Time{}) { + delete(state.tokens, token) + } else { + new = old.Clone() + new.Expires = expires + state.tokens[token] = new + } + err := state.rewrite() + if err != nil { + state.tokens[token] = old + return nil, err + } + return new, err +} + +// called locked +func (state *state) rewrite() error { + if state.tokens == nil || len(state.tokens) == 0 { + err := os.Remove(state.filename) + if err == nil || os.IsNotExist(err) { + return nil + } + return err + } + + dir := filepath.Dir(state.filename) + tmpfile, err := os.CreateTemp(dir, "tokens") + if err != nil { + return err + } + a, err := state.list("") + if err != nil { + os.Remove(tmpfile.Name()) + return err + } + encoder := json.NewEncoder(tmpfile) + for _, t := range a { + err := encoder.Encode(t) + if err != nil { + tmpfile.Close() + os.Remove(tmpfile.Name()) + return err + } + } + + err = tmpfile.Close() + if err != nil { + os.Remove(tmpfile.Name()) + return err + } + + err = os.Rename(tmpfile.Name(), state.filename) + if err != nil { + os.Remove(tmpfile.Name()) + return err + } + + fi, err := os.Stat(state.filename) + if err == nil { + state.modTime = fi.ModTime() + state.fileSize = fi.Size() + } else { + // force rereading next time + state.modTime = time.Time{} + state.fileSize = 0 + } + + return nil +} + +// called locked +func (state *state) list(group string) ([]*Stateful, error) { + err := state.load() + if err != nil { + return nil, err + } + + a := make([]*Stateful, 0) + if state.tokens == nil { + return a, nil + } + for _, t := range state.tokens { + if group != "" { + if t.Group != group { + continue + } + } + a = append(a, t) + } + sort.Slice(a, func(i, j int) bool { + if a[j].Expires == nil { + return false + } + if a[i].Expires == nil { + return true + } + return (*a[i].Expires).Before(*a[j].Expires) + }) + return a, nil +} + +func (state *state) List(group string) ([]*Stateful, error) { + state.mu.Lock() + defer state.mu.Unlock() + return state.list(group) +} + +func List(group string) ([]*Stateful, error) { + return tokens.List(group) +} diff --git a/token/stateful_test.go b/token/stateful_test.go new file mode 100644 index 0000000..4ba8212 --- /dev/null +++ b/token/stateful_test.go @@ -0,0 +1,294 @@ +package token + +import ( + "encoding/json" + "io" + "os" + "path/filepath" + "reflect" + "sort" + "testing" + "time" +) + +func equal(a, b *Stateful) bool { + if a.Token != b.Token || a.Group != b.Group || + !reflect.DeepEqual(a.Username, b.Username) || + !reflect.DeepEqual(a.Permissions, b.Permissions) { + return false + } + if a.Expires != nil && b.Expires != nil { + return (*a.Expires).Equal(*b.Expires) + } + if (a.Expires != nil) != (b.Expires != nil) { + return false + } + + if a.NotBefore != nil && b.NotBefore != nil { + return (*a.NotBefore).Equal(*b.NotBefore) + } + return (a.NotBefore != nil) == (b.NotBefore != nil) +} + +func TestStatefulCheck(t *testing.T) { + now := time.Now() + past := now.Add(-time.Hour) + nearFuture := now.Add(time.Hour / 2) + future := now.Add(time.Hour) + user := "user" + user2 := "user2" + token1 := &Stateful{ + Token: "token", + Group: "group", + Username: &user, + Permissions: []string{"present"}, + Expires: &future, + } + token2 := &Stateful{ + Token: "token", + Group: "group", + Permissions: []string{"present"}, + Expires: &future, + } + token3 := &Stateful{ + Token: "token", + Group: "group", + Username: &user, + Permissions: []string{"present"}, + Expires: &past, + } + token4 := &Stateful{ + Token: "token", + Group: "group", + Username: &user, + Permissions: []string{"present"}, + Expires: &future, + NotBefore: &nearFuture, + } + + success := []struct { + token *Stateful + group string + username *string + expUsername string + expPermissions []string + }{ + { + token: token1, + group: "group", + username: &user, + expUsername: user, + expPermissions: []string{"present"}, + }, + { + token: token1, + group: "group", + username: &user2, + expUsername: user, + expPermissions: []string{"present"}, + }, + { + token: token1, + group: "group", + expUsername: user, + expPermissions: []string{"present"}, + }, + { + token: token2, + group: "group", + username: &user, + expUsername: "", + expPermissions: []string{"present"}, + }, + } + + for i, s := range success { + u, p, err := s.token.Check("", s.group, s.username) + if err != nil || u != s.expUsername || + !reflect.DeepEqual(p, s.expPermissions) { + t.Errorf("Check %v failed: %v %v %v -> %v %v %v", + i, s.token, s.group, s.username, + u, p, err) + } + } + + failure := []struct { + token *Stateful + group string + username *string + }{ + { + token: token1, + group: "group2", + username: &user, + }, + { + token: token3, + group: "group", + username: &user, + }, + { + token: token4, + group: "group", + username: &user, + }, + } + + for i, s := range failure { + u, p, err := s.token.Check("", s.group, s.username) + if err == nil { + t.Errorf("Check %v succeded: %v %v %v -> %v %v %v", + i, s.token, s.group, s.username, + u, p, err) + } + } +} + +func readTokenFile(filename string) []*Stateful { + f, err := os.Open(filename) + if err != nil { + panic(err) + } + defer f.Close() + + a := make([]*Stateful, 0) + decoder := json.NewDecoder(f) + for { + var t Stateful + err := decoder.Decode(&t) + if err == io.EOF { + break + } else if err != nil { + panic(err) + } + a = append(a, &t) + } + return a +} + +func expectTokenArray(t *testing.T, a, b []*Stateful) { + if len(a) != len(b) { + t.Errorf("Bad length: %v != %v", len(a), len(b)) + } + aa := append([]*Stateful(nil), a...) + sort.Slice(aa, func(i, j int) bool { + return aa[i].Token < aa[j].Token + }) + bb := append([]*Stateful(nil), b...) + sort.Slice(bb, func(i, j int) bool { + return bb[i].Token < bb[j].Token + }) + + if len(aa) != len(bb) { + t.Errorf("Not equal: %v != %v", len(aa), len(bb)) + } + + for i, ta := range aa { + tb := bb[i] + if !equal(ta, tb) { + t.Errorf("Not equal: %v != %v", ta, tb) + } + } +} + +func expectTokens(t *testing.T, tokens map[string]*Stateful, value []*Stateful) { + a := make([]*Stateful, 0, len(tokens)) + for tok, token := range tokens { + if tok != token.Token { + t.Errorf("Inconsistent token: %v != %v", + tok, token.Token) + } + a = append(a, token) + } + expectTokenArray(t, a, value) +} + +func expectTokenFile(t *testing.T, filename string, value []*Stateful) { + a := readTokenFile(filename) + expectTokenArray(t, a, value) +} + +func TestTokenStorage(t *testing.T) { + d := t.TempDir() + s := state{ + filename: filepath.Join(d, "test.jsonl"), + } + now := time.Now() + past := now.Add(-time.Hour) + nearFuture := now.Add(time.Hour / 2) + future := now.Add(time.Hour) + user1 := "user1" + user2 := "user2" + user3 := "user3" + tokens := []*Stateful{ + &Stateful{ + Token: "tok1", + Group: "test", + Username: &user1, + Permissions: []string{"present"}, + Expires: &future, + }, + &Stateful{ + Token: "tok2", + Group: "test", + Username: &user2, + Permissions: []string{"present", "record"}, + Expires: &nearFuture, + NotBefore: &past, + }, + &Stateful{ + Token: "tok3", + Group: "test", + Username: &user3, + Permissions: []string{"present"}, + Expires: &nearFuture, + }, + } + for i, token := range tokens { + new, err := s.Add(token) + if err != nil { + t.Errorf("Add: %v", err) + } + if !equal(new, token) { + t.Errorf("Add: got %v, expected %v", new, token) + } + expectTokens(t, s.tokens, tokens[:i+1]) + expectTokenFile(t, s.filename, tokens[:i+1]) + } + + s.modTime = time.Time{} + err := s.load() + if err != nil { + t.Errorf("Load: %v", err) + } + expectTokens(t, s.tokens, tokens) + + _, err = s.Edit("test2", tokens[1].Token, now.Add(time.Hour)) + if err == nil { + t.Errorf("Edit succeeded with wrong group") + } + new, err := s.Edit("test", tokens[1].Token, now.Add(time.Hour)) + if err != nil { + t.Errorf("Edit: %v", err) + } + tokens[1].Expires = &future + if !equal(new, tokens[1]) { + t.Errorf("Edit: got %v, expected %v", tokens[1], new) + } + expectTokens(t, s.tokens, tokens) + expectTokenFile(t, s.filename, tokens) + + for t := range s.tokens { + delete(s.tokens, t) + } + + err = s.rewrite() + if err != nil { + t.Errorf("rewrite(empty): %v", err) + } + + _, err = os.Stat(s.filename) + if !os.IsNotExist(err) { + t.Errorf("existence check: %v", err) + } +} diff --git a/token/token.go b/token/token.go index 512df1b..65e7871 100644 --- a/token/token.go +++ b/token/token.go @@ -1,9 +1,19 @@ package token +import ( + "errors" +) + +var ErrUsernameRequired = errors.New("username required") + type Token interface { Check(host, group string, username *string) (string, []string, error) } func Parse(token string, keys []map[string]interface{}) (Token, error) { + t, err := getStateful(token) + if err == nil && t != nil { + return t, nil + } return parseJWT(token, keys) }