1
Fork 0
photoview/api/graphql/resolvers/faces.go

196 lines
5.1 KiB
Go
Raw Normal View History

2021-02-16 12:01:10 +01:00
package resolvers
import (
"context"
"github.com/photoview/photoview/api/graphql/auth"
"github.com/photoview/photoview/api/graphql/models"
2021-02-19 19:24:31 +01:00
"github.com/photoview/photoview/api/scanner/face_detection"
2021-02-19 17:49:41 +01:00
"github.com/pkg/errors"
"gorm.io/gorm"
2021-02-16 12:01:10 +01:00
)
func (r *queryResolver) MyFaceGroups(ctx context.Context, paginate *models.Pagination) ([]*models.FaceGroup, error) {
user := auth.UserFromContext(ctx)
if user == nil {
return nil, errors.New("unauthorized")
}
if err := user.FillAlbums(r.Database); err != nil {
return nil, err
}
userAlbumIDs := make([]int, len(user.Albums))
for i, album := range user.Albums {
userAlbumIDs[i] = album.ID
}
2021-02-16 17:13:08 +01:00
imageFaceQuery := r.Database.
Joins("Media").
Where("media.album_id IN (?)", userAlbumIDs)
2021-02-16 12:01:10 +01:00
var imageFaces []*models.ImageFace
2021-02-16 17:13:08 +01:00
if err := imageFaceQuery.Find(&imageFaces).Error; err != nil {
2021-02-16 12:01:10 +01:00
return nil, err
}
faceGroupMap := make(map[int][]models.ImageFace)
for _, face := range imageFaces {
2021-02-16 12:41:34 +01:00
_, found := faceGroupMap[face.FaceGroupID]
2021-02-16 12:01:10 +01:00
if found {
2021-02-16 12:41:34 +01:00
faceGroupMap[face.FaceGroupID] = append(faceGroupMap[face.FaceGroupID], *face)
2021-02-16 12:01:10 +01:00
} else {
faceGroupMap[face.FaceGroupID] = make([]models.ImageFace, 1)
faceGroupMap[face.FaceGroupID][0] = *face
}
}
faceGroupIDs := make([]int, len(faceGroupMap))
i := 0
for groupID := range faceGroupMap {
faceGroupIDs[i] = groupID
i++
}
2021-02-16 17:13:08 +01:00
faceGroupQuery := r.Database.
2021-02-17 13:50:32 +01:00
Joins("LEFT JOIN image_faces ON image_faces.id = face_groups.id").
Where("face_groups.id IN (?)", faceGroupIDs).
2021-02-16 17:13:08 +01:00
Order("CASE WHEN label IS NULL THEN 1 ELSE 0 END")
2021-02-16 12:01:10 +01:00
var faceGroups []*models.FaceGroup
2021-02-16 17:13:08 +01:00
if err := faceGroupQuery.Find(&faceGroups).Error; err != nil {
2021-02-16 12:01:10 +01:00
return nil, err
}
for _, faceGroup := range faceGroups {
faceGroup.ImageFaces = faceGroupMap[faceGroup.ID]
}
return faceGroups, nil
}
2021-02-17 13:50:32 +01:00
func (r *mutationResolver) SetFaceGroupLabel(ctx context.Context, faceGroupID int, label *string) (*models.FaceGroup, error) {
user := auth.UserFromContext(ctx)
if user == nil {
return nil, errors.New("unauthorized")
}
2021-02-19 17:49:41 +01:00
faceGroup, err := userOwnedFaceGroup(r.Database, user, faceGroupID)
if err != nil {
return nil, err
}
if err := r.Database.Model(faceGroup).Update("label", label).Error; err != nil {
return nil, err
}
return faceGroup, nil
}
func (r *mutationResolver) CombineFaceGroups(ctx context.Context, destinationFaceGroupID int, sourceFaceGroupID int) (*models.FaceGroup, error) {
user := auth.UserFromContext(ctx)
if user == nil {
return nil, errors.New("unauthorized")
}
destinationFaceGroup, err := userOwnedFaceGroup(r.Database, user, destinationFaceGroupID)
if err != nil {
return nil, err
}
sourceFaceGroup, err := userOwnedFaceGroup(r.Database, user, sourceFaceGroupID)
if err != nil {
return nil, err
}
updateError := r.Database.Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&models.ImageFace{}).Where("face_group_id = ?", sourceFaceGroup.ID).Update("face_group_id", destinationFaceGroup.ID).Error; err != nil {
return err
}
if err := tx.Delete(&sourceFaceGroup).Error; err != nil {
return err
}
return nil
})
if updateError != nil {
return nil, updateError
}
2021-02-19 19:24:31 +01:00
face_detection.GlobalFaceDetector.MergeCategories(int32(sourceFaceGroupID), int32(destinationFaceGroupID))
2021-02-19 17:49:41 +01:00
return destinationFaceGroup, nil
}
func (r *mutationResolver) MoveImageFace(ctx context.Context, imageFaceID int, newFaceGroupID int) (*models.ImageFace, error) {
panic("not implemented")
}
func (r *mutationResolver) RecognizeUnlabeledFaces(ctx context.Context) ([]*models.ImageFace, error) {
2021-02-19 19:24:31 +01:00
user := auth.UserFromContext(ctx)
if user == nil {
return nil, errors.New("unauthorized")
}
var updatedImageFaces []*models.ImageFace
transactionError := r.Database.Transaction(func(tx *gorm.DB) error {
var err error
updatedImageFaces, err = face_detection.GlobalFaceDetector.RecognizeUnlabeledFaces(tx, user)
return err
})
if transactionError != nil {
return nil, transactionError
}
return updatedImageFaces, nil
2021-02-19 17:49:41 +01:00
}
func userOwnedFaceGroup(db *gorm.DB, user *models.User, faceGroupID int) (*models.FaceGroup, error) {
if user.Admin {
var faceGroup models.FaceGroup
if err := db.Where("id = ?", faceGroupID).Find(&faceGroup).Error; err != nil {
return nil, err
}
return &faceGroup, nil
}
if err := user.FillAlbums(db); err != nil {
2021-02-17 13:50:32 +01:00
return nil, err
}
userAlbumIDs := make([]int, len(user.Albums))
for i, album := range user.Albums {
userAlbumIDs[i] = album.ID
}
// Verify that user owns at leat one of the images in the face group
2021-02-19 17:49:41 +01:00
imageFaceQuery := db.
2021-02-17 13:50:32 +01:00
Select("image_faces.id").
Table("image_faces").
Joins("LEFT JOIN media ON media.id = image_faces.media_id").
Where("media.album_id IN (?)", userAlbumIDs)
2021-02-19 17:49:41 +01:00
faceGroupQuery := db.
2021-02-17 13:50:32 +01:00
Model(&models.FaceGroup{}).
Joins("JOIN image_faces ON face_groups.id = image_faces.face_group_id").
Where("face_groups.id = ?", faceGroupID).
Where("image_faces.id IN (?)", imageFaceQuery)
var faceGroup models.FaceGroup
if err := faceGroupQuery.Find(&faceGroup).Error; err != nil {
2021-02-19 17:49:41 +01:00
if err == gorm.ErrRecordNotFound {
return nil, errors.Wrap(err, "face group does not exist or is not owned by the user")
}
2021-02-17 13:50:32 +01:00
return nil, err
}
return &faceGroup, nil
}