Add test for routes authenticateMedia
This commit is contained in:
parent
8539d48944
commit
8d2654997d
|
@ -38,7 +38,7 @@ func Middleware(db *gorm.DB) func(http.Handler) http.Handler {
|
|||
}
|
||||
|
||||
// put it in context
|
||||
ctx := context.WithValue(r.Context(), userCtxKey, user)
|
||||
ctx := AddUserToContext(r.Context(), user)
|
||||
|
||||
// and call the next with our new context
|
||||
r = r.WithContext(ctx)
|
||||
|
@ -51,6 +51,10 @@ func Middleware(db *gorm.DB) func(http.Handler) http.Handler {
|
|||
}
|
||||
}
|
||||
|
||||
func AddUserToContext(ctx context.Context, user *models.User) context.Context {
|
||||
return context.WithValue(ctx, userCtxKey, user)
|
||||
}
|
||||
|
||||
func TokenFromBearer(bearer *string) (*string, error) {
|
||||
regex, _ := regexp.Compile("^(?i)Bearer ([a-zA-Z0-9]{24})$")
|
||||
matches := regex.FindStringSubmatch(*bearer)
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
package actions
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/photoview/photoview/api/graphql/auth"
|
||||
"github.com/photoview/photoview/api/graphql/models"
|
||||
"github.com/photoview/photoview/api/utils"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func AddMediaShare(db *gorm.DB, userID int, mediaID int, expire *time.Time, password *string) (*models.ShareToken, error) {
|
||||
var media models.Media
|
||||
|
||||
var query string
|
||||
if db.Dialector.Name() == "postgres" {
|
||||
query = "EXISTS (SELECT * FROM user_albums WHERE user_albums.album_id = \"Album\".id AND user_albums.user_id = ?)"
|
||||
} else {
|
||||
query = "EXISTS (SELECT * FROM user_albums WHERE user_albums.album_id = Album.id AND user_albums.user_id = ?)"
|
||||
}
|
||||
|
||||
err := db.Joins("Album").
|
||||
Where(query, userID).
|
||||
First(&media, mediaID).
|
||||
Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, auth.ErrUnauthorized
|
||||
} else {
|
||||
return nil, errors.Wrap(err, "failed to validate media owner with database")
|
||||
}
|
||||
}
|
||||
|
||||
hashedPassword, err := hashSharePassword(password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shareToken := models.ShareToken{
|
||||
Value: utils.GenerateToken(),
|
||||
OwnerID: userID,
|
||||
Expire: expire,
|
||||
Password: hashedPassword,
|
||||
AlbumID: nil,
|
||||
MediaID: &mediaID,
|
||||
}
|
||||
|
||||
if err := db.Create(&shareToken).Error; err != nil {
|
||||
return nil, errors.Wrap(err, "failed to insert new share token into database")
|
||||
}
|
||||
|
||||
return &shareToken, nil
|
||||
}
|
||||
|
||||
func DeleteShareToken(db *gorm.DB, userID int, tokenValue string) (*models.ShareToken, error) {
|
||||
token, err := getUserToken(db, userID, tokenValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := db.Delete(&token).Error; err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to delete share token (%s) from database", tokenValue)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func ProtectShareToken(db *gorm.DB, userID int, tokenValue string, password *string) (*models.ShareToken, error) {
|
||||
token, err := getUserToken(db, userID, tokenValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hashedPassword, err := hashSharePassword(password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token.Password = hashedPassword
|
||||
|
||||
if err := db.Save(&token).Error; err != nil {
|
||||
return nil, errors.Wrap(err, "failed to update password for share token")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func hashSharePassword(password *string) (*string, error) {
|
||||
var hashedPassword *string = nil
|
||||
if password != nil {
|
||||
hashedPassBytes, err := bcrypt.GenerateFromPassword([]byte(*password), 12)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to generate hash for share password")
|
||||
}
|
||||
hashedStr := string(hashedPassBytes)
|
||||
hashedPassword = &hashedStr
|
||||
}
|
||||
|
||||
return hashedPassword, nil
|
||||
}
|
||||
|
||||
func getUserToken(db *gorm.DB, userID int, tokenValue string) (*models.ShareToken, error) {
|
||||
|
||||
var token models.ShareToken
|
||||
err := db.Where("share_tokens.value = ?", tokenValue).Joins("Owner").Where("Owner.id = ? OR Owner.admin = TRUE", userID).First(&token).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get user share token from database")
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
|
@ -11,6 +11,7 @@ import (
|
|||
api "github.com/photoview/photoview/api/graphql"
|
||||
"github.com/photoview/photoview/api/graphql/auth"
|
||||
"github.com/photoview/photoview/api/graphql/models"
|
||||
"github.com/photoview/photoview/api/graphql/models/actions"
|
||||
"github.com/photoview/photoview/api/utils"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
@ -145,47 +146,7 @@ func (r *mutationResolver) ShareMedia(ctx context.Context, mediaID int, expire *
|
|||
return nil, auth.ErrUnauthorized
|
||||
}
|
||||
|
||||
var media models.Media
|
||||
|
||||
var query string
|
||||
if r.Database.Dialector.Name() == "postgres" {
|
||||
query = "EXISTS (SELECT * FROM user_albums WHERE user_albums.album_id = \"Album\".id AND user_albums.user_id = ?)"
|
||||
} else {
|
||||
query = "EXISTS (SELECT * FROM user_albums WHERE user_albums.album_id = Album.id AND user_albums.user_id = ?)"
|
||||
}
|
||||
|
||||
err := r.Database.Joins("Album").
|
||||
Where(query, user.ID).
|
||||
First(&media, mediaID).
|
||||
Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, auth.ErrUnauthorized
|
||||
} else {
|
||||
return nil, errors.Wrap(err, "failed to validate media owner with database")
|
||||
}
|
||||
}
|
||||
|
||||
hashedPassword, err := hashSharePassword(password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shareToken := models.ShareToken{
|
||||
Value: utils.GenerateToken(),
|
||||
OwnerID: user.ID,
|
||||
Expire: expire,
|
||||
Password: hashedPassword,
|
||||
AlbumID: nil,
|
||||
MediaID: &mediaID,
|
||||
}
|
||||
|
||||
if err := r.Database.Create(&shareToken).Error; err != nil {
|
||||
return nil, errors.Wrap(err, "failed to insert new share token into database")
|
||||
}
|
||||
|
||||
return &shareToken, nil
|
||||
return actions.AddMediaShare(r.Database, user.ID, mediaID, expire, password)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) DeleteShareToken(ctx context.Context, tokenValue string) (*models.ShareToken, error) {
|
||||
|
@ -194,16 +155,7 @@ func (r *mutationResolver) DeleteShareToken(ctx context.Context, tokenValue stri
|
|||
return nil, auth.ErrUnauthorized
|
||||
}
|
||||
|
||||
token, err := getUserToken(r.Database, user, tokenValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := r.Database.Delete(&token).Error; err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to delete share token (%s) from database", tokenValue)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
return actions.DeleteShareToken(r.Database, user.ID, tokenValue)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) ProtectShareToken(ctx context.Context, tokenValue string, password *string) (*models.ShareToken, error) {
|
||||
|
@ -212,47 +164,5 @@ func (r *mutationResolver) ProtectShareToken(ctx context.Context, tokenValue str
|
|||
return nil, auth.ErrUnauthorized
|
||||
}
|
||||
|
||||
token, err := getUserToken(r.Database, user, tokenValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hashedPassword, err := hashSharePassword(password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token.Password = hashedPassword
|
||||
|
||||
if err := r.Database.Save(&token).Error; err != nil {
|
||||
return nil, errors.Wrap(err, "failed to update password for share token")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func hashSharePassword(password *string) (*string, error) {
|
||||
var hashedPassword *string = nil
|
||||
if password != nil {
|
||||
hashedPassBytes, err := bcrypt.GenerateFromPassword([]byte(*password), 12)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to generate hash for share password")
|
||||
}
|
||||
hashedStr := string(hashedPassBytes)
|
||||
hashedPassword = &hashedStr
|
||||
}
|
||||
|
||||
return hashedPassword, nil
|
||||
}
|
||||
|
||||
func getUserToken(db *gorm.DB, user *models.User, tokenValue string) (*models.ShareToken, error) {
|
||||
|
||||
var token models.ShareToken
|
||||
err := db.Where("share_tokens.value = ?", tokenValue).Joins("Owner").Where("Owner.id = ? OR Owner.admin = TRUE", user.ID).First(&token).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get user share token from database")
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
return actions.ProtectShareToken(r.Database, user.ID, tokenValue, password)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
package routes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/photoview/photoview/api/graphql/auth"
|
||||
"github.com/photoview/photoview/api/graphql/models"
|
||||
"github.com/photoview/photoview/api/graphql/models/actions"
|
||||
"github.com/photoview/photoview/api/test_utils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAuthenticateMedia(t *testing.T) {
|
||||
db := test_utils.DatabaseTest(t)
|
||||
|
||||
user, err := models.RegisterUser(db, "username", nil, false)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
album := models.Album{
|
||||
Title: "my_album",
|
||||
}
|
||||
|
||||
if !assert.NoError(t, db.Model(&user).Association("Albums").Append(&album)) {
|
||||
return
|
||||
}
|
||||
|
||||
media := models.Media{
|
||||
Title: "my_media",
|
||||
Path: "/photos/image.jpg",
|
||||
AlbumID: album.ID,
|
||||
}
|
||||
|
||||
if !assert.NoError(t, db.Save(&media).Error) {
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("Authorized request", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/photo/image.jpg", strings.NewReader("IMAGE DATA"))
|
||||
ctx := auth.AddUserToContext(req.Context(), user)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, success)
|
||||
assert.Equal(t, responseMessage, "success")
|
||||
assert.Equal(t, responseStatus, http.StatusAccepted)
|
||||
})
|
||||
|
||||
t.Run("Request without access token", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/photo/image.jpg", strings.NewReader("IMAGE DATA"))
|
||||
|
||||
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, success)
|
||||
assert.Equal(t, responseMessage, "unauthorized")
|
||||
assert.Equal(t, responseStatus, http.StatusForbidden)
|
||||
})
|
||||
|
||||
t.Run("Request without access token", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/photo/image.jpg", strings.NewReader("IMAGE DATA"))
|
||||
|
||||
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, success)
|
||||
assert.Equal(t, responseMessage, "unauthorized")
|
||||
assert.Equal(t, responseStatus, http.StatusForbidden)
|
||||
})
|
||||
|
||||
expire := time.Now().Add(time.Hour * 24 * 30)
|
||||
tokenPassword := "token-password-123"
|
||||
shareToken, err := actions.AddMediaShare(db, user.ID, media.ID, &expire, &tokenPassword)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("Request with share token", func(t *testing.T) {
|
||||
url := fmt.Sprintf("/photo/image.jpg?token=%s", shareToken.Value)
|
||||
req := httptest.NewRequest("GET", url, strings.NewReader("IMAGE DATA"))
|
||||
|
||||
cookie := http.Cookie{
|
||||
Name: fmt.Sprintf("share-token-pw-%s", shareToken.Value),
|
||||
Value: tokenPassword,
|
||||
}
|
||||
req.AddCookie(&cookie)
|
||||
|
||||
success, responseMessage, responseStatus, err := authenticateMedia(&media, db, req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, success)
|
||||
assert.Equal(t, responseMessage, "success")
|
||||
assert.Equal(t, responseStatus, http.StatusAccepted)
|
||||
})
|
||||
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
package routes_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/photoview/photoview/api/test_utils"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(test_utils.IntegrationTestRun(m))
|
||||
}
|
Loading…
Reference in New Issue