编写自定义函数

本指南将介绍如何在 SymPy 中创建自定义函数类。自定义用户定义函数使用与 SymPy 中包含的 函数 相同的机制,例如常见的 基本函数,如 exp()sin()特殊函数,如 gamma()Si(),以及 组合函数数论函数,如 factorial()primepi()。因此,本指南既是希望定义自己自定义函数的最终用户的指南,也是希望扩展 SymPy 中包含的函数的 SymPy 开发人员的指南。

本指南介绍了如何定义复值函数,即将 \(\mathbb{C}^n\) 的一个子集映射到 \(\mathbb{C}\) 的函数。接受或返回除复数以外的其他类型的对象的函数应该子类化另一个类,例如 BooleanMatrixExprExprBasic。这里写的一些内容将适用于一般的 BasicExpr 子类,但其中大部分内容只适用于 Function 子类。

简单情况:完全符号化或完全评估

在深入研究自定义函数的更高级功能之前,我们应该提到两种常见情况,一种是函数完全符号化,另一种是函数完全评估。这两种情况都有比本指南中描述的完整机制简单得多的替代方案。

完全符号化的情况

如果您的函数 f 没有您想定义的数学属性,并且不应该对任何参数进行评估,则可以使用 Function('f') 创建一个未定义的函数

>>> from sympy import symbols, Function
>>> x = symbols('x')
>>> f = Function('f')
>>> f(x)
f(x)
>>> f(0)
f(0)

例如,这在求解 ODE 时很有用。

这在您只想创建一个依赖于另一个符号的符号以用于微分时也很有用。默认情况下,SymPy 假设所有符号都是相互独立的

>>> from sympy.abc import x, y
>>> y.diff(x)
0

要创建一个依赖于另一个符号的符号,您可以使用一个显式依赖于该符号的函数。

>>> y = Function('y')
>>> y(x).diff(x)
Derivative(y(x), x)

如果您希望您的函数具有其他行为,例如具有自定义导数,或在某些参数上进行评估,则应该创建一个自定义的 Function 子类,如 下面所述。但是,未定义的函数确实支持一项附加功能,即可以使用与符号相同的语法在其上定义假设。这定义了函数输出的假设,而不是输入(也就是说,它定义了函数的范围,而不是它的域)。

>>> g = Function('g', real=True)
>>> g(x)
g(x)
>>> g(x).is_real
True

要使函数的假设以某种方式依赖于其输入,您应该创建一个自定义的 Function 子类,并定义假设处理程序,如 下面所述

完全评估的情况

另一方面,有些函数无论输入是什么,都会始终评估为某个值。这些函数永远不会以未评估的符号形式保留下来,例如 f(x)

在这种情况下,您应该使用使用 def 关键字的普通 Python 函数

>>> def f(x):
...     if x == 0:
...         return 0
...     else:
...         return x + 1
>>> f(0)
0
>>> f(1)
2
>>> f(x)
x + 1

如果您发现自己在 Function 子类上定义了一个 eval() 方法,您始终返回一个值,并且从未返回 None,您应该考虑只使用一个普通的 Python 函数,因为在这种情况下使用符号 Function 子类没有任何好处(请参阅下面关于 eval() 的最佳实践 部分)。

请注意,在许多情况下,此类函数可以直接使用 SymPy 类表示。例如,上面的函数可以使用 Piecewise 符号表示。可以使用 subs()x 的特定值评估 Piecewise 表达式。

>>> from sympy import Piecewise, Eq, pprint
>>> f = Piecewise((0, Eq(x, 0)), (x + 1, True))
>>> pprint(f, use_unicode=True)
⎧  0    for x = 0

⎩x + 1  otherwise
>>> f.subs(x, 0)
0
>>> f.subs(x, 1)
2

Piecewise 这样的完全符号表示具有准确表示符号值的优点。例如,在上面 Python def 定义的 f 中,f(x) 隐式地假设 x 不为零。Piecewise 版本正确处理这种情况,并且不会评估到 \(x \neq 0\) 的情况,除非已知 x 不为零。

另一个选择,如果您希望一个函数不仅评估,而且始终评估为一个数值,那么可以使用 lambdify()。这会将 SymPy 表达式转换为可以使用 NumPy 评估的函数。

>>> from sympy import lambdify
>>> func = lambdify(x, Piecewise((0, Eq(x, 0)), (x + 1, True)))
>>> import numpy as np 
>>> func(np.arange(5)) 
array([0., 2., 3., 4., 5.])

最终,正确的工具取决于您正在做什么以及您想要的确切行为。

创建自定义函数

创建自定义函数的第一步是子类化 Function。子类的名称将是函数的名称。然后应该在这个子类上定义不同的方法,具体取决于您想要提供的功能。

作为本文档的示例,让我们创建一个表示 正矢函数 的自定义函数类。正矢是一个三角函数,历史上与正弦和余弦等一些更熟悉的三角函数一起使用。它在今天很少使用。正矢可以通过恒等式定义

\[\operatorname{versin}(x) = 1 - \cos(x).\]

SymPy 没有包含正矢,因为它在现代数学中很少使用,而且它很容易用更熟悉的余弦来定义。

让我们从子类化 Function 开始。

>>> class versin(Function):
...     pass

此时,versin 没有在其上定义任何行为。它与我们上面讨论的 未定义函数 非常相似。请注意,versin 是一个类,而 versin(x) 是该类的实例。

>>> versin(x)
versin(x)
>>> isinstance(versin(x), versin)
True

注意

下面描述的所有方法都是可选的。如果要定义给定行为,可以将其包含在内,但如果省略,SymPy 将默认保留未评估状态。例如,如果您没有定义 微分diff() 将只返回一个未评估的 Derivative

使用 eval() 定义自动评估

在我们的自定义函数上,我们可能想定义的第一件事,也是最常见的事,就是自动评估,也就是说,它将在某些情况下返回一个实际值,而不是仅仅保留为未评估状态。

这可以通过定义类方法 eval() 来完成。eval() 应该接受函数的参数,并返回一个值或 None。如果它返回 None,该函数将保持在这种情况下未评估状态。这也用于定义函数的签名(默认情况下,如果没有 eval() 方法,Function 子类将接受任意数量的参数)。

对于我们的函数 versin,我们可能回想起对于整数 \(n\)\(\cos(n\pi) = (-1)^n\),因此 \(\operatorname{versin}(n\pi) = 1 - (-1)^n.\) 我们可以使 versin 在传递给它一个 pi 的整数倍数时自动评估为该值

>>> from sympy import pi, Integer
>>> class versin(Function):
...    @classmethod
...    def eval(cls, x):
...        # If x is an integer multiple of pi, x/pi will cancel and be an Integer
...        n = x/pi
...        if isinstance(n, Integer):
...            return 1 - (-1)**n
>>> versin(pi)
2
>>> versin(2*pi)
0

这里利用了这样一个事实:如果 Python 函数没有显式地返回值,它会自动返回 None。因此,在 if isinstance(n, Integer) 语句没有被触发的情况下,eval() 会返回 None,而 versin 则保持未计算。

>>> versin(x*pi)
versin(pi*x)

注意

Function 子类不应该重新定义 __new____init__。如果你想要实现 eval() 无法实现的行为,那么对 Expr 子类化可能比 Function 更合理。

eval() 可以接受任意数量的参数,包括使用 *args 的任意数量的参数和可选的关键字参数。函数的 .args 将始终是用户传入的参数。例如

>>> class f(Function):
...     @classmethod
...     def eval(cls, x, y=1, *args):
...         return None
>>> f(1).args
(1,)
>>> f(1, 2).args
(1, 2)
>>> f(1, 2, 3).args
(1, 2, 3)

最后,请注意,一旦 evalf() 被定义,对浮点数输入的自动计算就会自动发生,因此你不需要在 eval() 中显式处理它。

eval() 的最佳实践

在定义 eval() 方法时,某些反模式很常见,应该避免。

  • 不要只返回一个表达式。

    在上面的例子中,我们可能会想写

    >>> from sympy import cos
    >>> class versin(Function):
    ...     @classmethod
    ...     def eval(cls, x):
    ...         # !! Not actually a good eval() method !!
    ...         return 1 - cos(x)
    

    但是,这会导致 versin(x) 始终 返回 1 - cos(x),而不管 x 是什么。如果你只需要一个对 1 - cos(x) 的快速简写,那就可以了,但使用上面提到的 Python 函数会更简单、更明确 使用上面提到的 Python 函数。如果我们像这样定义 versin,它永远不会实际表示为 versin(x),我们下面定义的其他行为也不会起作用,因为我们将在 versin 类上定义的其他行为仅在返回的对象实际上是 versin 实例时才会应用。因此,例如,versin(x).diff(x) 实际上就是 (1 - cos(x)).diff(x),而不是调用 我们下面定义的 fdiff() 方法

    关键点

    eval() 的目的不是定义函数是什么,而是指定在哪些输入上它应该自动计算。函数的数学定义是通过下面概述的方法来指定各种数学属性来确定的,例如 数值计算微分等等。

    如果你发现自己正在这样做,你应该考虑你真正想要实现什么。如果你只需要一个表达式的简写函数,那么使用 定义一个 Python 函数 会更简单。如果你真的想要一个符号函数,考虑一下你想要它在什么时候计算为其他东西,以及你想要它在什么时候保持未计算。一种选择是在 eval() 中使你的函数保持未计算,并定义一个 doit() 方法 来计算它。

  • 避免过多的自动计算。

    建议尽量减少 eval() 自动计算的内容。通常将更高级的简化放在 其他方法 中,例如 doit()。请记住,你为自动计算定义的任何内容都将始终计算 [1]。如前一点所述,如果你计算了所有值,那么使用符号函数本身就没有太大意义。例如,我们可能会想在 eval() 中计算 versin 上的某些三角恒等式,但这样这些恒等式就会始终计算,而无法表示恒等式的其中一半。

    还应该避免在 eval() 中执行任何计算速度慢的操作。SymPy 通常假设创建表达式是廉价的,如果这不是真的,它会导致性能问题。

    最后,建议避免在 eval() 中基于假设执行自动计算。相反,eval() 通常只应该计算显式的数值特殊值,并对其他所有内容返回 None。你可能已经注意到,在 上面的例子 中,我们使用了 isinstance(n, Integer) 而不是使用假设系统检查 n.is_integer。我们本来可以那样做,这会使得 versin(n*pi) 即使在 n = Symbol('n', integer=True) 时也会计算。但这是一种我们可能并不总是希望进行计算的情况,如果 n 是一个更复杂的表达式,那么 n.is_integer 的计算成本可能更高。

    让我们考虑一个例子。利用恒等式 \(\cos(x + y) = \cos(x)\cos(y) - \sin(x)\sin(y)\),我们可以推导出恒等式

    \[\operatorname{versin}(x + y) = \operatorname{versin}(x)\operatorname{versin}(y) - \operatorname{versin}(x) - \operatorname{versin}(y) - \sin(x)\sin(y) + 1.\]

    假设我们决定在 eval() 中自动展开它

    >>> from sympy import Add, sin
    >>> class versin(Function):
    ...     @classmethod
    ...     def eval(cls, x):
    ...         # !! Not actually a good eval() method !!
    ...         if isinstance(x, Add):
    ...             a, b = x.as_two_terms()
    ...             return (versin(a)*versin(b) - versin(a) - versin(b)
    ...                     - sin(a)*sin(b) + 1)
    

    此方法递归地将 Add 项拆分为两部分,并应用上述恒等式。

    >>> x, y, z = symbols('x y z')
    >>> versin(x + y)
    -sin(x)*sin(y) + versin(x)*versin(y) - versin(x) - versin(y) + 1
    

    但现在,无法表示 versin(x + y) 而不将其展开。这也会影响其他方法。例如,假设我们定义了 微分(见下文)

    >>> class versin(Function):
    ...     @classmethod
    ...     def eval(cls, x):
    ...         # !! Not actually a good eval() method !!
    ...         if isinstance(x, Add):
    ...             a, b = x.as_two_terms()
    ...             return (versin(a)*versin(b) - versin(a) - versin(b)
    ...                     - sin(a)*sin(b) + 1)
    ...
    ...     def fdiff(self, argindex=1):
    ...         return sin(self.args[0])
    

    我们预计 versin(x + y).diff(x) 会返回 sin(x + y),事实上,如果我们没有在 eval() 中展开这个恒等式,它会。但是在这个版本中,versin(x + y) 在调用 diff() 之前被自动展开,因此我们得到一个更复杂的表达式

    >>> versin(x + y).diff(x)
    sin(x)*versin(y) - sin(x) - sin(y)*cos(x)
    

    情况比这更糟。让我们尝试一个包含三项的 Add

    >>> versin(x + y + z)
    (-sin(y)*sin(z) + versin(y)*versin(z) - versin(y) - versin(z) +
    1)*versin(x) - sin(x)*sin(y + z) + sin(y)*sin(z) - versin(x) -
    versin(y)*versin(z) + versin(y) + versin(z)
    

    我们可以看到,事情很快就会失控。实际上,versin(Add(*symbols('x:100')))(包含 100 项的 versin())需要一秒钟以上才能计算,而这仅仅是创建表达式的时间,甚至还没有对它进行任何操作。

    这样的恒等式最好放在 eval 之外,并在其他方法中实现(对于这个恒等式,expand_trig())。

  • 在限制输入域时:允许 None 输入假设。

    我们的示例函数 \(\operatorname{versin}(x)\) 是从 \(\mathbb{C}\)\(\mathbb{C}\) 的函数,因此它可以接受任何输入。但是假设我们有一个只有在特定输入下才有意义的函数。作为第二个例子,让我们定义一个函数 divides 如下

    \[\begin{split}\operatorname{divides}(m, n) = \begin{cases} 1 & \text{for}\: m \mid n \\ 0 & \text{for}\: m\not\mid n \end{cases}.\end{split}\]

    也就是说,divides(m, n) 如果 m 能被 n 整除,则为 1,否则为 0。显然,divides 只有在 mn 是整数的情况下才有意义。

    我们可能会想这样定义 divideseval() 方法

    >>> class divides(Function):
    ...     @classmethod
    ...     def eval(cls, m, n):
    ...         # !! Not actually a good eval() method !!
    ...
    ...         # Evaluate for explicit integer m and n. This part is fine.
    ...         if isinstance(m, Integer) and isinstance(n, Integer):
    ...             return int(n % m == 0)
    ...
    ...         # For symbolic arguments, require m and n to be integer.
    ...         # If we write the logic this way, we will run into trouble.
    ...         if not m.is_integer or not n.is_integer:
    ...             raise TypeError("m and n should be integers")
    

    这里的问题是,通过使用 if not m.is_integer,我们要求 m.is_integerTrue。如果它为 None,它会失败(有关假设为 None 的含义,请参阅 布尔值和三值逻辑指南)。这存在两个问题。首先,它强制用户为任何输入变量定义假设。如果用户省略它们,它会失败

    >>> n, m = symbols('n m')
    >>> print(n.is_integer)
    None
    >>> divides(m, n)
    Traceback (most recent call last):
    ...
    TypeError: m and n should be integers
    

    相反,他们必须写

    >>> n, m = symbols('n m', integer=True)
    >>> divides(m, n)
    divides(m, n)
    

    这看起来可能是一个可以接受的限制,但还有一个更大的问题。有时,SymPy 的假设系统无法推导出一个假设,即使它在数学上是正确的。在这种情况下,它会给出 None(在 SymPy 的假设中,None 既表示“未定义”,也表示“无法计算”)。例如

    >>> # n and m are still defined as integer=True as above
    >>> divides(2, (m**2 + m)/2)
    Traceback (most recent call last):
    ...
    TypeError: m and n should be integers
    

    这里的表达式 (m**2 + m)/2 始终是整数,但 SymPy 的假设系统无法推导出这一点

    >>> print(((m**2 + m)/2).is_integer)
    None
    

    SymPy 的假设系统一直在改进,但由于问题的基本计算复杂性以及一般问题通常不可判定,总会存在它无法推断出的情况。

    因此,应始终对输入变量测试否定的假设,也就是说,如果假设为False则失败,但允许假设为None

    >>> class divides(Function):
    ...     @classmethod
    ...     def eval(cls, m, n):
    ...         # Evaluate for explicit integer m and n. This part is fine.
    ...         if isinstance(m, Integer) and isinstance(n, Integer):
    ...             return int(n % m == 0)
    ...
    ...         # For symbolic arguments, require m and n to be integer.
    ...         # This is the better way to write this logic.
    ...         if m.is_integer is False or n.is_integer is False:
    ...             raise TypeError("m and n should be integers")
    

    这仍然如预期那样禁止非整数输入

    >>> divides(1.5, 1)
    Traceback (most recent call last):
    ...
    TypeError: m and n should be integers
    

    但它不会在假设为None的情况下失败。

    >>> divides(2, (m**2 + m)/2)
    divides(2, m**2/2 + m/2)
    >>> _.subs(m, 2)
    0
    >>> n, m = symbols('n m') # Redefine n and m without the integer assumption
    >>> divides(m, n)
    divides(m, n)
    

    注意

    仅当会引发异常时,例如对输入域进行类型检查,才允许None假设。在进行简化或其他操作的情况下,应将None假设视为“可以是TrueFalse”,并且不执行可能在数学上无效的操作。

假设

您可能要定义的下一件事是对函数的假设。假设系统允许您定义给定输入时函数具有的数学属性,例如,“\(f(x)\)\(x\)实数时为”。

有关假设系统的指南详细介绍了假设系统。建议您先阅读该指南以了解不同假设的含义以及假设系统的运作方式。

最简单的情况是函数始终具有给定的假设,而与输入无关。在这种情况下,您可以在类上直接定义is_assumption

例如,我们的示例divides函数始终为整数,因为它的值始终为 0 或 1

>>> class divides(Function):
...     is_integer = True
...     is_negative = False
>>> divides(m, n).is_integer
True
>>> divides(m, n).is_nonnegative
True

但是,通常情况下,函数的假设取决于其输入的假设。在这种情况下,您应该定义一个_eval_assumption方法。

对于我们的\(\operatorname{versin}(x)\)示例,当 \(x\)为实数时,该函数始终位于 \([0, 2]\) 中,当 \(x\)\(\pi\) 的偶数倍时,该函数正好为 0。因此,versin(x)应在x实数时为非负,并在x实数且不是 π 的偶数倍时为。请记住,默认情况下,函数的域为整个 \(\mathbb{C}\),并且实际上,versin(x)在非实数x的情况下仍然有意义。

要查看x是否为pi的偶数倍,我们可以使用as_independent()按结构将x匹配为coeff*pi。在假设处理程序中以这种结构方式拆分子表达式比使用类似(x/pi).is_even的方法更可取,因为这将创建一个新的表达式x/pi。创建新表达式会慢得多。此外,每当创建表达式时,在创建表达式时调用的构造函数本身通常会导致查询假设。如果您不小心,这会导致无限递归。因此,假设处理程序的一个好的通用规则是,切勿在假设处理程序中创建新表达式。始终使用结构化方法(如as_independent)拆分函数的参数。

请注意,\(\operatorname{versin}(x)\)对于非实数 \(x\)可以是非负数,例如

>>> from sympy import I
>>> 1 - cos(pi + I*pi)
1 + cosh(pi)
>>> (1 - cos(pi + I*pi)).evalf()
12.5919532755215

因此,对于_eval_is_nonnegative处理程序,如果x.is_realTrue,则我们希望返回True,但如果x.is_realFalseNone,则返回None。使用类似于_eval_is_positive处理程序中的逻辑,留待读者练习处理使versin(x)非负的非实数x的情况。

在假设处理程序方法中,与所有方法一样,我们可以使用self.args访问函数的参数。

>>> from sympy.core.logic import fuzzy_and, fuzzy_not
>>> class versin(Function):
...     def _eval_is_nonnegative(self):
...         # versin(x) is nonnegative if x is real
...         x = self.args[0]
...         if x.is_real is True:
...             return True
...
...     def _eval_is_positive(self):
...         # versin(x) is positive if x is real and not an even multiple of pi
...         x = self.args[0]
...
...         # x.as_independent(pi, as_Add=False) will split x as a Mul of the
...         # form coeff*pi
...         coeff, pi_ = x.as_independent(pi, as_Add=False)
...         # If pi_ = pi, x = coeff*pi. Otherwise x is not (structurally) of
...         # the form coeff*pi.
...         if pi_ == pi:
...             return fuzzy_and([x.is_real, fuzzy_not(coeff.is_even)])
...         elif x.is_real is False:
...             return False
...         # else: return None. We do not know for sure whether x is an even
...         # multiple of pi
>>> versin(1).is_nonnegative
True
>>> versin(2*pi).is_positive
False
>>> versin(3*pi).is_positive
True

请注意,在更复杂的_eval_is_positive()处理程序中使用fuzzy_函数,以及对if/elif的谨慎处理。在使用假设时,始终要谨慎处理三值逻辑至关重要。这确保当x.is_realcoeff.is_evenNone时,该方法返回正确的结果。

警告

切勿将is_assumption定义为@property方法。这样做会导致自动推断其他假设出现故障。is_assumption仅应定义为等于TrueFalse的类变量。如果假设以某种方式取决于函数的.args,请定义_eval_assumption方法。

在本例中,不需要定义_eval_is_real(),因为它会根据其他假设自动推断出来,因为nonnegative -> real。通常情况下,您应该避免定义假设系统可以根据其已知事实自动推断出的假设。

>>> versin(1).is_real
True

假设系统通常能够推断出您可能没有想到的更多内容。例如,从上面可以推断出,当n为整数时,versin(2*n*pi)为零。

>>> n = symbols('n', integer=True)
>>> versin(2*n*pi).is_zero
True

在手动编码之前,始终值得检查假设系统是否可以自动推断出某些内容。

最后,提醒一句:在编码假设时要格外小心。确保使用各种假设的准确定义,并始终检查您是否使用模糊的三值逻辑函数正确处理了None情况。假设不正确或不一致会导致难以察觉的错误。建议您使用单元测试检查所有各种情况,尤其是在您的函数具有非平凡的假设处理程序时。SymPy 本身定义的所有函数都要求进行广泛的测试。

使用evalf()进行数值评估

这里我们将展示如何定义函数应如何数值评估为浮点Float值,例如,通过evalf()。实现数值评估会在 SymPy 中启用几种行为。例如,一旦定义了evalf(),您就可以绘制函数,并且不等式之类的内容会评估为显式值。

如果您的函数与 mpmath中的函数同名,大多数包含在 SymPy 中的函数都是这种情况,则数值评估会自动发生,您无需执行任何操作。

如果不是这种情况,可以通过定义方法_eval_evalf(self, prec)来指定数值评估,其中prec是输入的二进制精度。该方法应返回以给定精度评估的表达式,如果无法实现,则返回None

注意

传递给_eval_evalf()prec参数是二进制精度,即浮点表示中的位数。这与evalf()方法的第一个参数不同,第一个参数是十进制精度,或dps。例如,Float的默认二进制精度为 53,对应于十进制精度为 15。因此,如果您的_eval_evalf()方法递归地调用其他表达式的 evalf,则应调用expr._eval_evalf(prec)而不是expr.evalf(prec),因为后者会错误地将prec用作十进制精度。

我们可以通过递归计算 \(2\sin^2\left(\frac{x}{2}\right)\) 来定义 我们的示例 \(\operatorname{versin}(x)\) 函数 的数值评估,这是一种数值上更稳定的写法,等同于 \(1 - \cos(x)\).

>>> from sympy import sin
>>> class versin(Function):
...     def _eval_evalf(self, prec):
...         return (2*sin(self.args[0]/2)**2)._eval_evalf(prec)
>>> versin(1).evalf()
0.459697694131860

一旦定义了 _eval_evalf(),就可以自动评估浮点输入。无需在 eval() 中手动实现。

>>> versin(1.)
0.459697694131860

请注意,evalf() 可以传递任何表达式,而不仅仅是可以数值评估的表达式。在这种情况下,预计表达式中的数值部分将被评估。一个通用的模式是递归调用 _eval_evalf(prec) 来评估函数的参数。

尽可能地,最好重用 SymPy 函数中定义的 evalf 功能。但是,在某些情况下,需要直接使用 mpmath。

改写和简化

各种简化函数和方法允许指定它们对自定义子类的行为。并非 SymPy 中的每个函数都有这样的钩子。有关详细信息,请参阅每个函数的文档。

rewrite()

使用 rewrite() 方法可以将表达式改写成特定函数或规则的形式。例如,

>>> sin(x).rewrite(cos)
cos(x - pi/2)

要实现改写,请定义一个方法 _eval_rewrite(self, rule, args, **hints),其中

  • rule 是传递给 rewrite() 方法的规则。通常情况下,rule 将是待改写对象的类,但对于更复杂的改写,它可以是任何东西。每个定义了 _eval_rewrite() 的对象都定义了它支持的规则。许多 SymPy 函数会改写为常见的类,例如 expr.rewrite(Add),用于执行简化或其他计算。

  • args 是用于改写的函数参数。应该使用它而不是 self.args,因为 args 中的任何递归表达式都将在 args 中被改写(假设调用者使用了 rewrite(deep=True),这是默认设置)。

  • **hints 是额外的关键字参数,可用于指定改写的行为。未知提示应该被忽略,因为它们可能会传递给其他 _eval_rewrite() 方法。如果递归调用 rewrite,则应将 **hints 传递下去。

该方法应该返回一个使用 args 作为函数参数的改写表达式,或者返回 None 表示表达式应该保持不变。

对于我们的 versin 示例,我们可以实现一个明显的改写,即将 versin(x) 改写成 1 - cos(x)

>>> class versin(Function):
...     def _eval_rewrite(self, rule, args, **hints):
...         if rule == cos:
...             return 1 - cos(*args)
>>> versin(x).rewrite(cos)
1 - cos(x)

定义好这个之后,simplify() 现在能够简化包含 versin 的一些表达式

>>> from sympy import simplify
>>> simplify(versin(x) + cos(x))
1

doit()

使用 doit() 方法来评估“未评估”的函数。要定义 doit(),请实现 doit(self, deep=True, **hints)。如果 deep=Truedoit() 应该递归调用 doit() 来评估参数。 **hints 将是传递给用户的任何其他关键字参数,应该将它们传递给 doit() 的任何递归调用。可以使用 hints 来允许用户指定 doit() 的特定行为。

在自定义 Function 子类中,doit() 的典型用法是执行更高级别的评估,这些评估在 eval() 中没有执行。

例如,对于我们的 divides 示例,可以通过一些恒等式来简化多个实例。例如,我们定义了 eval() 来评估显式整数,但我们可能还想评估像 divides(k, k*n) 这样的示例,其中可除性在符号上是正确的。 eval() 的最佳实践 之一是避免过多的自动评估。在这种情况下,自动评估可能被认为过多了,因为它会使用假设系统,这可能会很昂贵。此外,我们可能希望能够表示 divides(k, k*n) 而不总是评估它。

解决方案是在 doit() 中实现这些更高级别的评估。这样,我们可以通过调用 expr.doit() 来显式地执行它们,但它们不会默认发生。一个针对 dividesdoit() 示例,它执行这个简化(以及 eval() 的上述定义)可能如下所示

注意

如果 doit() 返回一个 Python int 字面量,请将其转换为一个 Integer,以确保返回的对象是 SymPy 类型。

>>> from sympy import Integer
>>> class divides(Function):
...     # Define evaluation on basic inputs, as well as type checking that the
...     # inputs are not nonintegral.
...     @classmethod
...     def eval(cls, m, n):
...         # Evaluate for explicit integer m and n.
...         if isinstance(m, Integer) and isinstance(n, Integer):
...             return int(n % m == 0)
...
...         # For symbolic arguments, require m and n to be integer.
...         if m.is_integer is False or n.is_integer is False:
...             raise TypeError("m and n should be integers")
...
...     # Define doit() as further evaluation on symbolic arguments using
...     # assumptions.
...     def doit(self, deep=False, **hints):
...         m, n = self.args
...         # Recursively call doit() on the args whenever deep=True.
...         # Be sure to pass deep=True and **hints through here.
...         if deep:
...            m, n = m.doit(deep=deep, **hints), n.doit(deep=deep, **hints)
...
...         # divides(m, n) is 1 iff n/m is an integer. Note that m and n are
...         # already assumed to be integers because of the logic in eval().
...         isint = (n/m).is_integer
...         if isint is True:
...             return Integer(1)
...         elif isint is False:
...             return Integer(0)
...         else:
...             return divides(m, n)

(请注意,这使用的是 约定,即对所有 \(k\)\(k \mid 0\),因此我们不需要检查 mn 是否为非零。如果使用不同的约定,我们需要在执行简化之前检查 m.is_zeron.is_zero。)

>>> n, m, k = symbols('n m k', integer=True)
>>> divides(k, k*n)
divides(k, k*n)
>>> divides(k, k*n).doit()
1

实现 doit() 的另一种常见方法是让它始终返回另一个表达式。这实际上将该函数视为另一个表达式的“未评估”形式。

例如,让我们为 融合乘加 定义一个函数:\(\operatorname{FMA}(x, y, z) = xy + z\)。将此函数表示为一个不同的函数可能很有用,例如,出于代码生成的考虑,但在某些情况下,将 FMA(x, y, z) “评估”为 x*y + z 也很有用,以便它能够与其他表达式正确简化。

>>> from sympy import Number
>>> class FMA(Function):
...     """
...     FMA(x, y, z) = x*y + z
...     """
...     @classmethod
...     def eval(cls, x, y, z):
...         # Number is the base class of Integer, Rational, and Float
...         if all(isinstance(i, Number) for i in [x, y, z]):
...            return x*y + z
...
...     def doit(self, deep=True, **hints):
...         x, y, z = self.args
...         # Recursively call doit() on the args whenever deep=True.
...         # Be sure to pass deep=True and **hints through here.
...         if deep:
...             x = x.doit(deep=deep, **hints)
...             y = y.doit(deep=deep, **hints)
...             z = z.doit(deep=deep, **hints)
...         return x*y + z
>>> x, y, z = symbols('x y z')
>>> FMA(x, y, z)
FMA(x, y, z)
>>> FMA(x, y, z).doit()
x*y + z

大多数自定义函数不会以这种方式定义 doit()。但是,这可以在始终评估的函数和从不评估的函数之间提供一个折衷方案,生成一个默认情况下不评估但可以按需评估的函数(参见 上面的讨论)。

expand()

函数 expand() 以各种方式“展开”表达式。它实际上是几个子展开提示的包装器。每个函数对应于 expand() 函数/方法的一个提示。可以通过定义 _eval_expand_hint(self, **hints) 来在自定义函数中定义一个特定的展开提示。有关已定义的提示以及每个特定 expand_hint() 函数的文档(例如,expand_trig())的详细信息,请参阅 expand() 的文档。

关键字参数 **hints 是可以传递给扩展函数以指定附加行为的附加提示(这些提示独立于上一段中描述的预定义 *hints*)。未知提示应被忽略,因为它们可能适用于其他函数的自定义 expand() 方法。一个常见的提示是定义 force,其中 force=True 将强制执行扩展,而该扩展对于所有给定的输入假设可能在数学上无效。例如,expand_log(log(x*y), force=True) 生成 log(x) + log(y),即使对于所有复数 xy,此恒等式都不成立(通常 force=False 是默认值)。

请注意,expand() 使用其自身的 deep 标志自动处理递归扩展表达式,因此 _eval_expand_* 方法不应递归地调用函数参数的扩展。

对于我们的 versin 示例,我们可以通过定义一个 _eval_expand_trig 方法来定义基本的 trig 扩展,该方法递归地调用 expand_trig()1 - cos(x)

>>> from sympy import expand_trig
>>> y = symbols('y')
>>> class versin(Function):
...    def _eval_expand_trig(self, **hints):
...        x = self.args[0]
...        return expand_trig(1 - cos(x))
>>> versin(x + y).expand(trig=True)
sin(x)*sin(y) - cos(x)*cos(y) + 1

更复杂的实现可能会尝试将 expand_trig(1 - cos(x)) 的结果重写回 versin 函数。这留给读者作为练习。

微分

要通过 diff() 定义微分,请定义一个方法 fdiff(self, argindex)fdiff() 应返回函数的导数,不考虑链式法则,相对于第 argindex 个变量。 argindex 的索引从 1 开始。

也就是说,f(x1, ..., xi, ..., xn).fdiff(i) 应该返回 \(\frac{d}{d x_i} f(x_1, \ldots, x_i, \ldots, x_n)\),其中 \(x_k\) 彼此独立。 diff() 将使用 fdiff() 的结果自动应用链式法则。用户代码应该使用 diff() 而不是直接调用 fdiff()

注意

Function 子类应使用 fdiff() 定义微分。不是 Function 子类的 Expr 子类将需要定义 _eval_derivative()。不建议在 Function 子类上重新定义 _eval_derivative()

对于我们的 \(\operatorname{versin}\) 示例函数,导数为 \(\sin(x)\)

>>> class versin(Function):
...     def fdiff(self, argindex=1):
...         # argindex indexes the args, starting at 1
...         return sin(self.args[0])
>>> versin(x).diff(x)
sin(x)
>>> versin(x**2).diff(x)
2*x*sin(x**2)
>>> versin(x + y).diff(x)
sin(x + y)

作为具有多个参数的函数的示例,请考虑上面定义的 融合乘加 (FMA) 示例\(\operatorname{FMA}(x, y, z) = xy + z\))。

我们有

\[\frac{d}{dx} \operatorname{FMA}(x, y, z) = y,\]
\[\frac{d}{dy} \operatorname{FMA}(x, y, z) = x,\]
\[\frac{d}{dz} \operatorname{FMA}(x, y, z) = 1.\]

因此 FMAfdiff() 方法将如下所示

>>> from sympy import Number, symbols
>>> x, y, z = symbols('x y z')
>>> class FMA(Function):
...     """
...     FMA(x, y, z) = x*y + z
...     """
...     def fdiff(self, argindex):
...         # argindex indexes the args, starting at 1
...         x, y, z = self.args
...         if argindex == 1:
...             return y
...         elif argindex == 2:
...             return x
...         elif argindex == 3:
...             return 1
>>> FMA(x, y, z).diff(x)
y
>>> FMA(x, y, z).diff(y)
x
>>> FMA(x, y, z).diff(z)
1
>>> FMA(x**2, x + 1, y).diff(x)
x**2 + 2*x*(x + 1)

要使导数不求值,请引发 sympy.core.function.ArgumentIndexError(self, argindex)。如果未定义 fdiff(),这是默认行为。以下是一个示例函数 \(f(x, y)\),该函数在第一个参数中是线性的,并且在第二个参数上具有未求值的导数。

>>> from sympy.core.function import ArgumentIndexError
>>> class f(Function):
...    @classmethod
...    def eval(cls, x, y):
...        pass
...
...    def fdiff(self, argindex):
...        if argindex == 1:
...           return 1
...        raise ArgumentIndexError(self, argindex)
>>> f(x, y).diff(x)
1
>>> f(x, y).diff(y)
Derivative(f(x, y), y)

打印

您可以使用各种 打印机(例如 字符串 打印机漂亮 打印机LaTeX 打印机,以及各种语言(如 CFortran)的代码打印机)来定义函数如何打印自身。

在大多数情况下,您不需要定义任何打印方法。默认行为是使用函数名称打印函数。但是,在某些情况下,我们可能希望为函数定义特殊的打印。

例如,对于我们的 上面的 divides 示例,我们可能希望 LaTeX 打印机打印更多数学表达式。让我们让 LaTeX 打印机将 divides(m, n) 表示为 \left [ m \middle | n \right ],它看起来像 \(\left [ m \middle | n \right ]\)(这里 \( [P]\)艾弗森括号,如果 \(P\) 为真,则为 \(1\),如果 \(P\) 为假,则为 \(0\))。

有两种主要方法可以为 SymPy 对象定义打印。一种是在打印机类上定义打印机。SymPy 库中大多数类都应使用这种方法,通过在 sympy.printing 中的相应类上定义打印机。对于用户代码,如果您正在定义自定义打印机,或者您有许多希望为其定义打印的自定义函数,这可能更可取。请参阅 自定义打印机示例,了解如何以这种方式定义打印机的示例。

另一种方法是在函数类上将打印定义为方法。为此,首先查找您要为其定义打印的打印机的 printmethod 属性。这是您应该为该打印机定义的方法的名称。对于 LaTeX 打印机,LatexPrinter.printmethod'_latex'。打印方法始终接受一个参数 printerprinter._print 应用于递归地打印任何其他表达式,包括函数的参数。

因此,要定义我们的 divides LaTeX 打印机,我们将定义类上的函数 _latex(self, printer),如下所示

>>> from sympy import latex
>>> class divides(Function):
...     def _latex(self, printer):
...         m, n = self.args
...         _m, _n = printer._print(m), printer._print(n)
...         return r'\left [ %s \middle | %s \right ]' % (_m, _n)
>>> print(latex(divides(m, n)))
\left [ m \middle | n \right ]

请参阅 自定义打印方法示例,了解有关如何定义打印机方法以及一些需要避免的陷阱的更多详细信息。最重要的是,您应该始终使用 printer._print() 在自定义打印机内递归地打印函数的参数。

其他方法

可以在自定义函数上定义其他几个方法来指定各种行为。

inverse()

可以定义 inverse(self, argindex=1) 方法来指定函数的逆函数。这由 solve()solveset() 使用。 argindex 参数是函数的参数,从 1 开始(类似于 fdiff() 方法 的相同参数名称)。

inverse() 应该返回一个函数(而不是一个表达式)来表示逆函数。如果逆函数比单个函数更大,则可以返回一个 lambda 函数。

inverse() 应该只定义在单调函数上。换句话说,f(x).inverse()f(x)左逆。在非单调函数上定义 inverse() 可能会导致 solve() 无法给出包含该函数的表达式的所有可能解。

我们的 versine 示例函数 不是单调的(因为余弦不是),但它的逆函数 \(\operatorname{arcversin}\) 是。我们可以将其定义如下(使用与 SymPy 中其他反三角函数相同的命名约定)

>>> class aversin(Function):
...     def inverse(self, argindex=1):
...         return versin

这使得 solve()aversin(x) 上工作

>>> from sympy import solve
>>> solve(aversin(x) - y, x)
[versin(y)]

as_real_imag()

方法 as_real_imag() 定义了如何将一个函数拆分成实部和虚部。它被 SymPy 中各种对表达式实部和虚部分别进行操作的函数所使用。

as_real_imag(self, deep=True, **hints) 应该返回一个包含函数实部和虚部的 2 元组。也就是说 expr.as_real_imag() 返回 (re(expr), im(expr)),其中 expr == re(expr) + im(expr)*I,而 re(expr)im(expr) 为实数。

如果 deep=True,它应该递归地对它的参数调用 as_real_imag(deep=True, **hints)。与 doit()the _eval_expand_*() methods 一样,**hints 可以是任何提示,允许用户指定方法的行为。未知提示应被忽略并传递到任何递归调用中,以防它们是为其他 as_real_imag() 方法准备的。

对于我们的 versin example,我们可以递归地使用已经为 1 - cos(x) 定义的 as_real_imag()

>>> class versin(Function):
...     def as_real_imag(self, deep=True, **hints):
...         return (1 - cos(self.args[0])).as_real_imag(deep=deep, **hints)
>>> versin(x).as_real_imag()
(-cos(re(x))*cosh(im(x)) + 1, sin(re(x))*sinh(im(x)))

定义 as_real_imag() 也会自动使 expand_complex() 起作用。

>>> versin(x).expand(complex=True)
I*sin(re(x))*sinh(im(x)) - cos(re(x))*cosh(im(x)) + 1

各种 _eval_* 方法

SymPy 中有许多其他函数,它们的行为可以通过自定义 _eval_* 方法在自定义函数上定义,类似于上面描述的那些。有关如何定义每个方法的详细信息,请参阅特定函数的文档。

完整示例

以下是本指南中定义的示例函数的完整示例。有关每个方法的详细信息,请参阅上述部分。

Versine

versine(正矢)函数定义为

\[\operatorname{versin}(x) = 1 - \cos(x).\]

Versine 是一个为所有复数定义的简单函数的示例。它的数学定义很简单,这使得在它上面定义所有上述方法变得很直观(在大多数情况下,我们只需重用为 1 - cos(x) 定义的现有 SymPy 逻辑)。

定义

>>> from sympy import Function, cos, expand_trig, Integer, pi, sin
>>> from sympy.core.logic import fuzzy_and, fuzzy_not
>>> class versin(Function):
...     r"""
...     The versine function.
...
...     $\operatorname{versin}(x) = 1 - \cos(x) = 2\sin(x/2)^2.$
...
...     Geometrically, given a standard right triangle with angle x in the
...     unit circle, the versine of x is the positive horizontal distance from
...     the right angle of the triangle to the rightmost point on the unit
...     circle. It was historically used as a more numerically accurate way to
...     compute 1 - cos(x), but it is rarely used today.
...
...     References
...     ==========
...
...     .. [1] https://en.wikipedia.org/wiki/Versine
...     .. [2] https://blogs.scientificamerican.com/roots-of-unity/10-secret-trig-functions-your-math-teachers-never-taught-you/
...     """
...     # Define evaluation on basic inputs.
...     @classmethod
...     def eval(cls, x):
...         # If x is an explicit integer multiple of pi, x/pi will cancel and
...         # be an Integer.
...         n = x/pi
...         if isinstance(n, Integer):
...             return 1 - (-1)**n
...
...     # Define numerical evaluation with evalf().
...     def _eval_evalf(self, prec):
...         return (2*sin(self.args[0]/2)**2)._eval_evalf(prec)
...
...     # Define basic assumptions.
...     def _eval_is_nonnegative(self):
...         # versin(x) is nonnegative if x is real
...         x = self.args[0]
...         if x.is_real is True:
...             return True
...
...     def _eval_is_positive(self):
...         # versin(x) is positive if x is real and not an even multiple of pi
...         x = self.args[0]
...
...         # x.as_independent(pi, as_Add=False) will split x as a Mul of the
...         # form n*pi
...         coeff, pi_ = x.as_independent(pi, as_Add=False)
...         # If pi_ = pi, x = coeff*pi. Otherwise pi_ = 1 and x is not
...         # (structurally) of the form n*pi.
...         if pi_ == pi:
...             return fuzzy_and([x.is_real, fuzzy_not(coeff.is_even)])
...         elif x.is_real is False:
...             return False
...         # else: return None. We do not know for sure whether x is an even
...         # multiple of pi
...
...     # Define the behavior for various simplification and rewriting
...     # functions.
...     def _eval_rewrite(self, rule, args, **hints):
...         if rule == cos:
...             return 1 - cos(*args)
...         elif rule == sin:
...             return 2*sin(x/2)**2
...
...     def _eval_expand_trig(self, **hints):
...         x = self.args[0]
...         return expand_trig(1 - cos(x))
...
...     def as_real_imag(self, deep=True, **hints):
...         # reuse _eval_rewrite(cos) defined above
...         return self.rewrite(cos).as_real_imag(deep=deep, **hints)
...
...     # Define differentiation.
...     def fdiff(self, argindex=1):
...         return sin(self.args[0])

示例

评估

>>> x, y = symbols('x y')
>>> versin(x)
versin(x)
>>> versin(2*pi)
0
>>> versin(1.0)
0.459697694131860

假设

>>> n = symbols('n', integer=True)
>>> versin(n).is_real
True
>>> versin((2*n + 1)*pi).is_positive
True
>>> versin(2*n*pi).is_zero
True
>>> print(versin(n*pi).is_positive)
None
>>> r = symbols('r', real=True)
>>> print(versin(r).is_positive)
None
>>> nr = symbols('nr', real=False)
>>> print(versin(nr).is_nonnegative)
None

简化

>>> a, b = symbols('a b', real=True)
>>> from sympy import I
>>> versin(x).rewrite(cos)
1 - cos(x)
>>> versin(x).rewrite(sin)
2*sin(x/2)**2
>>> versin(2*x).expand(trig=True)
2 - 2*cos(x)**2
>>> versin(a + b*I).expand(complex=True)
I*sin(a)*sinh(b) - cos(a)*cosh(b) + 1

微分

>>> versin(x).diff(x)
sin(x)

求解

(aversin 的更一般版本也会定义所有上述方法)

>>> class aversin(Function):
...     def inverse(self, argindex=1):
...         return versin
>>> from sympy import solve
>>> solve(aversin(x**2) - y, x)
[-sqrt(versin(y)), sqrt(versin(y))]

divides

divides 函数定义为

\[\begin{split}\operatorname{divides}(m, n) = \begin{cases} 1 & \text{for}\: m \mid n \\ 0 & \text{for}\: m\not\mid n \end{cases},\end{split}\]

也就是说,如果 m 可以整除 n,则 divides(m, n) 为 1;如果 m 不能整除 m,则 divides(m, n) 为 0。它仅针对整数 mn 定义。为了简单起见,我们采用以下约定:对于所有整数 \(m\)\(m \mid 0\)

divides 是一个仅针对特定输入值(整数)定义的函数的示例。 divides 还提供了一个定义自定义打印机(_latex())的示例。

定义

>>> from sympy import Function, Integer
>>> from sympy.core.logic import fuzzy_not
>>> class divides(Function):
...     r"""
...     $$\operatorname{divides}(m, n) = \begin{cases} 1 & \text{for}\: m \mid n \\ 0 & \text{for}\: m\not\mid n  \end{cases}.$$
...
...     That is, ``divides(m, n)`` is ``1`` if ``m`` divides ``n`` and ``0``
...     if ``m`` does not divide ``n`. It is undefined if ``m`` or ``n`` are
...     not integers. For simplicity, the convention is used that
...     ``divides(m, 0) = 1`` for all integers ``m``.
...
...     References
...     ==========
...
...     .. [1] https://en.wikipedia.org/wiki/Divisor#Definition
...     """
...     # Define evaluation on basic inputs, as well as type checking that the
...     # inputs are not nonintegral.
...     @classmethod
...     def eval(cls, m, n):
...         # Evaluate for explicit integer m and n.
...         if isinstance(m, Integer) and isinstance(n, Integer):
...             return int(n % m == 0)
...
...         # For symbolic arguments, require m and n to be integer.
...         if m.is_integer is False or n.is_integer is False:
...             raise TypeError("m and n should be integers")
...
...     # Define basic assumptions.
...
...     # divides is always either 0 or 1.
...     is_integer = True
...     is_negative = False
...
...     # Whether divides(m, n) is 0 or 1 depends on m and n. Note that this
...     # method only makes sense because we don't automatically evaluate on
...     # such cases, but instead simplify these cases in doit() below.
...     def _eval_is_zero(self):
...         m, n = self.args
...         if m.is_integer and n.is_integer:
...              return fuzzy_not((n/m).is_integer)
...
...     # Define doit() as further evaluation on symbolic arguments using
...     # assumptions.
...     def doit(self, deep=False, **hints):
...         m, n = self.args
...         # Recursively call doit() on the args whenever deep=True.
...         # Be sure to pass deep=True and **hints through here.
...         if deep:
...            m, n = m.doit(deep=deep, **hints), n.doit(deep=deep, **hints)
...
...         # divides(m, n) is 1 iff n/m is an integer. Note that m and n are
...         # already assumed to be integers because of the logic in eval().
...         isint = (n/m).is_integer
...         if isint is True:
...             return Integer(1)
...         elif isint is False:
...             return Integer(0)
...         else:
...             return divides(m, n)
...
...     # Define LaTeX printing for use with the latex() function and the
...     # Jupyter notebook.
...     def _latex(self, printer):
...         m, n = self.args
...         _m, _n = printer._print(m), printer._print(n)
...         return r'\left [ %s \middle | %s \right ]' % (_m, _n)
...

示例

评估

>>> from sympy import symbols
>>> n, m, k = symbols('n m k', integer=True)
>>> divides(3, 10)
0
>>> divides(3, 12)
1
>>> divides(m, n).is_integer
True
>>> divides(k, 2*k)
divides(k, 2*k)
>>> divides(k, 2*k).is_zero
False
>>> divides(k, 2*k).doit()
1

打印

>>> str(divides(m, n)) # This is using the default str printer
'divides(m, n)'
>>> print(latex(divides(m, n)))
\left [ m \middle | n \right ]

融合乘加 (FMA)

融合乘加 (FMA) 是一个乘法运算后接加法运算

\[\operatorname{FMA}(x, y, z) = xy + z.\]

它通常在硬件中实现为一个单一的浮点运算,与等效的乘法和加法运算组合相比,它具有更好的舍入和性能。

FMA 是一个自定义函数的示例,它被定义为另一个函数的未评估的“简写”。这是因为 doit() 方法定义为返回 x*y + z,这意味着 FMA 函数可以轻松地评估为它所代表的表达式,但 eval() 方法什么也不返回(除了当 xyz 都是显式的数值时),这意味着它在默认情况下保持未评估状态。

将它与 versine 示例进行对比,该示例将 versin 视为一个独立的一流函数。即使 versin(x) 可以用其他函数表示(1 - cos(x)),它也不会在 versin.eval() 中对一般的符号输入进行评估,而 versin.doit() 根本没有定义。

FMA 还代表一个在多个变量上定义的连续函数的示例,它演示了 fdiff 示例中的 argindex 的工作原理。

最后,FMA 展示了为 CC++ 定义一些代码打印机的示例(使用来自 C99CodePrinter.printmethodCXX11CodePrinter.printmethod 的方法名),因为这是该函数的典型用例。

FMA 的数学定义非常简单,很容易在它上面定义每个方法,但这里只显示了少数方法。 versinedivides 示例展示了如何定义本指南中讨论的其他重要方法。

请注意,如果您想实际使用融合乘加来生成代码,SymPy 中已经存在一个版本 sympy.codegen.cfunctions.fma(),它受现有代码打印机的支持。这里提供的版本仅作为示例。

定义

>>> from sympy import Number, symbols, Add, Mul
>>> x, y, z = symbols('x y z')
>>> class FMA(Function):
...     """
...     FMA(x, y, z) = x*y + z
...
...     FMA is often defined as a single operation in hardware for better
...     rounding and performance.
...
...     FMA can be evaluated by using the doit() method.
...
...     References
...     ==========
...
...     .. [1] https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation#Fused_multiply%E2%80%93add
...     """
...     # Define automatic evaluation on explicit numbers
...     @classmethod
...     def eval(cls, x, y, z):
...         # Number is the base class of Integer, Rational, and Float
...         if all(isinstance(i, Number) for i in [x, y, z]):
...            return x*y + z
...
...     # Define numerical evaluation with evalf().
...     def _eval_evalf(self, prec):
...         return self.doit(deep=False)._eval_evalf(prec)
...
...     # Define full evaluation to Add and Mul in doit(). This effectively
...     # treats FMA(x, y, z) as just a shorthand for x*y + z that is useful
...     # to have as a separate expression in some contexts and which can be
...     # evaluated to its expanded form in other contexts.
...     def doit(self, deep=True, **hints):
...         x, y, z = self.args
...         # Recursively call doit() on the args whenever deep=True.
...         # Be sure to pass deep=True and **hints through here.
...         if deep:
...             x = x.doit(deep=deep, **hints)
...             y = y.doit(deep=deep, **hints)
...             z = z.doit(deep=deep, **hints)
...         return x*y + z
...
...     # Define FMA.rewrite(Add) and FMA.rewrite(Mul).
...     def _eval_rewrite(self, rule, args, **hints):
...         x, y, z = self.args
...         if rule in [Add, Mul]:
...             return self.doit()
...
...     # Define differentiation.
...     def fdiff(self, argindex):
...         # argindex indexes the args, starting at 1
...         x, y, z = self.args
...         if argindex == 1:
...             return y
...         elif argindex == 2:
...             return x
...         elif argindex == 3:
...             return 1
...
...     # Define code printers for ccode() and cxxcode()
...     def _ccode(self, printer):
...         x, y, z = self.args
...         _x, _y, _z = printer._print(x), printer._print(y), printer._print(z)
...         return "fma(%s, %s, %s)" % (_x, _y, _z)
...
...     def _cxxcode(self, printer):
...         x, y, z = self.args
...         _x, _y, _z = printer._print(x), printer._print(y), printer._print(z)
...         return "std::fma(%s, %s, %s)" % (_x, _y, _z)

示例

评估

>>> x, y, z = symbols('x y z')
>>> FMA(2, 3, 4)
10
>>> FMA(x, y, z)
FMA(x, y, z)
>>> FMA(x, y, z).doit()
x*y + z
>>> FMA(x, y, z).rewrite(Add)
x*y + z
>>> FMA(2, pi, 1).evalf()
7.28318530717959

微分

>>> FMA(x, x, y).diff(x)
2*x
>>> FMA(x, y, x).diff(x)
y + 1

代码打印机

>>> from sympy import ccode, cxxcode
>>> ccode(FMA(x, y, z))
'fma(x, y, z)'
>>> cxxcode(FMA(x, y, z))
'std::fma(x, y, z)'

其他提示

  • SymPy 包含数十个函数。它们可以作为编写自定义函数的有用示例,特别是如果函数与已实现的函数类似。请记住,本指南中的所有内容都同样适用于 SymPy 附带的函数和用户定义的函数。实际上,本指南旨在作为 SymPy 贡献者的开发指南和 SymPy 终端用户的指南。

  • 如果您有许多共享通用逻辑的自定义函数,您可以使用一个通用的基类来包含此共享逻辑。有关这方面的示例,请参阅 SymPy 中三角函数的源代码,它使用 TrigonometricFunctionInverseTrigonometricFunctionReciprocalTrigonometricFunction 基类,其中包含一些共享逻辑。

  • 与任何代码一样,为您的函数编写大量测试是一个好主意。 SymPy 测试套件 是如何为此类函数编写测试的良好资源。SymPy 本身包含的所有代码都需要进行测试。包含在 SymPy 中的函数还应该始终包含一个文档字符串,其中包含引用、数学定义和 doctest 示例。