# -*- coding: utf-8 -*-
# Copyright 2016 Juca Crispim <juca@poraodojuca.net>
# This file is part of mongomotor.
# mongomotor is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# mongomotor is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with mongomotor. If not, see <http://www.gnu.org/licenses/>.
from bson.code import Code
from bson import SON
import functools
import os
from mongoengine import signals, DENY, CASCADE, NULLIFY, PULL
from mongoengine.connection import get_db
from mongoengine.queryset.queryset import QuerySet as MEQuerySet
from mongoengine.errors import OperationError
from motor.core import coroutine_annotation
from mongomotor.exceptions import ConfusionError
from mongomotor.metaprogramming import (get_future, AsyncGenericMetaclass,
Async, asynchronize)
from mongomotor.monkey import MonkeyPatcher
# for tests
TEST_ENV = os.environ.get('MONGOMOTOR_TEST_ENV')
[docs]class QuerySet(MEQuerySet, metaclass=AsyncGenericMetaclass):
distinct = Async()
explain = Async()
in_bulk = Async()
map_reduce = Async()
modify = Async()
update = Async()
def __repr__(self): # pragma no cover
return self.__class__.__name__
def __len__(self):
raise TypeError('len() is not supported. Use count()')
def _iter_results(self):
try:
return super()._iter_results()
except StopIteration:
raise StopAsyncIteration
def __getitem__(self, index):
# If we received an slice we will return a queryset
# and as we will not touch the db now we do not need a future
# here
if isinstance(index, slice):
return super().__getitem__(index)
else:
sync_getitem = MEQuerySet.__getitem__
async_getitem = asynchronize(sync_getitem)
return async_getitem(self, index)
def __aiter__(self):
return self
async def __anext__(self):
async for doc in self._cursor:
mm_doc = self._document._from_son(
doc,
_auto_dereference=self._auto_dereference)
return mm_doc
else:
raise StopAsyncIteration()
[docs] @coroutine_annotation
def get(self, *q_objs, **query):
"""Retrieve the the matching object raising
:class:`~mongoengine.queryset.MultipleObjectsReturned` or
`DocumentName.MultipleObjectsReturned` exception if multiple results
and :class:`~mongoengine.queryset.DoesNotExist` or
`DocumentName.DoesNotExist` if no results are found.
"""
queryset = self.clone()
queryset = queryset.order_by().limit(2)
queryset = queryset.filter(*q_objs, **query)
future = get_future(self)
def _get_cb(done_future):
docs = done_future.result()
if len(docs) < 1:
msg = ("%s matching query does not exist."
% queryset._document._class_name)
future.set_exception(queryset._document.DoesNotExist(msg))
elif len(docs) > 1:
msg = 'More than 1 item returned'
future.set_exception(
queryset._document.MultipleObjectsReturned(msg))
else:
future.set_result(docs[0])
list_future = queryset.to_list(length=2)
list_future.add_done_callback(_get_cb) # pragma no cover
return future
[docs] @coroutine_annotation
def first(self):
"""Retrieve the first object matching the query.
"""
queryset = self.clone()
first_future = queryset[0]
future = get_future(self)
def first_cb(first_future):
try:
result = first_future.result()
future.set_result(result)
except IndexError:
result = None
future.set_result(result)
except Exception as e:
future.set_exception(e)
first_future.add_done_callback(first_cb)
return future
[docs] @coroutine_annotation
def count(self, with_limit_and_skip=True):
"""Counts the documents in the queryset.
:param with_limit_and_skip: Indicates if limit and skip applied to
the queryset should be taken into account."""
if self._limit == 0 and with_limit_and_skip or self._none:
return 0
kw = {}
if with_limit_and_skip and self._limit:
kw['limit'] = self._limit
if with_limit_and_skip and self._skip:
kw['skip'] = self._skip
return self._collection.count_documents(self._query, **kw)
[docs] @coroutine_annotation
def insert(self, doc_or_docs, load_bulk=True, write_concern=None):
"""bulk insert documents
:param doc_or_docs: a document or list of documents to be inserted
:param load_bulk (optional): If True returns the list of document
instances
:param write_concern: Extra keyword arguments are passed down to
:meth:`~pymongo.collection.Collection.insert`
which will be used as options for the resultant
``getLastError`` command. For example,
``insert(..., {w: 2, fsync: True})`` will wait until at least
two servers have recorded the write and will force an fsync on
each server being written to.
By default returns document instances, set ``load_bulk`` to False to
return just ``ObjectIds``
"""
super_insert = MEQuerySet.insert
async_in_bulk = self.in_bulk
# this sync method is not really sync, it uses motor sockets and
# greenlets events, but looks like sync, so...
sync_in_bulk = functools.partial(self.in_bulk.__wrapped__, self)
insert_future = get_future(self)
with MonkeyPatcher() as patcher:
# here we change the method with the async api for the method
# with a sync api so I don't need to rewrite the mongoengine
# method.
patcher.patch_item(self, 'in_bulk', sync_in_bulk, undo=False)
future = asynchronize(super_insert)(self, doc_or_docs,
load_bulk=load_bulk,
write_concern=write_concern)
def cb(future):
try:
result = future.result()
insert_future.set_result(result)
except Exception as e:
insert_future.set_exception(e)
finally:
patcher.patch_item(self, 'in_bulk', async_in_bulk,
undo=False)
future.add_done_callback(cb)
return insert_future
[docs] async def delete(self, write_concern=None, _from_doc_delete=False,
cascade_refs=None):
"""Deletes the documents matched by the query.
:param write_concern: Extra keyword arguments are passed down which
will be used as options for the resultant
``getLastError`` command. For example,
``save(..., write_concern={w: 2, fsync: True}, ...)`` will
wait until at least two servers have recorded the write and
will force an fsync on the primary server.
:param _from_doc_delete: True when called from document delete
therefore signals will have been triggered so don't loop.
:returns number of deleted documents
"""
queryset = self.clone()
doc = queryset._document
if write_concern is None:
write_concern = {}
# Handle deletes where skips or limits have been applied or
# there is an untriggered delete signal
has_delete_signal = signals.signals_available and (
signals.pre_delete.has_receivers_for(self._document) or
signals.post_delete.has_receivers_for(self._document))
call_document_delete = (queryset._skip or queryset._limit or
has_delete_signal) and not _from_doc_delete
if call_document_delete:
async_method = asynchronize(self._document_delete)
return async_method(queryset, write_concern)
await self._check_delete_rules(doc, queryset, cascade_refs,
write_concern)
r = await queryset._collection.delete_many(
queryset._query, **write_concern)
return r
[docs] @coroutine_annotation
def upsert_one(self, write_concern=None, **update):
"""Overwrite or add the first document matched by the query.
:param write_concern: Extra keyword arguments are passed down which
will be used as options for the resultant
``getLastError`` command. For example,
``save(..., write_concern={w: 2, fsync: True}, ...)`` will
wait until at least two servers have recorded the write and
will force an fsync on the primary server.
:param update: Django-style update keyword arguments
:returns the new or overwritten document
"""
update_future = self.update(multi=False, upsert=True,
write_concern=write_concern,
full_result=True, **update)
upsert_future = get_future(self)
def update_cb(update_future):
try:
result = update_future.result().raw_result
if result['updatedExisting']:
document_future = self.first()
else:
document_future = self._document.objects.with_id(
result['upserted'])
def doc_cb(document_future):
try:
result = document_future.result()
upsert_future.set_result(result)
except Exception as e:
upsert_future.set_exception(e)
document_future.add_done_callback(doc_cb)
except Exception as e:
upsert_future.set_exception(e)
update_future.add_done_callback(update_cb)
return upsert_future
[docs] @coroutine_annotation
def to_list(self, length=100):
"""Returns a list of the current documents in the queryset.
:param length: maximum number of documents to return for this call."""
list_future = get_future(self)
def _to_list_cb(future):
# Transforms mongo's raw documents into
# mongomotor documents
docs_list = future.result()
final_list = [self._document._from_son(
d, _auto_dereference=self._auto_dereference)
for d in docs_list]
list_future.set_result(final_list)
cursor = self._cursor
future = cursor.to_list(length)
future.add_done_callback(_to_list_cb)
return list_future
[docs] async def item_frequencies(self, field, normalize=False):
"""Returns a dictionary of all items present in a field across
the whole queried set of documents, and their corresponding frequency.
This is useful for generating tag clouds, or searching documents.
.. note::
Can only do direct simple mappings and cannot map across
:class:`~mongoengine.fields.ReferenceField` or
:class:`~mongoengine.fields.GenericReferenceField` for more complex
counting a manual aggretation call would be required.
If the field is a :class:`~mongoengine.fields.ListField`,
the items within each list will be counted individually.
:param field: the field to use
:param normalize: normalize the results so they add to 1.0
"""
cursor = self._document._get_collection().aggregate([
{'$match': self._query},
{'$unwind': f'${field}'},
{'$group': {'_id': '$' + field, 'total': {'$sum': 1}}}
])
freqs = {}
async for doc in cursor:
freqs[doc['_id']] = doc['total']
if normalize:
count = sum(freqs.values())
freqs = dict([(k, float(v) / count)
for k, v in list(freqs.items())])
return freqs
[docs] async def average(self, field):
"""Average over the values of the specified field.
:param field: the field to average over; use dot-notation to refer to
embedded document fields
This method is more performant than the regular `average`, because it
uses the aggregation framework instead of map-reduce.
"""
cursor = self._document._get_collection().aggregate([
{'$match': self._query},
{'$group': {'_id': 'avg', 'total': {'$avg': '$' + field}}}
])
avg = 0
async for doc in cursor:
avg = doc['total']
break
return avg
[docs] async def sum(self, field):
"""Sum over the values of the specified field.
:param field: the field to sum over; use dot-notation to refer to
embedded document fields
This method is more performant than the regular `sum`, because it uses
the aggregation framework instead of map-reduce.
"""
cursor = self._document._get_collection().aggregate([
{'$match': self._query},
{'$group': {'_id': 'sum', 'total': {'$sum': '$' + field}}}
])
r = 0
async for doc in cursor:
r = doc['total']
break
return r
@property
@coroutine_annotation
def fetch_next(self):
return self._cursor.fetch_next
[docs] def next_object(self):
raw = self._cursor.next_object()
return self._document._from_son(
raw, _auto_dereference=self._auto_dereference)
[docs] def no_cache(self):
"""Convert to a non-caching queryset
"""
if self._result_cache is not None:
raise OperationError('QuerySet already cached')
return self._clone_into(QuerySetNoCache(self._document,
self._collection))
def _get_code(self, func):
f_scope = {}
if isinstance(func, Code):
f_scope = func.scope
func = str(func)
func = Code(self._sub_js_fields(func), f_scope)
return func
def _get_output(self, output):
if isinstance(output, str) or isinstance(output, SON):
out = output
elif isinstance(output, dict):
ordered_output = []
for part in ('replace', 'merge', 'reduce'):
value = output.get(part)
if value:
ordered_output.append((part, value))
break
else:
raise OperationError("actionData not specified for output")
db_alias = output.get('db_alias')
remaing_args = ['db', 'sharded', 'nonAtomic']
if db_alias:
ordered_output.append(('db', get_db(db_alias).name))
del remaing_args[0]
for part in remaing_args:
value = output.get(part)
if value:
ordered_output.append((part, value))
out = SON(ordered_output)
else:
raise ConfusionError('Bad output type {}'.format(type(output)))
return out
async def _check_delete_rules(self, doc, queryset, cascade_refs,
write_concern):
"""Checks the delete rules for documents being deleted in a queryset.
Raises an exception if any document has a DENY rule."""
delete_rules = doc._meta.get('delete_rules') or {}
# Check for DENY rules before actually deleting/nullifying any other
# references
delete_rules = delete_rules.copy()
for rule_entry in delete_rules:
document_cls, field_name = rule_entry
if document_cls._meta.get('abstract'):
continue
rule = doc._meta['delete_rules'][rule_entry]
if rule == DENY and document_cls.objects(
**{field_name + '__in': self}).count() > 0:
msg = ("Could not delete document (%s.%s refers to it)"
% (document_cls.__name__, field_name))
raise OperationError(msg)
if not delete_rules:
return
r = None
for rule_entry in delete_rules:
document_cls, field_name = rule_entry
if document_cls._meta.get('abstract'):
continue
rule = doc._meta['delete_rules'][rule_entry]
if rule == CASCADE:
cascade_refs = set() if cascade_refs is None else cascade_refs
for ref in queryset:
cascade_refs.add(ref.id)
ref_q = document_cls.objects(**{field_name + '__in': self,
'id__nin': cascade_refs})
count = await ref_q.count()
if count > 0:
r = await ref_q.delete(write_concern=write_concern,
cascade_refs=cascade_refs)
elif rule in (NULLIFY, PULL):
if rule == NULLIFY:
updatekw = {'unset__%s' % field_name: 1}
else:
updatekw = {'pull_all__%s' % field_name: self}
r = await document_cls.objects(
**{field_name + '__in': self}).update(
write_concern=write_concern, **updatekw)
return r
def _document_delete(self, queryset, write_concern):
"""Delete the documents in queryset by calling the document's delete
method."""
cnt = 0
for doc in queryset:
doc.delete(**write_concern)
cnt += 1
return cnt
def _get_loop(self):
"""Returns the ioloop for this queryset."""
db = self._document._get_db()
loop = db.get_io_loop()
return loop
[docs]class QuerySetNoCache(QuerySet):
"""A non caching QuerySet"""
[docs] def cache(self):
"""Convert to a caching queryset
"""
return self._clone_into(QuerySet(self._document, self._collection))
def __iter__(self):
queryset = self
if queryset._iter:
queryset = self.clone()
queryset.rewind()
return queryset