import logging
from itertools import chain
from pprint import pformat
from functools import wraps
import six
from elasticsearch import Elasticsearch
from elasticsearch.exceptions import NotFoundError
import transaction as zope_transaction
from zope.interface import implementer
from transaction.interfaces import ISavepointDataManager
from .query import ElasticQuery
from .result import ElasticResultRecord
log = logging.getLogger(__name__)
ANALYZER_SETTINGS = {
"analysis": {
"filter": {
"snowball": {
"type": "snowball",
"language": "English"
},
},
"analyzer": {
"lowercase": {
"type": "custom",
"tokenizer": "standard",
"filter": ["standard", "lowercase"]
},
"email": {
"type": "custom",
"tokenizer": "uax_url_email",
"filter": ["standard", "lowercase"]
},
"content": {
"type": "custom",
"tokenizer": "standard",
"char_filter": ["html_strip"],
"filter": ["standard", "lowercase", "stop", "snowball"]
}
}
}
}
CREATE_INDEX_SETTINGS = ANALYZER_SETTINGS.copy()
CREATE_INDEX_SETTINGS.update({
"index": {
"number_of_shards": 2,
"number_of_replicas": 0
},
})
STATUS_ACTIVE = 'active'
STATUS_CHANGED = 'changed'
_CLIENT_STATE = {}
@implementer(ISavepointDataManager)
class ElasticDataManager(object):
def __init__(self, client, transaction_manager):
self.client = client
self.transaction_manager = transaction_manager
t = transaction_manager.get()
t.join(self)
_CLIENT_STATE[id(client)] = STATUS_ACTIVE
self._reset()
def _reset(self):
log.error('_reset(%s)', self)
self.client.uncommitted = []
def _finish(self):
log.error('_finish(%s)', self)
client = self.client
del _CLIENT_STATE[id(client)]
def abort(self, transaction):
log.error('abort(%s)', self)
self._reset()
self._finish()
def tpc_begin(self, transaction):
log.error('tpc_begin(%s)', self)
pass
def commit(self, transaction):
log.error('commit(%s)', self)
pass
def tpc_vote(self, transaction):
log.error('tpc_vote(%s)', self)
# XXX Ideally, we'd try to check the uncommitted queue and make sure
# everything looked ok. Note sure how we can do that, though.
pass
def tpc_finish(self, transaction):
# Actually persist the uncommitted queue.
log.error('tpc_finish(%s)', self)
log.warn("running: %r", self.client.uncommitted)
for cmd, args, kwargs in self.client.uncommitted:
kwargs['immediate'] = True
getattr(self.client, cmd)(*args, **kwargs)
self._reset()
self._finish()
def tpc_abort(self, transaction):
log.error('tpc_abort()')
self._reset()
self._finish()
def sortKey(self):
# NOTE: Ideally, we want this to sort *after* database-oriented data
# managers, like the SQLAlchemy one. The double tilde should get us
# to the end.
return '~~elasticsearch' + str(id(self))
def savepoint(self):
return ElasticSavepoint(self)
class ElasticSavepoint(object):
def __init__(self, dm):
self.dm = dm
self.saved = dm.client.uncommitted.copy()
def rollback(self):
self.dm.client.uncommitted = self.saved.copy()
def join_transaction(client, transaction_manager):
client_id = id(client)
existing_state = _CLIENT_STATE.get(client_id, None)
if existing_state is None:
log.error('client %s not found, setting up new data manager',
client_id)
ElasticDataManager(client, transaction_manager)
else:
log.error('client %s found, using existing data manager',
client_id)
_CLIENT_STATE[client_id] = STATUS_CHANGED
def transactional(f):
@wraps(f)
def transactional_inner(client, *args, **kwargs):
immediate = kwargs.pop('immediate', None)
if client.use_transaction:
if immediate:
return f(client, *args, **kwargs)
else:
log.error('enqueueing action: %s: %r, %r', f.__name__, args,
kwargs)
join_transaction(client, client.transaction_manager)
client.uncommitted.append((f.__name__, args, kwargs))
return
return f(client, *args, **kwargs)
return transactional_inner
[docs]class ElasticClient(object):
"""
A handle for interacting with the Elasticsearch backend.
"""
def __init__(self, servers, index, timeout=1.0, disable_indexing=False,
use_transaction=True,
transaction_manager=zope_transaction.manager):
self.index = index
self.disable_indexing = disable_indexing
self.use_transaction = use_transaction
self.transaction_manager = transaction_manager
self.es = Elasticsearch(servers)
[docs] def ensure_index(self, recreate=False):
"""
Ensure that the index exists on the ES server, and has up-to-date
settings.
"""
exists = self.es.indices.exists(self.index)
if recreate or not exists:
if exists:
self.es.indices.delete(self.index)
self.es.indices.create(self.index,
body=dict(settings=CREATE_INDEX_SETTINGS))
[docs] def delete_index(self):
"""
Delete the index on the ES server.
"""
self.es.indices.delete(self.index)
[docs] def ensure_mapping(self, cls, recreate=False):
"""
Put an explicit mapping for the given class if it doesn't already
exist.
"""
doc_type = cls.__name__
doc_mapping = cls.elastic_mapping()
doc_mapping = dict(doc_mapping)
if cls.elastic_parent:
doc_mapping["_parent"] = {
"type": cls.elastic_parent
}
doc_mapping = {doc_type: doc_mapping}
log.debug('Putting mapping: \n%s', pformat(doc_mapping))
if recreate:
try:
self.es.indices.delete_mapping(index=self.index,
doc_type=doc_type)
except NotFoundError:
pass
self.es.indices.put_mapping(index=self.index,
doc_type=doc_type,
body=doc_mapping)
[docs] def delete_mapping(self, cls):
"""
Delete the mapping corresponding to ``cls`` on the server. Does not
delete subclass mappings.
"""
doc_type = cls.__name__
self.es.indices.delete_mapping(index=self.index,
doc_type=doc_type)
[docs] def ensure_all_mappings(self, base_class, recreate=False):
"""
Initialize explicit mappings for all subclasses of the specified
SQLAlcehmy declarative base class.
"""
for cls in base_class._decl_class_registry.values():
if hasattr(cls, 'elastic_mapping'):
self.ensure_mapping(cls, recreate=recreate)
[docs] def get_mappings(self, cls=None):
"""
Return the object mappings currently used by ES.
"""
doc_type = cls and cls.__name__
raw = self.es.indices.get_mapping(index=self.index,
doc_type=doc_type)
return raw[self.index]['mappings']
[docs] def index_object(self, obj, **kw):
"""
Add or update the indexed document for an object.
"""
doc = obj.elastic_document()
doc_type = obj.__class__.__name__
doc_id = doc.pop("_id")
doc_parent = obj.elastic_parent
log.debug('Indexing object:\n%s', pformat(doc))
log.debug('Type is %r', doc_type)
log.debug('ID is %r', doc_id)
log.debug('Parent is %r', doc_parent)
self.index_document(id=doc_id,
doc_type=doc_type,
doc=doc,
parent=doc_parent,
**kw)
[docs] def delete_object(self, obj, safe=False, **kw):
"""
Delete the indexed document for an object.
"""
doc = obj.elastic_document()
doc_type = obj.__class__.__name__
doc_id = doc.pop("_id")
doc_parent = obj.elastic_parent
self.delete_document(id=doc_id,
doc_type=doc_type,
parent=doc_parent,
safe=safe,
**kw)
@transactional
[docs] def index_document(self, id, doc_type, doc, parent=None):
"""
Add or update the indexed document from a raw document source (not an
object).
"""
if self.disable_indexing:
return
kwargs = dict(index=self.index,
body=doc,
doc_type=doc_type,
id=id)
if parent:
kwargs['parent'] = parent
self.es.index(**kwargs)
@transactional
[docs] def delete_document(self, id, doc_type, parent=None, safe=False):
"""
Delete the indexed document based on a raw document source (not an
object).
"""
if self.disable_indexing:
return
kwargs = dict(index=self.index,
doc_type=doc_type,
id=id)
if parent:
kwargs['routing'] = parent
try:
self.es.delete(**kwargs)
except NotFoundError:
if not safe:
raise
[docs] def index_objects(self, objects):
"""
Add multiple objects to the index.
"""
for obj in objects:
self.index_object(obj)
def flush(self, force=True):
self.es.indices.flush(force=force)
[docs] def get(self, obj, routing=None):
"""
Retrieve the ES source document for a given object or (document type,
id) pair.
"""
if isinstance(obj, tuple):
doc_type, doc_id = obj
else:
doc_type, doc_id = obj.__class__.__name__, obj.id
if obj.elastic_parent:
routing = obj.elastic_parent
kwargs = dict(index=self.index,
doc_type=doc_type,
id=doc_id)
if routing:
kwargs['routing'] = routing
r = self.es.get(**kwargs)
return ElasticResultRecord(r)
[docs] def refresh(self):
"""
Refresh the ES index.
"""
self.es.indices.refresh(index=self.index)
[docs] def subtype_names(self, cls):
"""
Return a list of document types to query given an object class.
"""
classes = [cls] + [m.class_ for m in
cls.__mapper__._inheriting_mappers]
return [c.__name__ for c in classes
if hasattr(c, "elastic_mapping")]
[docs] def search(self, body, classes=None, fields=None, **query_params):
"""
Run ES search using default indexes.
"""
doc_types = classes and list(chain.from_iterable(
[doc_type] if isinstance(doc_type, six.string_types) else
self.subtype_names(doc_type)
for doc_type in classes))
if fields:
query_params['fields'] = fields
return self.es.search(index=self.index,
doc_type=','.join(doc_types),
body=body,
**query_params)
[docs] def query(self, *classes, **kw):
"""
Return an ElasticQuery against the specified class.
"""
cls = kw.pop('cls', ElasticQuery)
return cls(client=self, classes=classes, **kw)
[docs] def analyze(self, text, analyzer):
"""
Preview the result of analyzing a block of text using a given analyzer.
"""
return self.es.indices.analyze(index=self.index,
analyzer=analyzer,
text=text)