Source code for mongomotor.monkey

# -*- coding: utf-8 -*-

# Copyright 2016 Juca Crispim <juca@poraodojuca.net>

# This file is part of mongomotor.

# mongomotor is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# mongomotor is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with mongomotor. If not, see <http://www.gnu.org/licenses/>.

from copy import copy
from asyncblink import signal
from mongoengine import connection, dereference, signals
from mongoengine.queryset import base
from pymongo.mongo_client import MongoClient
from mongomotor.dereference import MongoMotorDeReference


[docs]class MonkeyPatcher: def __init__(self): self.patched = {} # if the original patched object is a dict, indicates if # we should merge the original dict with the dict existing # when leaving the context manager. self._update_original_dict = False def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): for obj, patches in self.patched.items(): for attr, origobj in patches.items(): if self._update_original_dict: current_obj = getattr(obj, attr) if hasattr(current_obj, 'update'): origobj.update(current_obj) setattr(obj, attr, origobj)
[docs] def patch_item(self, obj, attr, newitem, undo=True): """Sets ``attr`` in ``obj`` with ``newitem``. If not ``undo`` the item will continue patched after leaving the context manager""" NONE = object() olditem = getattr(obj, attr, NONE) if undo and olditem is not NONE: self.patched.setdefault(obj, {}).setdefault(attr, olditem) setattr(obj, attr, newitem)
[docs] def patch_get_mongodb_version(self): """Patches mongoengine's get_mongodb_version to use a function does not reach the database. """ from .connection import get_db_version from mongoengine import fields self.patch_item(fields, 'get_mongodb_version', get_db_version)
[docs] def patch_db_clients(self, client): """Patches the db clients used to connect to mongodb. :param client: Which client should be used.""" self.patch_item(connection, 'MongoClient', client)
[docs] def patch_async_connections(self): """Patches mongoengine.connection._connections removing all asynchronous connections from there. It is used when switching to a synchronous connection to avoid mongoengine returning a asynchronous connection with the same configuration.""" connections = copy(connection._connections) for alias, conn in connection._connections.items(): conn = connections[alias] if not isinstance(conn, MongoClient): del connections[alias] # we merge the connections no in next time we use the # sync one we don't need to connect again. self._update_original_dict = True self.patch_item(connection, '_connections', connections)
[docs] def patch_sync_connections(self): """Patches mongoengine.connection._connections removing all synchronous connections from there. """ connections = copy(connection._connections) for alias, conn in connection._connections.items(): conn = connections[alias] if isinstance(conn, MongoClient): del connections[alias] self._update_original_dict = True self.patch_item(connection, '_connections', connections)
[docs] def patch_dereference(self): self.patch_item(dereference, 'DeReference', MongoMotorDeReference, undo=False)
[docs] def patch_qs_stop_iteration(self): """Patches StopIterations raised by mongoengine's queryset replacing it by AsyncStopIteration so it can interact well with futures.""" self.patch_item(base, 'StopIteration', StopAsyncIteration, undo=False) # self.patch_item(queryset, 'StopIteration', StopAsyncIteration) self.patch_item(dereference, 'StopIteration', StopAsyncIteration, undo=False)
[docs] def patch_signals(self): """Patches mongoengine signals to use asyncblink signals""" pre_init = signal('pre_init') self.patch_item(signals, 'pre_init', pre_init) post_init = signal('post_init') self.patch_item(signals, 'post_init', post_init) pre_save = signal('pre_save') self.patch_item(signals, 'pre_save', pre_save) post_save = signal('post_save') self.patch_item(signals, 'post_save', post_save) pre_save_post_validation = signal('pre_save_post_validation') self.patch_item(signals, 'pre_save_post_validation', pre_save_post_validation) pre_delete = signal('pre_delete') self.patch_item(signals, 'pre_delete', pre_delete) post_delete = signal('post_delete') self.patch_item(signals, 'post_delete', post_delete) pre_bulk_insert = signal('pre_bulk_insert') self.patch_item(signals, 'pre_bulk_insert', pre_bulk_insert) post_bulk_insert = signal('post_bulk_insert') self.patch_item(signals, 'post_bulk_insert', post_bulk_insert)