Source code for rule_engine.ast.expression.control

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
#  rule_engine/ast/expression/control.py
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are
#  met:
#
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above
#    copyright notice, this list of conditions and the following disclaimer
#    in the documentation and/or other materials provided with the
#    distribution.
#  * Neither the name of the project nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
#  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
#  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
#  A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
#  OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
#  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
#  LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
#  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
#  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import collections
import datetime
import functools
import operator
from typing import TYPE_CHECKING, Any, Callable

from ... import errors
from ...types import DataType, is_numeric
from ...types import _DataTypeDef

from ..base import (
        Assignment,
        ExpressionBase,
        LiteralExpressionBase,
        _assert_not_nullable,
        _is_reduced,
        _propagate_nullable,
        _resolve_type,
)
from ..literal import BooleanExpression, FloatExpression, TimedeltaExpression

if TYPE_CHECKING:
    from ...engine.context import Context

[docs] class ComprehensionExpression(ExpressionBase): result_type: _DataTypeDef = DataType.ARRAY def __init__( self, context: 'Context', result: ExpressionBase, variable: str, iterable: ExpressionBase, condition: ExpressionBase | None = None ) -> None: self.context = context self.result = result self.variable = variable self.iterable = iterable self.condition = condition self.result_type = DataType.ARRAY(self.result.result_type) @classmethod def build( # type: ignore[override] cls, context: 'Context', result: ExpressionBase, variable: str, iterable: ExpressionBase, condition: ExpressionBase | None = None ) -> ExpressionBase: iterable_built = iterable.build() assert isinstance(iterable_built, ExpressionBase) if iterable_built.result_type is not DataType.UNDEFINED and not iterable_built.result_type.is_iterable: raise errors.EvaluationError('data type mismatch (comprehension requires an iterable)') resolved_iterable_type = _resolve_type(iterable_built.result_type, context) assignment = Assignment(variable, value_type=getattr(resolved_iterable_type, 'iterable_type', DataType.UNDEFINED)) with context.assignments(assignment): if condition is not None: condition_built = condition.build() assert isinstance(condition_built, ExpressionBase) condition = condition_built result_built = result.build() assert isinstance(result_built, ExpressionBase) result = result_built reduced = cls(context, result, variable, iterable_built, condition=condition).reduce() assert isinstance(reduced, ExpressionBase) return reduced def __repr__(self) -> str: return "<{0} iterable={1!r} result={2!r} condition={3!r} >".format(self.__class__.__name__, self.iterable, self.result, self.condition) def evaluate(self, thing: Any) -> Any: output_array: 'collections.deque[Any]' = collections.deque() input_iterable = self.iterable.evaluate(thing) if not DataType.from_value(input_iterable).is_iterable: raise errors.EvaluationError('data type mismatch (comprehension requires an iterable)') for value in input_iterable: assignment = Assignment(self.variable, value=value) with self.context.assignments(assignment): if self.condition is None or self.condition.evaluate(thing): output_array.append(self.result.evaluate(thing)) return tuple(output_array) def to_graphviz(self, digraph: Any, *args: Any, **kwargs: Any) -> None: digraph.node(str(id(self)), "{}\nvariable={!r}".format(self.__class__.__name__, self.variable)) self.result.to_graphviz(digraph, *args, **kwargs) digraph.edge(str(id(self)), str(id(self.result)), label='result') self.iterable.to_graphviz(digraph, *args, **kwargs) digraph.edge(str(id(self)), str(id(self.iterable)), label='iterable') if self.condition is not None: self.condition.to_graphviz(digraph, *args, **kwargs) digraph.edge(str(id(self)), str(id(self.condition)), label='condition')
[docs] class TernaryExpression(ExpressionBase): """ A class for representing ternary expressions from the grammar text. These involve evaluating :py:attr:`.condition` before evaluating either :py:attr:`.case_true` or :py:attr:`.case_false` based on the results. """ def __init__(self, context: 'Context', condition: ExpressionBase, case_true: ExpressionBase, case_false: ExpressionBase) -> None: """ :param context: The context to use for evaluating the expression. :type context: :py:class:`~rule_engine.engine.Context` :param condition: The condition expression whose evaluation determines whether the *case_true* or *case_false* expression is evaluated. :param case_true: The expression that's evaluated when *condition* is True. :param case_false:The expression that's evaluated when *condition* is False. """ self.context = context self.condition = condition self.case_true = case_true self.case_false = case_false true_type = DataType.NULLABLE.unwrap(self.case_true.result_type) false_type = DataType.NULLABLE.unwrap(self.case_false.result_type) if true_type == false_type: self.result_type = true_type elif DataType.is_type(true_type, DataType.ARRAY) and DataType.is_type(false_type, DataType.ARRAY): self.result_type = DataType.ARRAY # todo: the other compound types should be checked here as well. self.result_type = _propagate_nullable(self.result_type, self.case_true.result_type, self.case_false.result_type) @classmethod def build(cls, context: 'Context', condition: ExpressionBase, case_true: ExpressionBase, case_false: ExpressionBase) -> ExpressionBase: # type: ignore[override] condition_built = condition.build() assert isinstance(condition_built, ExpressionBase) case_true_built = case_true.build() assert isinstance(case_true_built, ExpressionBase) case_false_built = case_false.build() assert isinstance(case_false_built, ExpressionBase) reduced = cls(context, condition_built, case_true_built, case_false_built).reduce() assert isinstance(reduced, ExpressionBase) return reduced def evaluate(self, thing: Any) -> Any: case = (self.case_true if self.condition.evaluate(thing) else self.case_false) return case.evaluate(thing) def reduce(self) -> ExpressionBase: if not _is_reduced(self.condition): return self assert isinstance(self.condition, LiteralExpressionBase) reduced_condition = bool(self.condition.value) reduced = self.case_true.reduce() if reduced_condition else self.case_false.reduce() assert isinstance(reduced, ExpressionBase) return reduced def to_graphviz(self, digraph: Any, *args: Any, **kwargs: Any) -> None: super(TernaryExpression, self).to_graphviz(digraph, *args, **kwargs) self.condition.to_graphviz(digraph, *args, **kwargs) self.case_true.to_graphviz(digraph, *args, **kwargs) self.case_false.to_graphviz(digraph, *args, **kwargs) digraph.edge(str(id(self)), str(id(self.condition)), label='condition') digraph.edge(str(id(self)), str(id(self.case_true)), label='true case') digraph.edge(str(id(self)), str(id(self.case_false)), label='false case')
[docs] class UnaryExpression(ExpressionBase): """ A class for representing unary expressions from the grammar text. These involve a single operator on the left side. """ def __init__(self, context: 'Context', type_: str, right: ExpressionBase) -> None: """ :param context: The context to use for evaluating the expression. :type context: :py:class:`~rule_engine.engine.Context` :param str type_: The grammar type of operator to the left of the expression. :param right: The expression to the right of the operator. :type right: :py:class:`~.ExpressionBase` """ self.context = context type_ = type_.lower() self.type = type_ if type_ == 'not': self.result_type = DataType.BOOLEAN elif type_ == 'uminus': _assert_not_nullable(right.result_type, role='unary minus operand') self.result_type = right.result_type else: raise ValueError('unknown unary expression type') self._evaluator = getattr(self, '_op_' + type_) self.right = right @classmethod def build(cls, context: 'Context', type_: str, right: ExpressionBase) -> ExpressionBase: # type: ignore[override] right_built = right.build() assert isinstance(right_built, ExpressionBase) reduced = cls(context, type_, right_built).reduce() assert isinstance(reduced, ExpressionBase) return reduced def __repr__(self) -> str: return "<{} type={!r} >".format(self.__class__.__name__, self.type) def evaluate(self, thing: Any) -> Any: return self._evaluator(thing) def __op(self, op: Callable[[Any], Any], thing: Any) -> Any: return op(self.right.evaluate(thing)) _op_not = functools.partialmethod(__op, operator.not_) def __op_arithmetic(self, op: Callable[[Any], Any], thing: Any) -> Any: right = self.right.evaluate(thing) if not is_numeric(right) and not isinstance(right, datetime.timedelta): raise errors.EvaluationError('data type mismatch (not a numeric or timedelta value)') return op(right) _op_uminus = functools.partialmethod(__op_arithmetic, operator.neg) def reduce(self) -> ExpressionBase: type_ = self.type.lower() if not _is_reduced(self.right): return self if type_ == 'not': return BooleanExpression(self.context, self.evaluate(None)) elif type_ == 'uminus': if isinstance(self.right, FloatExpression): return FloatExpression(self.context, self.evaluate(None)) elif isinstance(self.right, TimedeltaExpression): return TimedeltaExpression(self.context, self.evaluate(None)) raise errors.EvaluationError('data type mismatch (not a float or timedelta expression)') raise errors.EngineError('unsupported unary expression type') def to_graphviz(self, digraph: Any, *args: Any, **kwargs: Any) -> None: digraph.node(str(id(self)), "{}\ntype={!r}".format(self.__class__.__name__, self.type.lower())) self.right.to_graphviz(digraph, *args, **kwargs) digraph.edge(str(id(self)), str(id(self.right)))