高级表达式操作

在本节中,我们将讨论一些执行高级表达式操作的方法。

理解表达式树

在我们开始之前,我们需要了解 SymPy 中表达式的表示方式。数学表达式用树来表示。让我们以表达式 \(x^2 + xy\) 为例,即 x**2 + x*y。我们可以使用 srepr 查看此表达式在内部的表示方式。

>>> from sympy import *
>>> x, y, z = symbols('x y z')
>>> expr = x**2 + x*y
>>> srepr(expr)
"Add(Pow(Symbol('x'), Integer(2)), Mul(Symbol('x'), Symbol('y')))"

最简单的方法是查看表达式树的图表

digraph{ # Graph style "ordering"="out" "rankdir"="TD" ######### # Nodes # ######### "Add(Pow(Symbol('x'), Integer(2)), Mul(Symbol('x'), Symbol('y')))_()" ["color"="black", "label"="Add", "shape"="ellipse"]; "Pow(Symbol('x'), Integer(2))_(0,)" ["color"="black", "label"="Pow", "shape"="ellipse"]; "Symbol('x')_(0, 0)" ["color"="black", "label"="Symbol('x')", "shape"="ellipse"]; "Integer(2)_(0, 1)" ["color"="black", "label"="Integer(2)", "shape"="ellipse"]; "Mul(Symbol('x'), Symbol('y'))_(1,)" ["color"="black", "label"="Mul", "shape"="ellipse"]; "Symbol('x')_(1, 0)" ["color"="black", "label"="Symbol('x')", "shape"="ellipse"]; "Symbol('y')_(1, 1)" ["color"="black", "label"="Symbol('y')", "shape"="ellipse"]; ######### # Edges # ######### "Add(Pow(Symbol('x'), Integer(2)), Mul(Symbol('x'), Symbol('y')))_()" -> "Pow(Symbol('x'), Integer(2))_(0,)"; "Add(Pow(Symbol('x'), Integer(2)), Mul(Symbol('x'), Symbol('y')))_()" -> "Mul(Symbol('x'), Symbol('y'))_(1,)"; "Pow(Symbol('x'), Integer(2))_(0,)" -> "Symbol('x')_(0, 0)"; "Pow(Symbol('x'), Integer(2))_(0,)" -> "Integer(2)_(0, 1)"; "Mul(Symbol('x'), Symbol('y'))_(1,)" -> "Symbol('x')_(1, 0)"; "Mul(Symbol('x'), Symbol('y'))_(1,)" -> "Symbol('y')_(1, 1)"; }

注意

上面的图像是使用 Graphvizdotprint 函数生成的。

首先,让我们看看这棵树的叶子。符号是 Symbol 类的实例。虽然我们一直在做

>>> x = symbols('x')

但我们也可以做

>>> x = Symbol('x')

无论哪种方式,我们都会得到一个名为“x”的符号 [1]。对于表达式中的数字 2,我们得到了 Integer(2)Integer 是 SymPy 中整数的类。它类似于 Python 内置类型 int,只是 Integer 与其他 SymPy 类型配合得很好。

当我们写 x**2 时,它会创建一个 Pow 对象。 Pow 是 “power” 的缩写。

>>> srepr(x**2)
"Pow(Symbol('x'), Integer(2))"

我们也可以通过调用 Pow(x, 2) 来创建相同对象。

>>> Pow(x, 2)
x**2

请注意,在 srepr 输出中,我们看到了 Integer(2),即 SymPy 版本的整数,即使从技术上讲,我们输入的是 2,一个 Python int。通常,当您通过某个函数或操作将 SymPy 对象与非 SymPy 对象组合时,非 SymPy 对象将被转换为 SymPy 对象。执行此操作的函数是 sympify [2]

>>> type(2)
<... 'int'>
>>> type(sympify(2))
<class 'sympy.core.numbers.Integer'>

我们已经看到 x**2 表示为 Pow(x, 2)。那么 x*y 呢?正如我们所预期的那样,这是 xy 的乘法。SymPy 中乘法的类是 Mul

>>> srepr(x*y)
"Mul(Symbol('x'), Symbol('y'))"

因此,我们可以通过编写 Mul(x, y) 来创建相同对象。

>>> Mul(x, y)
x*y

现在我们到了最终的表达式, x**2 + x*y。这是我们最后两个对象的加法,Pow(x, 2)Mul(x, y)。SymPy 中加法的类是 Add,因此,正如您所预期的那样,要创建此对象,我们使用 Add(Pow(x, 2), Mul(x, y))

>>> Add(Pow(x, 2), Mul(x, y))
x**2 + x*y

SymPy 表达式树可以有很多分支,并且可以非常深或非常广。下面是一个更复杂的例子

>>> expr = sin(x*y)/2 - x**2 + 1/y
>>> srepr(expr)
"Add(Mul(Integer(-1), Pow(Symbol('x'), Integer(2))), Mul(Rational(1, 2),
sin(Mul(Symbol('x'), Symbol('y')))), Pow(Symbol('y'), Integer(-1)))"

下面是一个图表

digraph{ # Graph style "rankdir"="TD" ######### # Nodes # ######### "Half()_(0, 0)" ["color"="black", "label"="Rational(1, 2)", "shape"="ellipse"]; "Symbol(y)_(2, 0)" ["color"="black", "label"="Symbol('y')", "shape"="ellipse"]; "Symbol(x)_(1, 1, 0)" ["color"="black", "label"="Symbol('x')", "shape"="ellipse"]; "Integer(2)_(1, 1, 1)" ["color"="black", "label"="Integer(2)", "shape"="ellipse"]; "NegativeOne()_(2, 1)" ["color"="black", "label"="Integer(-1)", "shape"="ellipse"]; "NegativeOne()_(1, 0)" ["color"="black", "label"="Integer(-1)", "shape"="ellipse"]; "Symbol(y)_(0, 1, 0, 1)" ["color"="black", "label"="Symbol('y')", "shape"="ellipse"]; "Symbol(x)_(0, 1, 0, 0)" ["color"="black", "label"="Symbol('x')", "shape"="ellipse"]; "Pow(Symbol(x), Integer(2))_(1, 1)" ["color"="black", "label"="Pow", "shape"="ellipse"]; "Pow(Symbol(y), NegativeOne())_(2,)" ["color"="black", "label"="Pow", "shape"="ellipse"]; "Mul(Symbol(x), Symbol(y))_(0, 1, 0)" ["color"="black", "label"="Mul", "shape"="ellipse"]; "sin(Mul(Symbol(x), Symbol(y)))_(0, 1)" ["color"="black", "label"="sin", "shape"="ellipse"]; "Mul(Half(), sin(Mul(Symbol(x), Symbol(y))))_(0,)" ["color"="black", "label"="Mul", "shape"="ellipse"]; "Mul(NegativeOne(), Pow(Symbol(x), Integer(2)))_(1,)" ["color"="black", "label"="Mul", "shape"="ellipse"]; "Add(Mul(Half(), sin(Mul(Symbol(x), Symbol(y)))), Mul(NegativeOne(), Pow(Symbol(x), Integer(2))), Pow(Symbol(y), NegativeOne()))_()" ["color"="black", "label"="Add", "shape"="ellipse"]; ######### # Edges # ######### "Pow(Symbol(y), NegativeOne())_(2,)" -> "Symbol(y)_(2, 0)"; "Pow(Symbol(x), Integer(2))_(1, 1)" -> "Symbol(x)_(1, 1, 0)"; "Pow(Symbol(x), Integer(2))_(1, 1)" -> "Integer(2)_(1, 1, 1)"; "Pow(Symbol(y), NegativeOne())_(2,)" -> "NegativeOne()_(2, 1)"; "Mul(Symbol(x), Symbol(y))_(0, 1, 0)" -> "Symbol(x)_(0, 1, 0, 0)"; "Mul(Symbol(x), Symbol(y))_(0, 1, 0)" -> "Symbol(y)_(0, 1, 0, 1)"; "Mul(Half(), sin(Mul(Symbol(x), Symbol(y))))_(0,)" -> "Half()_(0, 0)"; "Mul(NegativeOne(), Pow(Symbol(x), Integer(2)))_(1,)" -> "NegativeOne()_(1, 0)"; "sin(Mul(Symbol(x), Symbol(y)))_(0, 1)" -> "Mul(Symbol(x), Symbol(y))_(0, 1, 0)"; "Mul(NegativeOne(), Pow(Symbol(x), Integer(2)))_(1,)" -> "Pow(Symbol(x), Integer(2))_(1, 1)"; "Mul(Half(), sin(Mul(Symbol(x), Symbol(y))))_(0,)" -> "sin(Mul(Symbol(x), Symbol(y)))_(0, 1)"; "Add(Mul(Half(), sin(Mul(Symbol(x), Symbol(y)))), Mul(NegativeOne(), Pow(Symbol(x), Integer(2))), Pow(Symbol(y), NegativeOne()))_()" -> "Pow(Symbol(y), NegativeOne())_(2,)"; "Add(Mul(Half(), sin(Mul(Symbol(x), Symbol(y)))), Mul(NegativeOne(), Pow(Symbol(x), Integer(2))), Pow(Symbol(y), NegativeOne()))_()" -> "Mul(Half(), sin(Mul(Symbol(x), Symbol(y))))_(0,)"; "Add(Mul(Half(), sin(Mul(Symbol(x), Symbol(y)))), Mul(NegativeOne(), Pow(Symbol(x), Integer(2))), Pow(Symbol(y), NegativeOne()))_()" -> "Mul(NegativeOne(), Pow(Symbol(x), Integer(2)))_(1,)"; }

此表达式揭示了关于 SymPy 表达式树的一些有趣的事情。让我们逐一介绍。

首先让我们看一下 x**2 项。正如我们预期的那样,我们看到了 Pow(x, 2)。上一级,我们看到我们有 Mul(-1, Pow(x, 2))。SymPy 中没有减法类。 x - y 表示为 x + -y,或者更完整地, x + -1*y,即 Add(x, Mul(-1, y))

>>> srepr(x - y)
"Add(Symbol('x'), Mul(Integer(-1), Symbol('y')))"

digraph{ # Graph style "rankdir"="TD" ######### # Nodes # ######### "Symbol(x)_(1,)" ["color"="black", "label"="Symbol('x')", "shape"="ellipse"]; "Symbol(y)_(0, 1)" ["color"="black", "label"="Symbol('y')", "shape"="ellipse"]; "NegativeOne()_(0, 0)" ["color"="black", "label"="Integer(-1)", "shape"="ellipse"]; "Mul(NegativeOne(), Symbol(y))_(0,)" ["color"="black", "label"="Mul", "shape"="ellipse"]; "Add(Mul(NegativeOne(), Symbol(y)), Symbol(x))_()" ["color"="black", "label"="Add", "shape"="ellipse"]; ######### # Edges # ######### "Mul(NegativeOne(), Symbol(y))_(0,)" -> "Symbol(y)_(0, 1)"; "Mul(NegativeOne(), Symbol(y))_(0,)" -> "NegativeOne()_(0, 0)"; "Add(Mul(NegativeOne(), Symbol(y)), Symbol(x))_()" -> "Symbol(x)_(1,)"; "Add(Mul(NegativeOne(), Symbol(y)), Symbol(x))_()" -> "Mul(NegativeOne(), Symbol(y))_(0,)"; }

接下来,看一下 1/y。我们可能会预期看到类似于 Div(1, y) 的内容,但与减法类似,SymPy 中没有除法类。相反,除法用 -1 的幂表示。因此,我们有 Pow(y, -1)。如果我们用 y 除以 1 之外的其他东西,比如 x/y 会怎样?让我们看看。

>>> expr = x/y
>>> srepr(expr)
"Mul(Symbol('x'), Pow(Symbol('y'), Integer(-1)))"

digraph{ # Graph style "rankdir"="TD" ######### # Nodes # ######### "Symbol(x)_(0,)" ["color"="black", "label"="Symbol('x')", "shape"="ellipse"]; "Symbol(y)_(1, 0)" ["color"="black", "label"="Symbol('y')", "shape"="ellipse"]; "NegativeOne()_(1, 1)" ["color"="black", "label"="Integer(-1)", "shape"="ellipse"]; "Pow(Symbol(y), NegativeOne())_(1,)" ["color"="black", "label"="Pow", "shape"="ellipse"]; "Mul(Symbol(x), Pow(Symbol(y), NegativeOne()))_()" ["color"="black", "label"="Mul", "shape"="ellipse"]; ######### # Edges # ######### "Pow(Symbol(y), NegativeOne())_(1,)" -> "Symbol(y)_(1, 0)"; "Pow(Symbol(y), NegativeOne())_(1,)" -> "NegativeOne()_(1, 1)"; "Mul(Symbol(x), Pow(Symbol(y), NegativeOne()))_()" -> "Symbol(x)_(0,)"; "Mul(Symbol(x), Pow(Symbol(y), NegativeOne()))_()" -> "Pow(Symbol(y), NegativeOne())_(1,)"; }

我们看到 x/y 表示为 x*y**-1,即 Mul(x, Pow(y, -1))

最后,让我们看一下 sin(x*y)/2 项。按照上一个示例的模式,我们可能会预期看到 Mul(sin(x*y), Pow(Integer(2), -1))。但相反,我们有 Mul(Rational(1, 2), sin(x*y))。有理数总是合并到乘法中的一个术语,因此,当我们除以 2 时,它表示为乘以 1/2。

最后,还有一点需要注意。您可能已经注意到,我们输入表达式的顺序和从 srepr 或图表中输出的顺序不同。您可能还在本教程的前面部分注意到过这种现象。例如

>>> 1 + x
x + 1

这是因为在 SymPy 中,可交换运算符 AddMul 的参数以任意(但一致!)的顺序存储,该顺序独立于输入顺序(如果您担心不可交换乘法,请不要担心。在 SymPy 中,您可以使用 Symbol('A', commutative=False) 创建不可交换符号,不可交换符号的乘法顺序将与输入保持一致)。此外,正如我们在下一节中将要看到的,打印顺序和内部存储顺序也可能不相同。

通常,在使用 SymPy 表达式树时,需要牢记的一点是:表达式的内部表示方式和打印方式可能不同。输入形式也是如此。如果某些表达式操作算法没有按预期的方式工作,则很可能是对象内部表示方式与您认为的不同。

递归遍历表达式树

现在您已经了解了 SymPy 中表达式树的工作原理,让我们看一下如何深入了解表达式树。SymPy 中的每个对象都有两个非常重要的属性,funcargs

func

func 是对象的头部。例如,(x*y).funcMul。通常它与对象的类相同(尽管存在例外)。

关于 func 有两点需要注意。首先,对象的类不必与用于创建它的类相同。例如

>>> expr = Add(x, x)
>>> expr.func
<class 'sympy.core.mul.Mul'>

我们创建了 Add(x, x),因此我们可能会预期 expr.funcAdd,但相反我们得到了 Mul。为什么呢?让我们仔细看一下 expr

>>> expr
2*x

Add(x, x),即 x + x,会自动转换为 Mul(2, x),即 2*x,这是一个 Mul。SymPy 类大量使用了 __new__ 类构造函数,与 __init__ 不同,它允许从构造函数中返回不同的类。

其次,有些类是特殊情况,通常是为了效率原因 [3]

>>> Integer(2).func
<class 'sympy.core.numbers.Integer'>
>>> Integer(0).func
<class 'sympy.core.numbers.Zero'>
>>> Integer(-1).func
<class 'sympy.core.numbers.NegativeOne'>

在大多数情况下,这些问题不会困扰我们。特殊类 ZeroOneNegativeOne 等是 Integer 的子类,因此只要您使用 isinstance,它就不会成为问题。

args

args 是对象的顶层参数。 (x*y).args 将是 (x, y)。让我们看一些例子

>>> expr = 3*y**2*x
>>> expr.func
<class 'sympy.core.mul.Mul'>
>>> expr.args
(3, x, y**2)

由此我们可以看到 expr == Mul(3, y**2, x)。事实上,我们可以看到,我们可以从 exprfuncargs 中完全重建 expr

>>> expr.func(*expr.args)
3*x*y**2
>>> expr == expr.func(*expr.args)
True

请注意,尽管我们输入了 3*y**2*x,但 args(3, x, y**2)。在 Mul 中,Rational 系数将在 args 中排在第一位,但除此之外,其他所有内容的顺序没有特殊模式。不过,请务必确保存在一个顺序。

>>> expr = y**2*3*x
>>> expr.args
(3, x, y**2)

Mul 的 args 是排序的,以便相同的 Mul 具有相同的 args。但排序基于一些旨在使排序唯一且高效的标准,这些标准没有数学意义。

我们 exprsrepr 形式是 Mul(3, x, Pow(y, 2))。如果我们想获得 Pow(y, 2)args 会怎样。请注意,y**2 位于 expr.args 的第三个位置,即 expr.args[2]

>>> expr.args[2]
y**2

因此,要获得它的 args,我们调用 expr.args[2].args

>>> expr.args[2].args
(y, 2)

现在我们尝试深入一点。y 的参数是什么? 或者 2 呢? 让我们看看。

>>> y.args
()
>>> Integer(2).args
()

它们都有空的 args。 在 SymPy 中,空的 args 表示我们已经到达了表达式树的叶子节点。

因此,SymPy 表达式有两种可能性。它要么具有空的 args,在这种情况下它是任何表达式树中的叶子节点;要么它具有 args,在这种情况下,它是任何表达式树的分支节点。 当它具有 args 时,它可以完全从它的 funcargs 重新构建。 这在关键不变式中得到了体现。

(回想一下,在 Python 中,如果 a 是一个元组,那么 f(*a) 表示使用 a 的元素作为参数调用 f,例如,f(*(1, 2, 3)) 等同于 f(1, 2, 3)。)

这个关键不变式使我们可以编写简单的算法,这些算法可以遍历表达式树、更改它们并将其重新构建成新的表达式。

遍历树

有了这些知识,让我们看看如何递归遍历表达式树。 args 的嵌套性质非常适合递归函数。 基本情况将是空的 args。 让我们编写一个简单的函数,它遍历表达式并打印每个级别上的所有 args

>>> def pre(expr):
...     print(expr)
...     for arg in expr.args:
...         pre(arg)

看看 () 如何在表达式树中表示叶子节点。 我们甚至不需要为递归编写基本情况;它由 for 循环自动处理。

让我们测试一下我们的函数。

>>> expr = x*y + 1
>>> pre(expr)
x*y + 1
1
x*y
x
y

你能猜到我们为什么把函数命名为 pre 吗? 我们刚刚为表达式树编写了一个前序遍历函数。 看看你是否可以编写一个后序遍历函数。

这种遍历在 SymPy 中非常常见,因此提供了生成器函数 preorder_traversalpostorder_traversal 来简化这种遍历。 我们也可以将我们的算法写成

>>> for arg in preorder_traversal(expr):
...     print(arg)
x*y + 1
1
x*y
x
y

阻止表达式求值

通常有两种方法可以阻止求值,要么在构建表达式时传递一个 evaluate=False 参数,要么通过用 UnevaluatedExpr 包裹表达式来创建求值停止器。

例如

>>> from sympy import Add
>>> from sympy.abc import x, y, z
>>> x + x
2*x
>>> Add(x, x)
2*x
>>> Add(x, x, evaluate=False)
x + x

如果你不记得要构建的表达式的对应类(运算符重载通常假设 evaluate=True),只需使用 sympify 并传递一个字符串

>>> from sympy import sympify
>>> sympify("x + x", evaluate=False)
x + x

请注意,evaluate=False 不会阻止表达式在以后使用时的进一步求值

>>> expr = Add(x, x, evaluate=False)
>>> expr
x + x
>>> expr + x
3*x

这就是 UnevaluatedExpr 类派上用场的地方。 UnevaluatedExpr 是 SymPy 提供的一种方法,允许用户保留未求值的表达式。 所谓的未求值是指它内部的值不会与它外部的表达式交互以给出简化的输出。 例如

>>> from sympy import UnevaluatedExpr
>>> expr = x + UnevaluatedExpr(x)
>>> expr
x + x
>>> x + expr
2*x + x

单独的 \(x\) 是被 UnevaluatedExpr 包裹的 \(x\)。 要释放它

>>> (x + expr).doit()
3*x

其他示例

>>> from sympy import *
>>> from sympy.abc import x, y, z
>>> uexpr = UnevaluatedExpr(S.One*5/7)*UnevaluatedExpr(S.One*3/4)
>>> uexpr
(5/7)*(3/4)
>>> x*UnevaluatedExpr(1/x)
x*1/x

需要注意的是,UnevaluatedExpr 无法阻止作为参数给出的表达式的求值。 例如

>>> expr1 = UnevaluatedExpr(x + x)
>>> expr1
2*x
>>> expr2 = sympify('x + x', evaluate=False)
>>> expr2
x + x

请记住,如果 expr2 包含在另一个表达式中,它将被求值。 结合这两种方法来阻止内部和外部求值

>>> UnevaluatedExpr(sympify("x + x", evaluate=False)) + y
y + (x + x)

UnevaluatedExpr 受 SymPy 打印机支持,可用于以不同的输出形式打印结果。 例如

>>> from sympy import latex
>>> uexpr = UnevaluatedExpr(S.One*5/7)*UnevaluatedExpr(S.One*3/4)
>>> print(latex(uexpr))
\frac{5}{7} \cdot \frac{3}{4}

为了释放表达式并获得求值的 LaTeX 形式,只需使用 .doit()

>>> print(latex(uexpr.doit()))
\frac{15}{28}

脚注