mirror of
https://github.com/jech/galene.git
synced 2024-11-09 18:25:58 +01:00
Implement accessors for stateful tokens.
This commit is contained in:
parent
fe15057252
commit
2f5c21d161
5 changed files with 153 additions and 99 deletions
|
@ -1753,7 +1753,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
tok.IssuedAt = &now
|
tok.IssuedAt = &now
|
||||||
|
|
||||||
new, err := token.Add(tok)
|
new, err := token.Update(tok, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return terror("error", err.Error())
|
return terror("error", err.Error())
|
||||||
}
|
}
|
||||||
|
@ -1783,19 +1783,26 @@ func handleClientMessage(c *webClient, m clientMessage) error {
|
||||||
}
|
}
|
||||||
if tok.Group != "" || tok.Username != nil ||
|
if tok.Group != "" || tok.Username != nil ||
|
||||||
tok.Permissions != nil ||
|
tok.Permissions != nil ||
|
||||||
tok.NotBefore != nil ||
|
|
||||||
tok.IssuedBy != nil ||
|
tok.IssuedBy != nil ||
|
||||||
tok.IssuedAt != nil {
|
tok.IssuedAt != nil {
|
||||||
return terror(
|
return terror(
|
||||||
"error", "this field cannot be edited",
|
"error", "this field cannot be edited",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if tok.Expires == nil {
|
|
||||||
return terror("error", "trying to edit nothing")
|
old, etag, err := token.Get(tok.Token)
|
||||||
|
if err != nil {
|
||||||
|
return terror("error", err.Error())
|
||||||
}
|
}
|
||||||
new, err := token.Extend(
|
t := old.Clone()
|
||||||
c.group.Name(), tok.Token, *tok.Expires,
|
if tok.Expires != nil {
|
||||||
)
|
t.Expires = tok.Expires
|
||||||
|
}
|
||||||
|
if tok.NotBefore != nil {
|
||||||
|
t.NotBefore = tok.NotBefore
|
||||||
|
}
|
||||||
|
|
||||||
|
new, err := token.Update(t, etag)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return terror("error", err.Error())
|
return terror("error", err.Error())
|
||||||
}
|
}
|
||||||
|
@ -1819,7 +1826,7 @@ func handleClientMessage(c *webClient, m clientMessage) error {
|
||||||
!member("token", c.permissions) {
|
!member("token", c.permissions) {
|
||||||
return terror("not-authorised", "not authorised")
|
return terror("not-authorised", "not authorised")
|
||||||
}
|
}
|
||||||
tokens, err := token.List(c.group.Name())
|
tokens, _, err := token.List(c.group.Name())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return terror("error", err.Error())
|
return terror("error", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package token
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -11,6 +12,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrTagMismatch = errors.New("tag mismatch")
|
||||||
|
|
||||||
// A stateful token
|
// A stateful token
|
||||||
type Stateful struct {
|
type Stateful struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
|
@ -57,26 +60,26 @@ func SetStatefulFilename(filename string) {
|
||||||
tokens.modTime = time.Time{}
|
tokens.modTime = time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state *state) Get(token string) (*Stateful, error) {
|
func (state *state) Get(token string) (*Stateful, string, error) {
|
||||||
state.mu.Lock()
|
state.mu.Lock()
|
||||||
defer state.mu.Unlock()
|
defer state.mu.Unlock()
|
||||||
err := state.load()
|
etag, err := state.load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
if state.tokens == nil {
|
if state.tokens == nil {
|
||||||
return nil, os.ErrNotExist
|
return nil, "", os.ErrNotExist
|
||||||
}
|
}
|
||||||
t := state.tokens[token]
|
t := state.tokens[token]
|
||||||
if t == nil {
|
if t == nil {
|
||||||
return nil, os.ErrNotExist
|
return nil, "", os.ErrNotExist
|
||||||
}
|
}
|
||||||
return t, nil
|
return t, etag, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get fetches a stateful token.
|
// Get fetches a stateful token.
|
||||||
// It returns os.ErrNotExist if the token doesn't exist.
|
// It returns os.ErrNotExist if the token doesn't exist.
|
||||||
func Get(token string) (*Stateful, error) {
|
func Get(token string) (*Stateful, string, error) {
|
||||||
return tokens.Get(token)
|
return tokens.Get(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,12 +106,13 @@ func (token *Stateful) Check(host, group string, username *string) (string, []st
|
||||||
return user, token.Permissions, nil
|
return user, token.Permissions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// load updates the state from the corresponding file.
|
||||||
// called locked
|
// called locked
|
||||||
func (state *state) load() error {
|
func (state *state) load() (string, error) {
|
||||||
if state.filename == "" {
|
if state.filename == "" {
|
||||||
state.modTime = time.Time{}
|
state.modTime = time.Time{}
|
||||||
state.tokens = nil
|
state.tokens = nil
|
||||||
return nil
|
return state.etag(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
fi, err := os.Stat(state.filename)
|
fi, err := os.Stat(state.filename)
|
||||||
|
@ -117,14 +121,13 @@ func (state *state) load() error {
|
||||||
state.fileSize = 0
|
state.fileSize = 0
|
||||||
state.tokens = nil
|
state.tokens = nil
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
return nil
|
return "", nil
|
||||||
}
|
}
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if state.modTime.Equal(fi.ModTime()) &&
|
if state.modTime.Equal(fi.ModTime()) && state.fileSize == fi.Size() {
|
||||||
state.fileSize == fi.Size() {
|
return state.etag(), nil
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := os.Open(state.filename)
|
f, err := os.Open(state.filename)
|
||||||
|
@ -133,9 +136,9 @@ func (state *state) load() error {
|
||||||
state.fileSize = 0
|
state.fileSize = 0
|
||||||
state.tokens = nil
|
state.tokens = nil
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
return nil
|
return state.etag(), nil
|
||||||
}
|
}
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
@ -150,7 +153,7 @@ func (state *state) load() error {
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
state.modTime = time.Time{}
|
state.modTime = time.Time{}
|
||||||
state.fileSize = 0
|
state.fileSize = 0
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
ts[t.Token] = &t
|
ts[t.Token] = &t
|
||||||
}
|
}
|
||||||
|
@ -161,35 +164,108 @@ func (state *state) load() error {
|
||||||
state.fileSize = 0
|
state.fileSize = 0
|
||||||
state.tokens = nil
|
state.tokens = nil
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
return nil
|
return state.etag(), nil
|
||||||
}
|
}
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
state.modTime = fi.ModTime()
|
state.modTime = fi.ModTime()
|
||||||
state.fileSize = fi.Size()
|
state.fileSize = fi.Size()
|
||||||
return nil
|
return state.etag(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state *state) Add(token *Stateful) (*Stateful, error) {
|
func (state *state) etag() string {
|
||||||
|
if state.modTime.Equal(time.Time{}) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("\"%v-%v\"",
|
||||||
|
state.fileSize, state.modTime.UnixNano(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update adds or updates a token.
|
||||||
|
// If etag is the empty string, it is added if it didn't exist. If etag
|
||||||
|
// is not empty, it is added if it matches the state's etag.
|
||||||
|
func (state *state) Update(token *Stateful, etag string) (*Stateful, error) {
|
||||||
tokens.mu.Lock()
|
tokens.mu.Lock()
|
||||||
defer tokens.mu.Unlock()
|
defer tokens.mu.Unlock()
|
||||||
|
|
||||||
if state.filename == "" {
|
if state.filename == "" {
|
||||||
|
if etag != "" {
|
||||||
|
return nil, ErrTagMismatch
|
||||||
|
}
|
||||||
return nil, os.ErrNotExist
|
return nil, os.ErrNotExist
|
||||||
}
|
}
|
||||||
|
|
||||||
err := state.load()
|
_, err := state.load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if state.tokens != nil {
|
if state.tokens == nil {
|
||||||
if _, ok := state.tokens[token.Token]; ok {
|
state.tokens = make(map[string]*Stateful)
|
||||||
return nil, os.ErrExist
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = os.MkdirAll(filepath.Dir(state.filename), 0700)
|
old, ok := state.tokens[token.Token]
|
||||||
|
if ok {
|
||||||
|
if etag != state.etag() {
|
||||||
|
return nil, ErrTagMismatch
|
||||||
|
}
|
||||||
|
state.tokens[token.Token] = token
|
||||||
|
err = state.rewrite()
|
||||||
|
if err != nil {
|
||||||
|
state.tokens[token.Token] = old
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if etag != "" {
|
||||||
|
return nil, ErrTagMismatch
|
||||||
|
}
|
||||||
|
return state.add(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Delete(token string, etag string) error {
|
||||||
|
return tokens.Delete(token, etag)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state *state) Delete(token string, etag string) error {
|
||||||
|
tokens.mu.Lock()
|
||||||
|
defer tokens.mu.Unlock()
|
||||||
|
|
||||||
|
if state.filename == "" {
|
||||||
|
return os.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := state.load()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.tokens == nil {
|
||||||
|
return os.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
old, ok := state.tokens[token]
|
||||||
|
if !ok {
|
||||||
|
return os.ErrNotExist
|
||||||
|
}
|
||||||
|
if etag != state.etag() {
|
||||||
|
return ErrTagMismatch
|
||||||
|
}
|
||||||
|
delete(state.tokens, token)
|
||||||
|
err = state.rewrite()
|
||||||
|
if err != nil {
|
||||||
|
state.tokens[token] = old
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// add unconditionally adds a token, which is assumed to not exist.
|
||||||
|
// called locked
|
||||||
|
func (state *state) add(token *Stateful) (*Stateful, error) {
|
||||||
|
err := os.MkdirAll(filepath.Dir(state.filename), 0700)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -220,49 +296,8 @@ func (state *state) Add(token *Stateful) (*Stateful, error) {
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Add(token *Stateful) (*Stateful, error) {
|
func Update(token *Stateful, etag string) (*Stateful, error) {
|
||||||
return tokens.Add(token)
|
return tokens.Update(token, etag)
|
||||||
}
|
|
||||||
|
|
||||||
func Extend(group, token string, expires time.Time) (*Stateful, error) {
|
|
||||||
return tokens.Extend(group, token, expires)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (state *state) Extend(group, token string, expires time.Time) (*Stateful, error) {
|
|
||||||
tokens.mu.Lock()
|
|
||||||
defer tokens.mu.Unlock()
|
|
||||||
return state.extend(group, token, expires)
|
|
||||||
}
|
|
||||||
|
|
||||||
// called locked
|
|
||||||
func (state *state) extend(group, token string, expires time.Time) (*Stateful, error) {
|
|
||||||
err := state.load()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if state.tokens == nil {
|
|
||||||
return nil, os.ErrNotExist
|
|
||||||
}
|
|
||||||
|
|
||||||
old := state.tokens[token]
|
|
||||||
if old == nil {
|
|
||||||
return nil, os.ErrNotExist
|
|
||||||
}
|
|
||||||
if old.Group != group {
|
|
||||||
return nil, os.ErrPermission
|
|
||||||
}
|
|
||||||
|
|
||||||
new := old.Clone()
|
|
||||||
new.Expires = &expires
|
|
||||||
state.tokens[token] = new
|
|
||||||
|
|
||||||
err = state.rewrite()
|
|
||||||
if err != nil {
|
|
||||||
state.tokens[token] = old
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return new, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// called locked
|
// called locked
|
||||||
|
@ -280,7 +315,7 @@ func (state *state) rewrite() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
a, err := state.list("")
|
a, _, err := state.list("")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
os.Remove(tmpfile.Name())
|
os.Remove(tmpfile.Name())
|
||||||
return err
|
return err
|
||||||
|
@ -321,15 +356,15 @@ func (state *state) rewrite() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// called locked
|
// called locked
|
||||||
func (state *state) list(group string) ([]*Stateful, error) {
|
func (state *state) list(group string) ([]*Stateful, string, error) {
|
||||||
err := state.load()
|
_, err := state.load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
a := make([]*Stateful, 0)
|
a := make([]*Stateful, 0)
|
||||||
if state.tokens == nil {
|
if state.tokens == nil {
|
||||||
return a, nil
|
return a, state.etag(), nil
|
||||||
}
|
}
|
||||||
for _, t := range state.tokens {
|
for _, t := range state.tokens {
|
||||||
if group != "" {
|
if group != "" {
|
||||||
|
@ -348,16 +383,16 @@ func (state *state) list(group string) ([]*Stateful, error) {
|
||||||
}
|
}
|
||||||
return (*a[i].Expires).Before(*a[j].Expires)
|
return (*a[i].Expires).Before(*a[j].Expires)
|
||||||
})
|
})
|
||||||
return a, nil
|
return a, state.etag(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state *state) List(group string) ([]*Stateful, error) {
|
func (state *state) List(group string) ([]*Stateful, string, error) {
|
||||||
state.mu.Lock()
|
state.mu.Lock()
|
||||||
defer state.mu.Unlock()
|
defer state.mu.Unlock()
|
||||||
return state.list(group)
|
return state.list(group)
|
||||||
}
|
}
|
||||||
|
|
||||||
func List(group string) ([]*Stateful, error) {
|
func List(group string) ([]*Stateful, string, error) {
|
||||||
return tokens.List(group)
|
return tokens.List(group)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -365,7 +400,7 @@ func (state *state) Expire() error {
|
||||||
state.mu.Lock()
|
state.mu.Lock()
|
||||||
defer state.mu.Unlock()
|
defer state.mu.Unlock()
|
||||||
|
|
||||||
err := state.load()
|
_, err := state.load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -252,7 +252,7 @@ func TestTokenStorage(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for i, token := range tokens {
|
for i, token := range tokens {
|
||||||
new, err := s.Add(token)
|
new, err := s.Update(token, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Add: %v", err)
|
t.Errorf("Add: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -264,19 +264,30 @@ func TestTokenStorage(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
s.modTime = time.Time{}
|
s.modTime = time.Time{}
|
||||||
err := s.load()
|
_, err := s.load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Load: %v", err)
|
t.Errorf("Load: %v", err)
|
||||||
}
|
}
|
||||||
expectTokens(t, s.tokens, tokens)
|
expectTokens(t, s.tokens, tokens)
|
||||||
|
|
||||||
_, err = s.Extend("test2", tokens[1].Token, now.Add(time.Hour))
|
t1, etag, err := s.Get("tok2")
|
||||||
if err == nil {
|
|
||||||
t.Errorf("Edit succeeded with wrong group")
|
|
||||||
}
|
|
||||||
new, err := s.Extend("test", tokens[1].Token, now.Add(time.Hour))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Edit: %v", err)
|
t.Fatalf("Get: %v", err)
|
||||||
|
}
|
||||||
|
t2 := t1.Clone()
|
||||||
|
soon := now.Add(time.Hour)
|
||||||
|
t2.Expires = &soon
|
||||||
|
_, err = s.Update(t2, "")
|
||||||
|
if !errors.Is(err, ErrTagMismatch) {
|
||||||
|
t.Errorf("Update: got %v, expected ErrTagMismatch", err)
|
||||||
|
}
|
||||||
|
_, err = s.Update(t2, "\"bad\"")
|
||||||
|
if !errors.Is(err, ErrTagMismatch) {
|
||||||
|
t.Errorf("Update: got %v, expected ErrTagMismatch", err)
|
||||||
|
}
|
||||||
|
new, err := s.Update(t2, etag)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Update: %v", err)
|
||||||
}
|
}
|
||||||
tokens[1].Expires = &future
|
tokens[1].Expires = &future
|
||||||
if !equal(new, tokens[1]) {
|
if !equal(new, tokens[1]) {
|
||||||
|
@ -350,7 +361,7 @@ func TestExpire(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, token := range tokens {
|
for _, token := range tokens {
|
||||||
_, err := s.Add(token)
|
_, err := s.Update(token, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Add: %v", err)
|
t.Errorf("Add: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,5 +22,6 @@ func Parse(token string, keys []map[string]interface{}) (Token, error) {
|
||||||
return jwt, nil
|
return jwt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return Get(token)
|
s, _, err := Get(token)
|
||||||
|
return s, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,13 +23,13 @@ func TestToken(t *testing.T) {
|
||||||
|
|
||||||
future := time.Now().Add(time.Hour)
|
future := time.Now().Add(time.Hour)
|
||||||
user := "user"
|
user := "user"
|
||||||
_, err = Add(&Stateful{
|
_, err = Update(&Stateful{
|
||||||
Token: "token",
|
Token: "token",
|
||||||
Group: "group",
|
Group: "group",
|
||||||
Username: &user,
|
Username: &user,
|
||||||
Permissions: []string{"present"},
|
Permissions: []string{"present"},
|
||||||
Expires: &future,
|
Expires: &future,
|
||||||
})
|
}, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Add: %v", err)
|
t.Fatalf("Add: %v", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue