package resolvers import ( "context" api "github.com/photoview/photoview/api/graphql" "github.com/photoview/photoview/api/graphql/auth" "github.com/photoview/photoview/api/graphql/models" "github.com/photoview/photoview/api/scanner/face_detection" "github.com/pkg/errors" "gorm.io/gorm" ) type imageFaceResolver struct { *Resolver } type faceGroupResolver struct { *Resolver } func (r *Resolver) ImageFace() api.ImageFaceResolver { return imageFaceResolver{r} } func (r *Resolver) FaceGroup() api.FaceGroupResolver { return faceGroupResolver{r} } func (r imageFaceResolver) FaceGroup(ctx context.Context, obj *models.ImageFace) (*models.FaceGroup, error) { if obj.FaceGroup != nil { return obj.FaceGroup, nil } var faceGroup models.FaceGroup if err := r.Database.Model(&obj).Association("FaceGroup").Find(&faceGroup); err != nil { return nil, err } obj.FaceGroup = &faceGroup return &faceGroup, nil } func (r faceGroupResolver) ImageFaces(ctx context.Context, obj *models.FaceGroup, paginate *models.Pagination) ([]*models.ImageFace, error) { user := auth.UserFromContext(ctx) if user == nil { return nil, errors.New("unauthorized") } if err := user.FillAlbums(r.Database); err != nil { return nil, err } userAlbumIDs := make([]int, len(user.Albums)) for i, album := range user.Albums { userAlbumIDs[i] = album.ID } query := r.Database. Joins("Media"). Where("face_group_id = ?", obj.ID). Where("media.album_id IN (?)", userAlbumIDs) query = models.FormatSQL(query, nil, paginate) var imageFaces []*models.ImageFace if err := query.Find(&imageFaces).Error; err != nil { return nil, err } return imageFaces, nil } func (r faceGroupResolver) ImageFaceCount(ctx context.Context, obj *models.FaceGroup) (int, error) { user := auth.UserFromContext(ctx) if user == nil { return -1, errors.New("unauthorized") } if err := user.FillAlbums(r.Database); err != nil { return -1, err } userAlbumIDs := make([]int, len(user.Albums)) for i, album := range user.Albums { userAlbumIDs[i] = album.ID } query := r.Database. Model(&models.ImageFace{}). Joins("Media"). Where("face_group_id = ?", obj.ID). Where("media.album_id IN (?)", userAlbumIDs) var count int64 if err := query.Count(&count).Error; err != nil { return -1, err } return int(count), nil } func (r *queryResolver) FaceGroup(ctx context.Context, id int) (*models.FaceGroup, error) { user := auth.UserFromContext(ctx) if user == nil { return nil, errors.New("unauthorized") } if err := user.FillAlbums(r.Database); err != nil { return nil, err } userAlbumIDs := make([]int, len(user.Albums)) for i, album := range user.Albums { userAlbumIDs[i] = album.ID } faceGroupQuery := r.Database. Joins("LEFT JOIN image_faces ON image_faces.face_group_id = face_groups.id"). Where("face_groups.id = ?", id). Where("image_faces.media_id IN (?)", r.Database.Select("media_id").Table("media").Where("media.album_id IN (?)", userAlbumIDs)) var faceGroup models.FaceGroup if err := faceGroupQuery.Find(&faceGroup).Error; err != nil { return nil, err } return &faceGroup, nil } func (r *queryResolver) MyFaceGroups(ctx context.Context, paginate *models.Pagination) ([]*models.FaceGroup, error) { user := auth.UserFromContext(ctx) if user == nil { return nil, errors.New("unauthorized") } if err := user.FillAlbums(r.Database); err != nil { return nil, err } userAlbumIDs := make([]int, len(user.Albums)) for i, album := range user.Albums { userAlbumIDs[i] = album.ID } faceGroupQuery := r.Database. Joins("JOIN image_faces ON image_faces.face_group_id = face_groups.id"). Where("image_faces.media_id IN (?)", r.Database.Select("media.id").Table("media").Where("media.album_id IN (?)", userAlbumIDs)). Group("image_faces.face_group_id"). Order("CASE WHEN label IS NULL THEN 1 ELSE 0 END"). Order("COUNT(image_faces.id) DESC") faceGroupQuery = models.FormatSQL(faceGroupQuery, nil, paginate) var faceGroups []*models.FaceGroup if err := faceGroupQuery.Find(&faceGroups).Error; err != nil { return nil, err } return faceGroups, nil } func (r *mutationResolver) SetFaceGroupLabel(ctx context.Context, faceGroupID int, label *string) (*models.FaceGroup, error) { user := auth.UserFromContext(ctx) if user == nil { return nil, errors.New("unauthorized") } faceGroup, err := userOwnedFaceGroup(r.Database, user, faceGroupID) if err != nil { return nil, err } if err := r.Database.Model(faceGroup).Update("label", label).Error; err != nil { return nil, err } return faceGroup, nil } func (r *mutationResolver) CombineFaceGroups(ctx context.Context, destinationFaceGroupID int, sourceFaceGroupID int) (*models.FaceGroup, error) { user := auth.UserFromContext(ctx) if user == nil { return nil, errors.New("unauthorized") } destinationFaceGroup, err := userOwnedFaceGroup(r.Database, user, destinationFaceGroupID) if err != nil { return nil, err } sourceFaceGroup, err := userOwnedFaceGroup(r.Database, user, sourceFaceGroupID) if err != nil { return nil, err } updateError := r.Database.Transaction(func(tx *gorm.DB) error { if err := tx.Model(&models.ImageFace{}).Where("face_group_id = ?", sourceFaceGroup.ID).Update("face_group_id", destinationFaceGroup.ID).Error; err != nil { return err } if err := tx.Delete(&sourceFaceGroup).Error; err != nil { return err } return nil }) if updateError != nil { return nil, updateError } face_detection.GlobalFaceDetector.MergeCategories(int32(sourceFaceGroupID), int32(destinationFaceGroupID)) return destinationFaceGroup, nil } func (r *mutationResolver) MoveImageFaces(ctx context.Context, imageFaceIDs []int, destinationFaceGroupID int) (*models.FaceGroup, error) { user := auth.UserFromContext(ctx) if user == nil { return nil, errors.New("unauthorized") } userOwnedImageFaceIDs := make([]int, 0) var destFaceGroup *models.FaceGroup transErr := r.Database.Transaction(func(tx *gorm.DB) error { var err error destFaceGroup, err = userOwnedFaceGroup(tx, user, destinationFaceGroupID) if err != nil { return err } userOwnedImageFaces, err := getUserOwnedImageFaces(tx, user, imageFaceIDs) if err != nil { return err } for _, imageFace := range userOwnedImageFaces { userOwnedImageFaceIDs = append(userOwnedImageFaceIDs, imageFace.ID) } var sourceFaceGroups []*models.FaceGroup if err := tx. Joins("LEFT JOIN image_faces ON image_faces.face_group_id = face_groups.id"). Where("image_faces.id IN (?)", userOwnedImageFaceIDs). Find(&sourceFaceGroups).Error; err != nil { return err } if err := tx. Model(&models.ImageFace{}). Where("id IN (?)", userOwnedImageFaceIDs). Update("face_group_id", destFaceGroup.ID).Error; err != nil { return err } // delete face groups if they have become empty for _, faceGroup := range sourceFaceGroups { var count int64 if err := tx.Model(&models.ImageFace{}).Where("face_group_id = ?", faceGroup.ID).Count(&count).Error; err != nil { return err } if count == 0 { if err := tx.Delete(&faceGroup).Error; err != nil { return err } } } return nil }) if transErr != nil { return nil, transErr } face_detection.GlobalFaceDetector.MergeImageFaces(userOwnedImageFaceIDs, int32(destFaceGroup.ID)) return destFaceGroup, nil } func (r *mutationResolver) RecognizeUnlabeledFaces(ctx context.Context) ([]*models.ImageFace, error) { user := auth.UserFromContext(ctx) if user == nil { return nil, errors.New("unauthorized") } var updatedImageFaces []*models.ImageFace transactionError := r.Database.Transaction(func(tx *gorm.DB) error { var err error updatedImageFaces, err = face_detection.GlobalFaceDetector.RecognizeUnlabeledFaces(tx, user) return err }) if transactionError != nil { return nil, transactionError } return updatedImageFaces, nil } func (r *mutationResolver) DetachImageFaces(ctx context.Context, imageFaceIDs []int) (*models.FaceGroup, error) { user := auth.UserFromContext(ctx) if user == nil { return nil, errors.New("unauthorized") } userOwnedImageFaceIDs := make([]int, 0) newFaceGroup := models.FaceGroup{} transactionError := r.Database.Transaction(func(tx *gorm.DB) error { userOwnedImageFaces, err := getUserOwnedImageFaces(tx, user, imageFaceIDs) if err != nil { return err } for _, imageFace := range userOwnedImageFaces { userOwnedImageFaceIDs = append(userOwnedImageFaceIDs, imageFace.ID) } if err := tx.Save(&newFaceGroup).Error; err != nil { return err } if err := tx. Model(&models.ImageFace{}). Where("id IN (?)", userOwnedImageFaceIDs). Update("face_group_id", newFaceGroup.ID).Error; err != nil { return err } return nil }) if transactionError != nil { return nil, transactionError } face_detection.GlobalFaceDetector.MergeImageFaces(userOwnedImageFaceIDs, int32(newFaceGroup.ID)) return &newFaceGroup, nil } func userOwnedFaceGroup(db *gorm.DB, user *models.User, faceGroupID int) (*models.FaceGroup, error) { if user.Admin { var faceGroup models.FaceGroup if err := db.Where("id = ?", faceGroupID).Find(&faceGroup).Error; err != nil { return nil, err } return &faceGroup, nil } if err := user.FillAlbums(db); err != nil { return nil, err } userAlbumIDs := make([]int, len(user.Albums)) for i, album := range user.Albums { userAlbumIDs[i] = album.ID } // Verify that user owns at leat one of the images in the face group imageFaceQuery := db. Select("image_faces.id"). Table("image_faces"). Joins("LEFT JOIN media ON media.id = image_faces.media_id"). Where("media.album_id IN (?)", userAlbumIDs) faceGroupQuery := db. Model(&models.FaceGroup{}). Joins("JOIN image_faces ON face_groups.id = image_faces.face_group_id"). Where("face_groups.id = ?", faceGroupID). Where("image_faces.id IN (?)", imageFaceQuery) var faceGroup models.FaceGroup if err := faceGroupQuery.Find(&faceGroup).Error; err != nil { if err == gorm.ErrRecordNotFound { return nil, errors.Wrap(err, "face group does not exist or is not owned by the user") } return nil, err } return &faceGroup, nil } func getUserOwnedImageFaces(tx *gorm.DB, user *models.User, imageFaceIDs []int) ([]*models.ImageFace, error) { if err := user.FillAlbums(tx); err != nil { return nil, err } userAlbumIDs := make([]int, len(user.Albums)) for i, album := range user.Albums { userAlbumIDs[i] = album.ID } var userOwnedImageFaces []*models.ImageFace if err := tx. Joins("JOIN media ON media.id = image_faces.media_id"). Where("media.album_id IN (?)", userAlbumIDs). Where("image_faces.id IN (?)", imageFaceIDs). Find(&userOwnedImageFaces).Error; err != nil { return nil, err } return userOwnedImageFaces, nil }