# -*- coding: utf-8 -*-
# Copyright 2016-2017, 2025 Juca Crispim <juca@poraodojuca.dev>
# This file is part of mongomotor.
# mongomotor is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# mongomotor is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with mongomotor. If not, see <http://www.gnu.org/licenses/>.
from bson import DBRef
import gridfs
from mongoengine import fields
from mongoengine.base import get_document
from mongoengine.base.datastructures import (
    BaseDict, BaseList, EmbeddedDocumentList)
from mongoengine.common import _import_class
from mongoengine.connection import get_db
from mongoengine.errors import DoesNotExist
from mongoengine.fields import GridFSError
from mongoengine.fields import *  # noqa f403 for the sake of the api
[docs]
class BaseAsyncReferenceField:
    """Base class for async reference fields."""
    def __get__(self, instance, owner):
        if instance is None:
            return self
        auto_dereference = instance._fields[self.name]._auto_dereference
        if not auto_dereference:
            return instance._data.get(self.name)
        async def get():
            if getattr(instance._data[self.name], "_dereferenced", False):
                return instance._data.get(self.name)
            ref_value = instance._data.get(self.name)
            if isinstance(ref_value, dict) and '_ref' in ref_value.keys():
                ref = ref_value['_ref']
                cls = get_document(ref_value['_cls'])
                instance._data[self.name] = await self._lazy_load_ref(
                    cls, ref)
            elif auto_dereference and isinstance(ref_value, DBRef):
                if hasattr(ref_value, "cls"):
                    # Dereference using the class type specified in the
                    # reference
                    cls = get_document(ref_value.cls)
                else:
                    cls = self.document_type
                instance._data[self.name] = await self._lazy_load_ref(
                    cls, ref_value)
                instance._data[self.name]._dereferenced = True
            return instance._data.get(self.name)
        return get()
    @staticmethod
    async def _lazy_load_ref(ref_cls, dbref):
        dereferenced_son = await ref_cls._get_db().dereference(dbref)
        if dereferenced_son is None:
            raise DoesNotExist(
                f"Trying to dereference unknown document {dbref}")
        return ref_cls._from_son(dereferenced_son) 
[docs]
class ReferenceField(BaseAsyncReferenceField, fields.ReferenceField):
    """A reference to a document that will be dereferenced on
    access.
    Use the `reverse_delete_rule` to handle what should happen if the document
    the field is referencing is deleted.  EmbeddedDocuments, DictFields and
    MapFields does not support reverse_delete_rule and an
    `InvalidDocumentError` will be raised if trying to set on one of these
    Document / Field types.
    The options are:
      * DO_NOTHING (0)  - don't do anything (default).
      * NULLIFY    (1)  - Updates the reference to null.
      * CASCADE    (2)  - Deletes the documents associated with the reference.
      * DENY       (3)  - Prevent the deletion of the reference object.
      * PULL       (4)  - Pull the reference from a
        :class:`~mongomotor.fields.ListField` of references
    Alternative syntax for registering delete rules (useful when implementing
    bi-directional delete rules)
    .. code-block:: python
        class Bar(Document):
            content = StringField()
            foo = ReferenceField('Foo')
        Foo.register_delete_rule(Bar, 'foo', NULLIFY)
    """ 
[docs]
class GenericReferenceField(
        BaseAsyncReferenceField, fields.GenericReferenceField):
    pass 
[docs]
class ComplexBaseField(fields.ComplexBaseField):
    def __get__(self, instance, owner):
        if instance is None:
            return self
        auto_dereference = instance._fields[self.name]._auto_dereference
        dereference = auto_dereference and isinstance(
            self.field, (GenericReferenceField, ReferenceField))
        if not dereference:
            val = instance._data.get(self.name)
            if val is not None:
                self._convert_value(instance, val)
            return instance._data.get(self.name)
        async def get():
            if getattr(instance._data[self.name], "_dereferenced", False):
                return instance._data.get(self.name)
            ref_values = instance._data.get(self.name)
            instance._data[self.name] = await self._lazy_load_refs(
                ref_values=ref_values, instance=instance, name=self.name,
                max_depth=1
            )
            if hasattr(instance._data[self.name], "_dereferenced"):
                instance._data[self.name]._dereferenced = True
            value = instance._data[self.name]
            self._convert_value(instance, value)
            value = instance._data[self.name]
            if (
                auto_dereference
                and instance._initialised
                and isinstance(value, (BaseList, BaseDict))
                and not value._dereferenced
            ):
                value = await self._lazy_load_refs(
                    ref_values=value, instance=instance, name=self.name,
                    max_depth=1
                )
                value._dereferenced = True
                instance._data[self.name] = value
            return value
        return get()
    @staticmethod
    async def _lazy_load_refs(instance, name, ref_values, *, max_depth):
        _dereference = _import_class("DeReference")()
        documents = await _dereference(
            ref_values,
            max_depth=max_depth,
            instance=instance,
            name=name,
        )
        return documents
    def _convert_value(self, instance, value):
        # Convert lists / values so we can watch for any changes on them
        if isinstance(value, (list, tuple)):
            if issubclass(type(self), fields.EmbeddedDocumentListField) \
               
and not isinstance(value, EmbeddedDocumentList):
                value = EmbeddedDocumentList(value, instance, self.name)
            elif not isinstance(value, BaseList):
                value = BaseList(value, instance, self.name)
                instance._data[self.name] = value
        elif isinstance(value, dict) and not isinstance(value, BaseDict):
            value = BaseDict(value, instance, self.name)
            instance._data[self.name] = value 
[docs]
class ListField(ComplexBaseField, fields.ListField):
    pass 
[docs]
class DictField(ComplexBaseField, fields.DictField):
    pass 
[docs]
class GridFSProxy(fields.GridFSProxy):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    async def __aenter__(self):
        return self
    async def __aexit__(self, exc, exc_type, exc_tb):
        await self.close()
    @property
    def fs(self):
        if not self._fs:
            self._fs = gridfs.AsyncGridFS(
                get_db(self.db_alias), collection=self.collection_name)
        return self._fs
[docs]
    async def get(self, grid_id=None):
        if grid_id:
            self.grid_id = grid_id
        if self.grid_id is None:
            return None
        try:
            if self.gridout is None:
                self.gridout = await self.fs.get(self.grid_id)
            return self.gridout
        except Exception:
            # File has been deleted
            return None 
[docs]
    async def close(self):
        if self.newfile:
            await self.newfile.close()
            self.newfile = None 
[docs]
    async def write(self, data):
        """Writes ``data`` to gridfs.
        :param data: String or bytes to write to gridfs."""
        if self.grid_id:
            if not self.newfile:
                raise GridFSError(  # noqa f405
                    'This document already has a file. Either '
                    'delete it or call replace to overwrite it')
        else:
            self.new_file()
        await self.newfile.write(data) 
[docs]
    async def put(self, file_obj, **kwargs):
        if self.grid_id:
            raise GridFSError(
                "This document already has a file. Either delete "
                "it or call replace to overwrite it"
            )
        self.grid_id = await self.fs.put(file_obj, **kwargs)
        self._mark_as_changed() 
[docs]
    async def read(self, size=-1):
        gridout = await self.get()
        if gridout is None:
            return None
        else:
            try:
                return await gridout.read(size)
            except Exception:
                return "" 
[docs]
    async def delete(self):
        # Delete file from GridFS, FileField still remains
        await self.fs.delete(self.grid_id)
        self.grid_in = None
        self.grid_id = None
        self.grid_out = None
        self._mark_as_changed() 
[docs]
    async def replace(self, data, **metadata):
        """Replaces the contents of the file with ``data``.
        :param data: A byte-string to write to gridfs.
        :param metatada: File metadata.
        """
        await self.delete()
        await self.put(data, **metadata) 
 
[docs]
class FileField(fields.FileField):
    proxy_class = GridFSProxy