97 lines
2.5 KiB
Go
97 lines
2.5 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"log"
|
|
"net/http"
|
|
"regexp"
|
|
|
|
"github.com/99designs/gqlgen/handler"
|
|
"github.com/viktorstrate/photoview/api/graphql/models"
|
|
)
|
|
|
|
var ErrUnauthorized = errors.New("unauthorized")
|
|
|
|
// A private key for context that only this package can access. This is important
|
|
// to prevent collisions between different context uses
|
|
var userCtxKey = &contextKey{"user"}
|
|
|
|
type contextKey struct {
|
|
name string
|
|
}
|
|
|
|
// Middleware decodes the share session cookie and packs the session into context
|
|
func Middleware(db *sql.DB) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
if tokenCookie, err := r.Cookie("auth-token"); err == nil {
|
|
log.Println("Found auth-token cookie")
|
|
user, err := models.VerifyTokenAndGetUser(db, tokenCookie.Value)
|
|
if err != nil {
|
|
log.Printf("Invalid token: %s\n", err)
|
|
http.Error(w, "invalid authorization token", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// put it in context
|
|
ctx := context.WithValue(r.Context(), userCtxKey, user)
|
|
|
|
// and call the next with our new context
|
|
r = r.WithContext(ctx)
|
|
} else {
|
|
log.Println("Did not find auth-token cookie")
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TokenFromBearer(bearer *string) (*string, error) {
|
|
regex, _ := regexp.Compile("^Bearer ([a-zA-Z0-9]{24})$")
|
|
matches := regex.FindStringSubmatch(*bearer)
|
|
if len(matches) != 2 {
|
|
return nil, errors.New("invalid bearer format")
|
|
}
|
|
|
|
token := matches[1]
|
|
return &token, nil
|
|
}
|
|
|
|
// UserFromContext finds the user from the context. REQUIRES Middleware to have run.
|
|
func UserFromContext(ctx context.Context) *models.User {
|
|
raw, _ := ctx.Value(userCtxKey).(*models.User)
|
|
return raw
|
|
}
|
|
|
|
func AuthWebsocketInit(db *sql.DB) func(context.Context, handler.InitPayload) (context.Context, error) {
|
|
return func(ctx context.Context, initPayload handler.InitPayload) (context.Context, error) {
|
|
|
|
bearer, exists := initPayload["Authorization"].(string)
|
|
if !exists {
|
|
return ctx, nil
|
|
}
|
|
|
|
token, err := TokenFromBearer(&bearer)
|
|
if err != nil {
|
|
log.Printf("Invalid bearer format (websocket): %s\n", bearer)
|
|
return nil, err
|
|
}
|
|
|
|
user, err := models.VerifyTokenAndGetUser(db, *token)
|
|
if err != nil {
|
|
log.Printf("Invalid token in websocket: %s\n", err)
|
|
return nil, errors.New("invalid authorization token")
|
|
}
|
|
|
|
// put it in context
|
|
userCtx := context.WithValue(ctx, userCtxKey, user)
|
|
|
|
// and return it so the resolvers can see it
|
|
return userCtx, nil
|
|
}
|
|
}
|