1
Fork 0
photoview/api/scanner/face_detection/face_detector.go

282 lines
7.0 KiB
Go
Raw Normal View History

2021-02-15 17:35:28 +01:00
package face_detection
import (
"log"
"path/filepath"
"sync"
"github.com/Kagami/go-face"
"github.com/photoview/photoview/api/graphql/models"
"github.com/pkg/errors"
"gorm.io/gorm"
)
type FaceDetector struct {
mutex sync.Mutex
db *gorm.DB
rec *face.Recognizer
faceDescriptors []face.Descriptor
faceGroupIDs []int32
imageFaceIDs []int
2021-02-15 17:35:28 +01:00
}
var GlobalFaceDetector FaceDetector
func InitializeFaceDetector(db *gorm.DB) error {
log.Println("Initializing face detector")
rec, err := face.NewRecognizer(filepath.Join("data", "models"))
if err != nil {
return errors.Wrap(err, "initialize facedetect recognizer")
}
faceDescriptors, faceGroupIDs, imageFaceIDs, err := getSamplesFromDatabase(db)
2021-02-15 17:35:28 +01:00
if err != nil {
return errors.Wrap(err, "get face detection samples from database")
}
GlobalFaceDetector = FaceDetector{
db: db,
rec: rec,
faceDescriptors: faceDescriptors,
faceGroupIDs: faceGroupIDs,
imageFaceIDs: imageFaceIDs,
2021-02-15 17:35:28 +01:00
}
return nil
}
func getSamplesFromDatabase(db *gorm.DB) (samples []face.Descriptor, faceGroupIDs []int32, imageFaceIDs []int, err error) {
2021-02-15 20:31:17 +01:00
var imageFaces []*models.ImageFace
if err = db.Find(&imageFaces).Error; err != nil {
return
}
samples = make([]face.Descriptor, len(imageFaces))
faceGroupIDs = make([]int32, len(imageFaces))
imageFaceIDs = make([]int, len(imageFaces))
2021-02-15 20:31:17 +01:00
for i, imgFace := range imageFaces {
samples[i] = face.Descriptor(imgFace.Descriptor)
faceGroupIDs[i] = int32(imgFace.FaceGroupID)
imageFaceIDs[i] = imgFace.ID
2021-02-15 20:31:17 +01:00
}
2021-02-15 17:35:28 +01:00
return
}
2021-02-15 20:31:17 +01:00
// DetectFaces finds the faces in the given image and saves them to the database
2021-02-15 17:35:28 +01:00
func (fd *FaceDetector) DetectFaces(media *models.Media) error {
if err := fd.db.Model(media).Preload("MediaURL").First(&media).Error; err != nil {
return err
}
var thumbnailURL *models.MediaURL
for _, url := range media.MediaURL {
if url.Purpose == models.PhotoThumbnail {
thumbnailURL = &url
thumbnailURL.Media = media
break
}
}
if thumbnailURL == nil {
return errors.New("thumbnail url is missing")
}
thumbnailPath, err := thumbnailURL.CachedPath()
if err != nil {
return err
}
fd.mutex.Lock()
faces, err := fd.rec.RecognizeFile(thumbnailPath)
fd.mutex.Unlock()
if err != nil {
return errors.Wrap(err, "error read faces")
}
for _, face := range faces {
2021-02-16 11:27:28 +01:00
fd.classifyFace(&face, media, thumbnailPath)
2021-02-15 17:35:28 +01:00
}
return nil
}
2021-02-19 19:24:31 +01:00
func (fd *FaceDetector) classifyDescriptor(descriptor face.Descriptor) int32 {
2021-02-19 23:30:43 +01:00
return int32(fd.rec.ClassifyThreshold(descriptor, 0.3))
2021-02-19 19:24:31 +01:00
}
2021-02-16 11:27:28 +01:00
func (fd *FaceDetector) classifyFace(face *face.Face, media *models.Media, imagePath string) error {
2021-02-15 17:35:28 +01:00
fd.mutex.Lock()
defer fd.mutex.Unlock()
2021-02-19 19:24:31 +01:00
match := fd.classifyDescriptor(face.Descriptor)
2021-02-15 17:35:28 +01:00
2021-02-16 11:27:28 +01:00
faceRect, err := models.ToDBFaceRectangle(face.Rectangle, imagePath)
if err != nil {
return err
}
2021-02-15 17:35:28 +01:00
imageFace := models.ImageFace{
MediaID: media.ID,
Descriptor: models.FaceDescriptor(face.Descriptor),
2021-02-16 11:27:28 +01:00
Rectangle: *faceRect,
2021-02-15 17:35:28 +01:00
}
var faceGroup models.FaceGroup
// If no match add it new to samples
if match < 0 {
log.Println("No match, assigning new face")
faceGroup = models.FaceGroup{
ImageFaces: []models.ImageFace{imageFace},
}
if err := fd.db.Create(&faceGroup).Error; err != nil {
return err
}
} else {
log.Println("Found match")
if err := fd.db.First(&faceGroup, int(match)).Error; err != nil {
return err
}
if err := fd.db.Model(&faceGroup).Association("ImageFaces").Append(&imageFace); err != nil {
return err
}
}
fd.faceDescriptors = append(fd.faceDescriptors, face.Descriptor)
fd.faceGroupIDs = append(fd.faceGroupIDs, int32(faceGroup.ID))
fd.imageFaceIDs = append(fd.imageFaceIDs, imageFace.ID)
2021-02-15 17:35:28 +01:00
fd.rec.SetSamples(fd.faceDescriptors, fd.faceGroupIDs)
2021-02-15 17:35:28 +01:00
return nil
}
2021-02-19 19:24:31 +01:00
func (fd *FaceDetector) MergeCategories(sourceID int32, destID int32) {
fd.mutex.Lock()
defer fd.mutex.Unlock()
for i := range fd.faceGroupIDs {
if fd.faceGroupIDs[i] == sourceID {
fd.faceGroupIDs[i] = destID
}
}
}
func (fd *FaceDetector) MergeImageFaces(imageFaceIDs []int, destFaceGroupID int32) {
fd.mutex.Lock()
defer fd.mutex.Unlock()
for i := range fd.faceGroupIDs {
imageFaceID := fd.imageFaceIDs[i]
for _, id := range imageFaceIDs {
if imageFaceID == id {
fd.faceGroupIDs[i] = destFaceGroupID
break
}
2021-02-19 19:24:31 +01:00
}
}
}
func (fd *FaceDetector) RecognizeUnlabeledFaces(tx *gorm.DB, user *models.User) ([]*models.ImageFace, error) {
unrecognizedDescriptors := make([]face.Descriptor, 0)
unrecognizedFaceGroupIDs := make([]int32, 0)
unrecognizedImageFaceIDs := make([]int, 0)
2021-02-19 19:24:31 +01:00
newFaceGroupIDs := make([]int32, 0)
newDescriptors := make([]face.Descriptor, 0)
newImageFaceIDs := make([]int, 0)
2021-02-19 19:24:31 +01:00
var unlabeledFaceGroups []*models.FaceGroup
err := tx.
Joins("JOIN image_faces ON image_faces.face_group_id = face_groups.id").
Joins("JOIN media ON image_faces.media_id = media.id").
Where("face_groups.label IS NULL").
Where("media.album_id IN (?)",
tx.Select("album_id").Table("user_albums").Where("user_id = ?", user.ID),
).
Find(&unlabeledFaceGroups).Error
if err != nil {
return nil, err
}
fd.mutex.Lock()
defer fd.mutex.Unlock()
for i := range fd.faceDescriptors {
descriptor := fd.faceDescriptors[i]
faceGroupID := fd.faceGroupIDs[i]
imageFaceID := fd.imageFaceIDs[i]
2021-02-19 19:24:31 +01:00
isUnlabeled := false
2021-02-19 19:24:31 +01:00
for _, unlabeledFaceGroup := range unlabeledFaceGroups {
if faceGroupID == int32(unlabeledFaceGroup.ID) {
isUnlabeled = true
2021-02-19 19:24:31 +01:00
continue
}
}
if isUnlabeled {
unrecognizedFaceGroupIDs = append(unrecognizedFaceGroupIDs, faceGroupID)
unrecognizedDescriptors = append(unrecognizedDescriptors, descriptor)
unrecognizedImageFaceIDs = append(unrecognizedImageFaceIDs, imageFaceID)
2021-02-19 19:24:31 +01:00
} else {
newFaceGroupIDs = append(newFaceGroupIDs, faceGroupID)
newDescriptors = append(newDescriptors, descriptor)
newImageFaceIDs = append(newImageFaceIDs, imageFaceID)
2021-02-19 19:24:31 +01:00
}
}
fd.faceGroupIDs = newFaceGroupIDs
fd.faceDescriptors = newDescriptors
fd.imageFaceIDs = newImageFaceIDs
2021-02-19 19:24:31 +01:00
updatedImageFaces := make([]*models.ImageFace, 0)
for i := range unrecognizedDescriptors {
descriptor := unrecognizedDescriptors[i]
faceGroupID := unrecognizedFaceGroupIDs[i]
imageFaceID := unrecognizedImageFaceIDs[i]
2021-02-19 19:24:31 +01:00
match := fd.classifyDescriptor(descriptor)
2021-02-19 19:24:31 +01:00
if match < 0 {
// still no match, we can readd it to the list
fd.faceGroupIDs = append(fd.faceGroupIDs, faceGroupID)
fd.faceDescriptors = append(fd.faceDescriptors, descriptor)
fd.imageFaceIDs = append(fd.imageFaceIDs, imageFaceID)
2021-02-19 19:24:31 +01:00
} else {
// found new match, update the database
var imageFace models.ImageFace
if err := tx.Model(&models.ImageFace{}).First(imageFace, imageFaceID).Error; err != nil {
2021-02-19 19:24:31 +01:00
return nil, err
}
if err := tx.Model(&imageFace).Update("face_group_id", int(faceGroupID)).Error; err != nil {
2021-02-19 19:24:31 +01:00
return nil, err
}
updatedImageFaces = append(updatedImageFaces, &imageFace)
fd.faceGroupIDs = append(fd.faceGroupIDs, match)
fd.faceDescriptors = append(fd.faceDescriptors, descriptor)
fd.imageFaceIDs = append(fd.imageFaceIDs, imageFaceID)
2021-02-19 19:24:31 +01:00
}
}
return updatedImageFaces, nil
}