Viewing File: /home/ubuntu/combine_ai/combine/lib/python3.10/site-packages/sympy/core/tests/test_expr.py
from sympy.assumptions.refine import refine
from sympy.concrete.summations import Sum
from sympy.core.add import Add
from sympy.core.basic import Basic
from sympy.core.containers import Tuple
from sympy.core.expr import (ExprBuilder, unchanged, Expr,
UnevaluatedExpr)
from sympy.core.function import (Function, expand, WildFunction,
AppliedUndef, Derivative, diff, Subs)
from sympy.core.mul import Mul
from sympy.core.numbers import (NumberSymbol, E, zoo, oo, Float, I,
Rational, nan, Integer, Number, pi, _illegal)
from sympy.core.power import Pow
from sympy.core.relational import Ge, Lt, Gt, Le
from sympy.core.singleton import S
from sympy.core.sorting import default_sort_key
from sympy.core.symbol import Symbol, symbols, Dummy, Wild
from sympy.core.sympify import sympify
from sympy.functions.combinatorial.factorials import factorial
from sympy.functions.elementary.exponential import exp_polar, exp, log
from sympy.functions.elementary.miscellaneous import sqrt, Max
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import tan, sin, cos
from sympy.functions.special.delta_functions import (Heaviside,
DiracDelta)
from sympy.functions.special.error_functions import Si
from sympy.functions.special.gamma_functions import gamma
from sympy.integrals.integrals import integrate, Integral
from sympy.physics.secondquant import FockState
from sympy.polys.partfrac import apart
from sympy.polys.polytools import factor, cancel, Poly
from sympy.polys.rationaltools import together
from sympy.series.order import O
from sympy.sets.sets import FiniteSet
from sympy.simplify.combsimp import combsimp
from sympy.simplify.gammasimp import gammasimp
from sympy.simplify.powsimp import powsimp
from sympy.simplify.radsimp import collect, radsimp
from sympy.simplify.ratsimp import ratsimp
from sympy.simplify.simplify import simplify, nsimplify
from sympy.simplify.trigsimp import trigsimp
from sympy.tensor.indexed import Indexed
from sympy.physics.units import meter
from sympy.testing.pytest import raises, XFAIL
from sympy.abc import a, b, c, n, t, u, x, y, z
f, g, h = symbols('f,g,h', cls=Function)
class DummyNumber:
"""
Minimal implementation of a number that works with SymPy.
If one has a Number class (e.g. Sage Integer, or some other custom class)
that one wants to work well with SymPy, one has to implement at least the
methods of this class DummyNumber, resp. its subclasses I5 and F1_1.
Basically, one just needs to implement either __int__() or __float__() and
then one needs to make sure that the class works with Python integers and
with itself.
"""
def __radd__(self, a):
if isinstance(a, (int, float)):
return a + self.number
return NotImplemented
def __add__(self, a):
if isinstance(a, (int, float, DummyNumber)):
return self.number + a
return NotImplemented
def __rsub__(self, a):
if isinstance(a, (int, float)):
return a - self.number
return NotImplemented
def __sub__(self, a):
if isinstance(a, (int, float, DummyNumber)):
return self.number - a
return NotImplemented
def __rmul__(self, a):
if isinstance(a, (int, float)):
return a * self.number
return NotImplemented
def __mul__(self, a):
if isinstance(a, (int, float, DummyNumber)):
return self.number * a
return NotImplemented
def __rtruediv__(self, a):
if isinstance(a, (int, float)):
return a / self.number
return NotImplemented
def __truediv__(self, a):
if isinstance(a, (int, float, DummyNumber)):
return self.number / a
return NotImplemented
def __rpow__(self, a):
if isinstance(a, (int, float)):
return a ** self.number
return NotImplemented
def __pow__(self, a):
if isinstance(a, (int, float, DummyNumber)):
return self.number ** a
return NotImplemented
def __pos__(self):
return self.number
def __neg__(self):
return - self.number
class I5(DummyNumber):
number = 5
def __int__(self):
return self.number
class F1_1(DummyNumber):
number = 1.1
def __float__(self):
return self.number
i5 = I5()
f1_1 = F1_1()
# basic SymPy objects
basic_objs = [
Rational(2),
Float("1.3"),
x,
y,
pow(x, y)*y,
]
# all supported objects
all_objs = basic_objs + [
5,
5.5,
i5,
f1_1
]
def dotest(s):
for xo in all_objs:
for yo in all_objs:
s(xo, yo)
return True
def test_basic():
def j(a, b):
x = a
x = +a
x = -a
x = a + b
x = a - b
x = a*b
x = a/b
x = a**b
del x
assert dotest(j)
def test_ibasic():
def s(a, b):
x = a
x += b
x = a
x -= b
x = a
x *= b
x = a
x /= b
assert dotest(s)
class NonBasic:
'''This class represents an object that knows how to implement binary
operations like +, -, etc with Expr but is not a subclass of Basic itself.
The NonExpr subclass below does subclass Basic but not Expr.
For both NonBasic and NonExpr it should be possible for them to override
Expr.__add__ etc because Expr.__add__ should be returning NotImplemented
for non Expr classes. Otherwise Expr.__add__ would create meaningless
objects like Add(Integer(1), FiniteSet(2)) and it wouldn't be possible for
other classes to override these operations when interacting with Expr.
'''
def __add__(self, other):
return SpecialOp('+', self, other)
def __radd__(self, other):
return SpecialOp('+', other, self)
def __sub__(self, other):
return SpecialOp('-', self, other)
def __rsub__(self, other):
return SpecialOp('-', other, self)
def __mul__(self, other):
return SpecialOp('*', self, other)
def __rmul__(self, other):
return SpecialOp('*', other, self)
def __truediv__(self, other):
return SpecialOp('/', self, other)
def __rtruediv__(self, other):
return SpecialOp('/', other, self)
def __floordiv__(self, other):
return SpecialOp('//', self, other)
def __rfloordiv__(self, other):
return SpecialOp('//', other, self)
def __mod__(self, other):
return SpecialOp('%', self, other)
def __rmod__(self, other):
return SpecialOp('%', other, self)
def __divmod__(self, other):
return SpecialOp('divmod', self, other)
def __rdivmod__(self, other):
return SpecialOp('divmod', other, self)
def __pow__(self, other):
return SpecialOp('**', self, other)
def __rpow__(self, other):
return SpecialOp('**', other, self)
def __lt__(self, other):
return SpecialOp('<', self, other)
def __gt__(self, other):
return SpecialOp('>', self, other)
def __le__(self, other):
return SpecialOp('<=', self, other)
def __ge__(self, other):
return SpecialOp('>=', self, other)
class NonExpr(Basic, NonBasic):
'''Like NonBasic above except this is a subclass of Basic but not Expr'''
pass
class SpecialOp():
'''Represents the results of operations with NonBasic and NonExpr'''
def __new__(cls, op, arg1, arg2):
obj = object.__new__(cls)
obj.args = (op, arg1, arg2)
return obj
class NonArithmetic(Basic):
'''Represents a Basic subclass that does not support arithmetic operations'''
pass
def test_cooperative_operations():
'''Tests that Expr uses binary operations cooperatively.
In particular it should be possible for non-Expr classes to override
binary operators like +, - etc when used with Expr instances. This should
work for non-Expr classes whether they are Basic subclasses or not. Also
non-Expr classes that do not define binary operators with Expr should give
TypeError.
'''
# A bunch of instances of Expr subclasses
exprs = [
Expr(),
S.Zero,
S.One,
S.Infinity,
S.NegativeInfinity,
S.ComplexInfinity,
S.Half,
Float(0.5),
Integer(2),
Symbol('x'),
Mul(2, Symbol('x')),
Add(2, Symbol('x')),
Pow(2, Symbol('x')),
]
for e in exprs:
# Test that these classes can override arithmetic operations in
# combination with various Expr types.
for ne in [NonBasic(), NonExpr()]:
results = [
(ne + e, ('+', ne, e)),
(e + ne, ('+', e, ne)),
(ne - e, ('-', ne, e)),
(e - ne, ('-', e, ne)),
(ne * e, ('*', ne, e)),
(e * ne, ('*', e, ne)),
(ne / e, ('/', ne, e)),
(e / ne, ('/', e, ne)),
(ne // e, ('//', ne, e)),
(e // ne, ('//', e, ne)),
(ne % e, ('%', ne, e)),
(e % ne, ('%', e, ne)),
(divmod(ne, e), ('divmod', ne, e)),
(divmod(e, ne), ('divmod', e, ne)),
(ne ** e, ('**', ne, e)),
(e ** ne, ('**', e, ne)),
(e < ne, ('>', ne, e)),
(ne < e, ('<', ne, e)),
(e > ne, ('<', ne, e)),
(ne > e, ('>', ne, e)),
(e <= ne, ('>=', ne, e)),
(ne <= e, ('<=', ne, e)),
(e >= ne, ('<=', ne, e)),
(ne >= e, ('>=', ne, e)),
]
for res, args in results:
assert type(res) is SpecialOp and res.args == args
# These classes do not support binary operators with Expr. Every
# operation should raise in combination with any of the Expr types.
for na in [NonArithmetic(), object()]:
raises(TypeError, lambda : e + na)
raises(TypeError, lambda : na + e)
raises(TypeError, lambda : e - na)
raises(TypeError, lambda : na - e)
raises(TypeError, lambda : e * na)
raises(TypeError, lambda : na * e)
raises(TypeError, lambda : e / na)
raises(TypeError, lambda : na / e)
raises(TypeError, lambda : e // na)
raises(TypeError, lambda : na // e)
raises(TypeError, lambda : e % na)
raises(TypeError, lambda : na % e)
raises(TypeError, lambda : divmod(e, na))
raises(TypeError, lambda : divmod(na, e))
raises(TypeError, lambda : e ** na)
raises(TypeError, lambda : na ** e)
raises(TypeError, lambda : e > na)
raises(TypeError, lambda : na > e)
raises(TypeError, lambda : e < na)
raises(TypeError, lambda : na < e)
raises(TypeError, lambda : e >= na)
raises(TypeError, lambda : na >= e)
raises(TypeError, lambda : e <= na)
raises(TypeError, lambda : na <= e)
def test_relational():
from sympy.core.relational import Lt
assert (pi < 3) is S.false
assert (pi <= 3) is S.false
assert (pi > 3) is S.true
assert (pi >= 3) is S.true
assert (-pi < 3) is S.true
assert (-pi <= 3) is S.true
assert (-pi > 3) is S.false
assert (-pi >= 3) is S.false
r = Symbol('r', real=True)
assert (r - 2 < r - 3) is S.false
assert Lt(x + I, x + I + 2).func == Lt # issue 8288
def test_relational_assumptions():
m1 = Symbol("m1", nonnegative=False)
m2 = Symbol("m2", positive=False)
m3 = Symbol("m3", nonpositive=False)
m4 = Symbol("m4", negative=False)
assert (m1 < 0) == Lt(m1, 0)
assert (m2 <= 0) == Le(m2, 0)
assert (m3 > 0) == Gt(m3, 0)
assert (m4 >= 0) == Ge(m4, 0)
m1 = Symbol("m1", nonnegative=False, real=True)
m2 = Symbol("m2", positive=False, real=True)
m3 = Symbol("m3", nonpositive=False, real=True)
m4 = Symbol("m4", negative=False, real=True)
assert (m1 < 0) is S.true
assert (m2 <= 0) is S.true
assert (m3 > 0) is S.true
assert (m4 >= 0) is S.true
m1 = Symbol("m1", negative=True)
m2 = Symbol("m2", nonpositive=True)
m3 = Symbol("m3", positive=True)
m4 = Symbol("m4", nonnegative=True)
assert (m1 < 0) is S.true
assert (m2 <= 0) is S.true
assert (m3 > 0) is S.true
assert (m4 >= 0) is S.true
m1 = Symbol("m1", negative=False, real=True)
m2 = Symbol("m2", nonpositive=False, real=True)
m3 = Symbol("m3", positive=False, real=True)
m4 = Symbol("m4", nonnegative=False, real=True)
assert (m1 < 0) is S.false
assert (m2 <= 0) is S.false
assert (m3 > 0) is S.false
assert (m4 >= 0) is S.false
# See https://github.com/sympy/sympy/issues/17708
#def test_relational_noncommutative():
# from sympy import Lt, Gt, Le, Ge
# A, B = symbols('A,B', commutative=False)
# assert (A < B) == Lt(A, B)
# assert (A <= B) == Le(A, B)
# assert (A > B) == Gt(A, B)
# assert (A >= B) == Ge(A, B)
def test_basic_nostr():
for obj in basic_objs:
raises(TypeError, lambda: obj + '1')
raises(TypeError, lambda: obj - '1')
if obj == 2:
assert obj * '1' == '11'
else:
raises(TypeError, lambda: obj * '1')
raises(TypeError, lambda: obj / '1')
raises(TypeError, lambda: obj ** '1')
def test_series_expansion_for_uniform_order():
assert (1/x + y + x).series(x, 0, 0) == 1/x + O(1, x)
assert (1/x + y + x).series(x, 0, 1) == 1/x + y + O(x)
assert (1/x + 1 + x).series(x, 0, 0) == 1/x + O(1, x)
assert (1/x + 1 + x).series(x, 0, 1) == 1/x + 1 + O(x)
assert (1/x + x).series(x, 0, 0) == 1/x + O(1, x)
assert (1/x + y + y*x + x).series(x, 0, 0) == 1/x + O(1, x)
assert (1/x + y + y*x + x).series(x, 0, 1) == 1/x + y + O(x)
def test_leadterm():
assert (3 + 2*x**(log(3)/log(2) - 1)).leadterm(x) == (3, 0)
assert (1/x**2 + 1 + x + x**2).leadterm(x)[1] == -2
assert (1/x + 1 + x + x**2).leadterm(x)[1] == -1
assert (x**2 + 1/x).leadterm(x)[1] == -1
assert (1 + x**2).leadterm(x)[1] == 0
assert (x + 1).leadterm(x)[1] == 0
assert (x + x**2).leadterm(x)[1] == 1
assert (x**2).leadterm(x)[1] == 2
def test_as_leading_term():
assert (3 + 2*x**(log(3)/log(2) - 1)).as_leading_term(x) == 3
assert (1/x**2 + 1 + x + x**2).as_leading_term(x) == 1/x**2
assert (1/x + 1 + x + x**2).as_leading_term(x) == 1/x
assert (x**2 + 1/x).as_leading_term(x) == 1/x
assert (1 + x**2).as_leading_term(x) == 1
assert (x + 1).as_leading_term(x) == 1
assert (x + x**2).as_leading_term(x) == x
assert (x**2).as_leading_term(x) == x**2
assert (x + oo).as_leading_term(x) is oo
raises(ValueError, lambda: (x + 1).as_leading_term(1))
# https://github.com/sympy/sympy/issues/21177
e = -3*x + (x + Rational(3, 2) - sqrt(3)*S.ImaginaryUnit/2)**2\
- Rational(3, 2) + 3*sqrt(3)*S.ImaginaryUnit/2
assert e.as_leading_term(x) == \
(12*sqrt(3)*x - 12*S.ImaginaryUnit*x)/(4*sqrt(3) + 12*S.ImaginaryUnit)
# https://github.com/sympy/sympy/issues/21245
e = 1 - x - x**2
d = (1 + sqrt(5))/2
assert e.subs(x, y + 1/d).as_leading_term(y) == \
(-576*sqrt(5)*y - 1280*y)/(256*sqrt(5) + 576)
def test_leadterm2():
assert (x*cos(1)*cos(1 + sin(1)) + sin(1 + sin(1))).leadterm(x) == \
(sin(1 + sin(1)), 0)
def test_leadterm3():
assert (y + z + x).leadterm(x) == (y + z, 0)
def test_as_leading_term2():
assert (x*cos(1)*cos(1 + sin(1)) + sin(1 + sin(1))).as_leading_term(x) == \
sin(1 + sin(1))
def test_as_leading_term3():
assert (2 + pi + x).as_leading_term(x) == 2 + pi
assert (2*x + pi*x + x**2).as_leading_term(x) == 2*x + pi*x
def test_as_leading_term4():
# see issue 6843
n = Symbol('n', integer=True, positive=True)
r = -n**3/(2*n**2 + 4*n + 2) - n**2/(n**2 + 2*n + 1) + \
n**2/(n + 1) - n/(2*n**2 + 4*n + 2) + n/(n*x + x) + 2*n/(n + 1) - \
1 + 1/(n*x + x) + 1/(n + 1) - 1/x
assert r.as_leading_term(x).cancel() == n/2
def test_as_leading_term_stub():
class foo(Function):
pass
assert foo(1/x).as_leading_term(x) == foo(1/x)
assert foo(1).as_leading_term(x) == foo(1)
raises(NotImplementedError, lambda: foo(x).as_leading_term(x))
def test_as_leading_term_deriv_integral():
# related to issue 11313
assert Derivative(x ** 3, x).as_leading_term(x) == 3*x**2
assert Derivative(x ** 3, y).as_leading_term(x) == 0
assert Integral(x ** 3, x).as_leading_term(x) == x**4/4
assert Integral(x ** 3, y).as_leading_term(x) == y*x**3
assert Derivative(exp(x), x).as_leading_term(x) == 1
assert Derivative(log(x), x).as_leading_term(x) == (1/x).as_leading_term(x)
def test_atoms():
assert x.atoms() == {x}
assert (1 + x).atoms() == {x, S.One}
assert (1 + 2*cos(x)).atoms(Symbol) == {x}
assert (1 + 2*cos(x)).atoms(Symbol, Number) == {S.One, S(2), x}
assert (2*(x**(y**x))).atoms() == {S(2), x, y}
assert S.Half.atoms() == {S.Half}
assert S.Half.atoms(Symbol) == set()
assert sin(oo).atoms(oo) == set()
assert Poly(0, x).atoms() == {S.Zero, x}
assert Poly(1, x).atoms() == {S.One, x}
assert Poly(x, x).atoms() == {x}
assert Poly(x, x, y).atoms() == {x, y}
assert Poly(x + y, x, y).atoms() == {x, y}
assert Poly(x + y, x, y, z).atoms() == {x, y, z}
assert Poly(x + y*t, x, y, z).atoms() == {t, x, y, z}
assert (I*pi).atoms(NumberSymbol) == {pi}
assert (I*pi).atoms(NumberSymbol, I) == \
(I*pi).atoms(I, NumberSymbol) == {pi, I}
assert exp(exp(x)).atoms(exp) == {exp(exp(x)), exp(x)}
assert (1 + x*(2 + y) + exp(3 + z)).atoms(Add) == \
{1 + x*(2 + y) + exp(3 + z), 2 + y, 3 + z}
# issue 6132
e = (f(x) + sin(x) + 2)
assert e.atoms(AppliedUndef) == \
{f(x)}
assert e.atoms(AppliedUndef, Function) == \
{f(x), sin(x)}
assert e.atoms(Function) == \
{f(x), sin(x)}
assert e.atoms(AppliedUndef, Number) == \
{f(x), S(2)}
assert e.atoms(Function, Number) == \
{S(2), sin(x), f(x)}
def test_is_polynomial():
k = Symbol('k', nonnegative=True, integer=True)
assert Rational(2).is_polynomial(x, y, z) is True
assert (S.Pi).is_polynomial(x, y, z) is True
assert x.is_polynomial(x) is True
assert x.is_polynomial(y) is True
assert (x**2).is_polynomial(x) is True
assert (x**2).is_polynomial(y) is True
assert (x**(-2)).is_polynomial(x) is False
assert (x**(-2)).is_polynomial(y) is True
assert (2**x).is_polynomial(x) is False
assert (2**x).is_polynomial(y) is True
assert (x**k).is_polynomial(x) is False
assert (x**k).is_polynomial(k) is False
assert (x**x).is_polynomial(x) is False
assert (k**k).is_polynomial(k) is False
assert (k**x).is_polynomial(k) is False
assert (x**(-k)).is_polynomial(x) is False
assert ((2*x)**k).is_polynomial(x) is False
assert (x**2 + 3*x - 8).is_polynomial(x) is True
assert (x**2 + 3*x - 8).is_polynomial(y) is True
assert (x**2 + 3*x - 8).is_polynomial() is True
assert sqrt(x).is_polynomial(x) is False
assert (sqrt(x)**3).is_polynomial(x) is False
assert (x**2 + 3*x*sqrt(y) - 8).is_polynomial(x) is True
assert (x**2 + 3*x*sqrt(y) - 8).is_polynomial(y) is False
assert ((x**2)*(y**2) + x*(y**2) + y*x + exp(2)).is_polynomial() is True
assert ((x**2)*(y**2) + x*(y**2) + y*x + exp(x)).is_polynomial() is False
assert (
(x**2)*(y**2) + x*(y**2) + y*x + exp(2)).is_polynomial(x, y) is True
assert (
(x**2)*(y**2) + x*(y**2) + y*x + exp(x)).is_polynomial(x, y) is False
assert (1/f(x) + 1).is_polynomial(f(x)) is False
def test_is_rational_function():
assert Integer(1).is_rational_function() is True
assert Integer(1).is_rational_function(x) is True
assert Rational(17, 54).is_rational_function() is True
assert Rational(17, 54).is_rational_function(x) is True
assert (12/x).is_rational_function() is True
assert (12/x).is_rational_function(x) is True
assert (x/y).is_rational_function() is True
assert (x/y).is_rational_function(x) is True
assert (x/y).is_rational_function(x, y) is True
assert (x**2 + 1/x/y).is_rational_function() is True
assert (x**2 + 1/x/y).is_rational_function(x) is True
assert (x**2 + 1/x/y).is_rational_function(x, y) is True
assert (sin(y)/x).is_rational_function() is False
assert (sin(y)/x).is_rational_function(y) is False
assert (sin(y)/x).is_rational_function(x) is True
assert (sin(y)/x).is_rational_function(x, y) is False
for i in _illegal:
assert not i.is_rational_function()
for d in (1, x):
assert not (i/d).is_rational_function()
def test_is_meromorphic():
f = a/x**2 + b + x + c*x**2
assert f.is_meromorphic(x, 0) is True
assert f.is_meromorphic(x, 1) is True
assert f.is_meromorphic(x, zoo) is True
g = 3 + 2*x**(log(3)/log(2) - 1)
assert g.is_meromorphic(x, 0) is False
assert g.is_meromorphic(x, 1) is True
assert g.is_meromorphic(x, zoo) is False
n = Symbol('n', integer=True)
e = sin(1/x)**n*x
assert e.is_meromorphic(x, 0) is False
assert e.is_meromorphic(x, 1) is True
assert e.is_meromorphic(x, zoo) is False
e = log(x)**pi
assert e.is_meromorphic(x, 0) is False
assert e.is_meromorphic(x, 1) is False
assert e.is_meromorphic(x, 2) is True
assert e.is_meromorphic(x, zoo) is False
assert (log(x)**a).is_meromorphic(x, 0) is False
assert (log(x)**a).is_meromorphic(x, 1) is False
assert (a**log(x)).is_meromorphic(x, 0) is None
assert (3**log(x)).is_meromorphic(x, 0) is False
assert (3**log(x)).is_meromorphic(x, 1) is True
def test_is_algebraic_expr():
assert sqrt(3).is_algebraic_expr(x) is True
assert sqrt(3).is_algebraic_expr() is True
eq = ((1 + x**2)/(1 - y**2))**(S.One/3)
assert eq.is_algebraic_expr(x) is True
assert eq.is_algebraic_expr(y) is True
assert (sqrt(x) + y**(S(2)/3)).is_algebraic_expr(x) is True
assert (sqrt(x) + y**(S(2)/3)).is_algebraic_expr(y) is True
assert (sqrt(x) + y**(S(2)/3)).is_algebraic_expr() is True
assert (cos(y)/sqrt(x)).is_algebraic_expr() is False
assert (cos(y)/sqrt(x)).is_algebraic_expr(x) is True
assert (cos(y)/sqrt(x)).is_algebraic_expr(y) is False
assert (cos(y)/sqrt(x)).is_algebraic_expr(x, y) is False
def test_SAGE1():
#see https://github.com/sympy/sympy/issues/3346
class MyInt:
def _sympy_(self):
return Integer(5)
m = MyInt()
e = Rational(2)*m
assert e == 10
raises(TypeError, lambda: Rational(2)*MyInt)
def test_SAGE2():
class MyInt:
def __int__(self):
return 5
assert sympify(MyInt()) == 5
e = Rational(2)*MyInt()
assert e == 10
raises(TypeError, lambda: Rational(2)*MyInt)
def test_SAGE3():
class MySymbol:
def __rmul__(self, other):
return ('mys', other, self)
o = MySymbol()
e = x*o
assert e == ('mys', x, o)
def test_len():
e = x*y
assert len(e.args) == 2
e = x + y + z
assert len(e.args) == 3
def test_doit():
a = Integral(x**2, x)
assert isinstance(a.doit(), Integral) is False
assert isinstance(a.doit(integrals=True), Integral) is False
assert isinstance(a.doit(integrals=False), Integral) is True
assert (2*Integral(x, x)).doit() == x**2
def test_attribute_error():
raises(AttributeError, lambda: x.cos())
raises(AttributeError, lambda: x.sin())
raises(AttributeError, lambda: x.exp())
def test_args():
assert (x*y).args in ((x, y), (y, x))
assert (x + y).args in ((x, y), (y, x))
assert (x*y + 1).args in ((x*y, 1), (1, x*y))
assert sin(x*y).args == (x*y,)
assert sin(x*y).args[0] == x*y
assert (x**y).args == (x, y)
assert (x**y).args[0] == x
assert (x**y).args[1] == y
def test_noncommutative_expand_issue_3757():
A, B, C = symbols('A,B,C', commutative=False)
assert A*B - B*A != 0
assert (A*(A + B)*B).expand() == A**2*B + A*B**2
assert (A*(A + B + C)*B).expand() == A**2*B + A*B**2 + A*C*B
def test_as_numer_denom():
a, b, c = symbols('a, b, c')
assert nan.as_numer_denom() == (nan, 1)
assert oo.as_numer_denom() == (oo, 1)
assert (-oo).as_numer_denom() == (-oo, 1)
assert zoo.as_numer_denom() == (zoo, 1)
assert (-zoo).as_numer_denom() == (zoo, 1)
assert x.as_numer_denom() == (x, 1)
assert (1/x).as_numer_denom() == (1, x)
assert (x/y).as_numer_denom() == (x, y)
assert (x/2).as_numer_denom() == (x, 2)
assert (x*y/z).as_numer_denom() == (x*y, z)
assert (x/(y*z)).as_numer_denom() == (x, y*z)
assert S.Half.as_numer_denom() == (1, 2)
assert (1/y**2).as_numer_denom() == (1, y**2)
assert (x/y**2).as_numer_denom() == (x, y**2)
assert ((x**2 + 1)/y).as_numer_denom() == (x**2 + 1, y)
assert (x*(y + 1)/y**7).as_numer_denom() == (x*(y + 1), y**7)
assert (x**-2).as_numer_denom() == (1, x**2)
assert (a/x + b/2/x + c/3/x).as_numer_denom() == \
(6*a + 3*b + 2*c, 6*x)
assert (a/x + b/2/x + c/3/y).as_numer_denom() == \
(2*c*x + y*(6*a + 3*b), 6*x*y)
assert (a/x + b/2/x + c/.5/x).as_numer_denom() == \
(2*a + b + 4.0*c, 2*x)
# this should take no more than a few seconds
assert int(log(Add(*[Dummy()/i/x for i in range(1, 705)]
).as_numer_denom()[1]/x).n(4)) == 705
for i in [S.Infinity, S.NegativeInfinity, S.ComplexInfinity]:
assert (i + x/3).as_numer_denom() == \
(x + i, 3)
assert (S.Infinity + x/3 + y/4).as_numer_denom() == \
(4*x + 3*y + S.Infinity, 12)
assert (oo*x + zoo*y).as_numer_denom() == \
(zoo*y + oo*x, 1)
A, B, C = symbols('A,B,C', commutative=False)
assert (A*B*C**-1).as_numer_denom() == (A*B*C**-1, 1)
assert (A*B*C**-1/x).as_numer_denom() == (A*B*C**-1, x)
assert (C**-1*A*B).as_numer_denom() == (C**-1*A*B, 1)
assert (C**-1*A*B/x).as_numer_denom() == (C**-1*A*B, x)
assert ((A*B*C)**-1).as_numer_denom() == ((A*B*C)**-1, 1)
assert ((A*B*C)**-1/x).as_numer_denom() == ((A*B*C)**-1, x)
# the following morphs from Add to Mul during processing
assert Add(0, (x + y)/z/-2, evaluate=False).as_numer_denom(
) == (-x - y, 2*z)
def test_trunc():
import math
x, y = symbols('x y')
assert math.trunc(2) == 2
assert math.trunc(4.57) == 4
assert math.trunc(-5.79) == -5
assert math.trunc(pi) == 3
assert math.trunc(log(7)) == 1
assert math.trunc(exp(5)) == 148
assert math.trunc(cos(pi)) == -1
assert math.trunc(sin(5)) == 0
raises(TypeError, lambda: math.trunc(x))
raises(TypeError, lambda: math.trunc(x + y**2))
raises(TypeError, lambda: math.trunc(oo))
def test_as_independent():
assert S.Zero.as_independent(x, as_Add=True) == (0, 0)
assert S.Zero.as_independent(x, as_Add=False) == (0, 0)
assert (2*x*sin(x) + y + x).as_independent(x) == (y, x + 2*x*sin(x))
assert (2*x*sin(x) + y + x).as_independent(y) == (x + 2*x*sin(x), y)
assert (2*x*sin(x) + y + x).as_independent(x, y) == (0, y + x + 2*x*sin(x))
assert (x*sin(x)*cos(y)).as_independent(x) == (cos(y), x*sin(x))
assert (x*sin(x)*cos(y)).as_independent(y) == (x*sin(x), cos(y))
assert (x*sin(x)*cos(y)).as_independent(x, y) == (1, x*sin(x)*cos(y))
assert (sin(x)).as_independent(x) == (1, sin(x))
assert (sin(x)).as_independent(y) == (sin(x), 1)
assert (2*sin(x)).as_independent(x) == (2, sin(x))
assert (2*sin(x)).as_independent(y) == (2*sin(x), 1)
# issue 4903 = 1766b
n1, n2, n3 = symbols('n1 n2 n3', commutative=False)
assert (n1 + n1*n2).as_independent(n2) == (n1, n1*n2)
assert (n2*n1 + n1*n2).as_independent(n2) == (0, n1*n2 + n2*n1)
assert (n1*n2*n1).as_independent(n2) == (n1, n2*n1)
assert (n1*n2*n1).as_independent(n1) == (1, n1*n2*n1)
assert (3*x).as_independent(x, as_Add=True) == (0, 3*x)
assert (3*x).as_independent(x, as_Add=False) == (3, x)
assert (3 + x).as_independent(x, as_Add=True) == (3, x)
assert (3 + x).as_independent(x, as_Add=False) == (1, 3 + x)
# issue 5479
assert (3*x).as_independent(Symbol) == (3, x)
# issue 5648
assert (n1*x*y).as_independent(x) == (n1*y, x)
assert ((x + n1)*(x - y)).as_independent(x) == (1, (x + n1)*(x - y))
assert ((x + n1)*(x - y)).as_independent(y) == (x + n1, x - y)
assert (DiracDelta(x - n1)*DiracDelta(x - y)).as_independent(x) \
== (1, DiracDelta(x - n1)*DiracDelta(x - y))
assert (x*y*n1*n2*n3).as_independent(n2) == (x*y*n1, n2*n3)
assert (x*y*n1*n2*n3).as_independent(n1) == (x*y, n1*n2*n3)
assert (x*y*n1*n2*n3).as_independent(n3) == (x*y*n1*n2, n3)
assert (DiracDelta(x - n1)*DiracDelta(y - n1)*DiracDelta(x - n2)).as_independent(y) == \
(DiracDelta(x - n1)*DiracDelta(x - n2), DiracDelta(y - n1))
# issue 5784
assert (x + Integral(x, (x, 1, 2))).as_independent(x, strict=True) == \
(Integral(x, (x, 1, 2)), x)
eq = Add(x, -x, 2, -3, evaluate=False)
assert eq.as_independent(x) == (-1, Add(x, -x, evaluate=False))
eq = Mul(x, 1/x, 2, -3, evaluate=False)
assert eq.as_independent(x) == (-6, Mul(x, 1/x, evaluate=False))
assert (x*y).as_independent(z, as_Add=True) == (x*y, 0)
@XFAIL
def test_call_2():
# TODO UndefinedFunction does not subclass Expr
assert (2*f)(x) == 2*f(x)
def test_replace():
e = log(sin(x)) + tan(sin(x**2))
assert e.replace(sin, cos) == log(cos(x)) + tan(cos(x**2))
assert e.replace(
sin, lambda a: sin(2*a)) == log(sin(2*x)) + tan(sin(2*x**2))
a = Wild('a')
b = Wild('b')
assert e.replace(sin(a), cos(a)) == log(cos(x)) + tan(cos(x**2))
assert e.replace(
sin(a), lambda a: sin(2*a)) == log(sin(2*x)) + tan(sin(2*x**2))
# test exact
assert (2*x).replace(a*x + b, b - a, exact=True) == 2*x
assert (2*x).replace(a*x + b, b - a) == 2*x
assert (2*x).replace(a*x + b, b - a, exact=False) == 2/x
assert (2*x).replace(a*x + b, lambda a, b: b - a, exact=True) == 2*x
assert (2*x).replace(a*x + b, lambda a, b: b - a) == 2*x
assert (2*x).replace(a*x + b, lambda a, b: b - a, exact=False) == 2/x
g = 2*sin(x**3)
assert g.replace(
lambda expr: expr.is_Number, lambda expr: expr**2) == 4*sin(x**9)
assert cos(x).replace(cos, sin, map=True) == (sin(x), {cos(x): sin(x)})
assert sin(x).replace(cos, sin) == sin(x)
cond, func = lambda x: x.is_Mul, lambda x: 2*x
assert (x*y).replace(cond, func, map=True) == (2*x*y, {x*y: 2*x*y})
assert (x*(1 + x*y)).replace(cond, func, map=True) == \
(2*x*(2*x*y + 1), {x*(2*x*y + 1): 2*x*(2*x*y + 1), x*y: 2*x*y})
assert (y*sin(x)).replace(sin, lambda expr: sin(expr)/y, map=True) == \
(sin(x), {sin(x): sin(x)/y})
# if not simultaneous then y*sin(x) -> y*sin(x)/y = sin(x) -> sin(x)/y
assert (y*sin(x)).replace(sin, lambda expr: sin(expr)/y,
simultaneous=False) == sin(x)/y
assert (x**2 + O(x**3)).replace(Pow, lambda b, e: b**e/e
) == x**2/2 + O(x**3)
assert (x**2 + O(x**3)).replace(Pow, lambda b, e: b**e/e,
simultaneous=False) == x**2/2 + O(x**3)
assert (x*(x*y + 3)).replace(lambda x: x.is_Mul, lambda x: 2 + x) == \
x*(x*y + 5) + 2
e = (x*y + 1)*(2*x*y + 1) + 1
assert e.replace(cond, func, map=True) == (
2*((2*x*y + 1)*(4*x*y + 1)) + 1,
{2*x*y: 4*x*y, x*y: 2*x*y, (2*x*y + 1)*(4*x*y + 1):
2*((2*x*y + 1)*(4*x*y + 1))})
assert x.replace(x, y) == y
assert (x + 1).replace(1, 2) == x + 2
# https://groups.google.com/forum/#!topic/sympy/8wCgeC95tz0
n1, n2, n3 = symbols('n1:4', commutative=False)
assert (n1*f(n2)).replace(f, lambda x: x) == n1*n2
assert (n3*f(n2)).replace(f, lambda x: x) == n3*n2
# issue 16725
assert S.Zero.replace(Wild('x'), 1) == 1
# let the user override the default decision of False
assert S.Zero.replace(Wild('x'), 1, exact=True) == 0
def test_find():
expr = (x + y + 2 + sin(3*x))
assert expr.find(lambda u: u.is_Integer) == {S(2), S(3)}
assert expr.find(lambda u: u.is_Symbol) == {x, y}
assert expr.find(lambda u: u.is_Integer, group=True) == {S(2): 1, S(3): 1}
assert expr.find(lambda u: u.is_Symbol, group=True) == {x: 2, y: 1}
assert expr.find(Integer) == {S(2), S(3)}
assert expr.find(Symbol) == {x, y}
assert expr.find(Integer, group=True) == {S(2): 1, S(3): 1}
assert expr.find(Symbol, group=True) == {x: 2, y: 1}
a = Wild('a')
expr = sin(sin(x)) + sin(x) + cos(x) + x
assert expr.find(lambda u: type(u) is sin) == {sin(x), sin(sin(x))}
assert expr.find(
lambda u: type(u) is sin, group=True) == {sin(x): 2, sin(sin(x)): 1}
assert expr.find(sin(a)) == {sin(x), sin(sin(x))}
assert expr.find(sin(a), group=True) == {sin(x): 2, sin(sin(x)): 1}
assert expr.find(sin) == {sin(x), sin(sin(x))}
assert expr.find(sin, group=True) == {sin(x): 2, sin(sin(x)): 1}
def test_count():
expr = (x + y + 2 + sin(3*x))
assert expr.count(lambda u: u.is_Integer) == 2
assert expr.count(lambda u: u.is_Symbol) == 3
assert expr.count(Integer) == 2
assert expr.count(Symbol) == 3
assert expr.count(2) == 1
a = Wild('a')
assert expr.count(sin) == 1
assert expr.count(sin(a)) == 1
assert expr.count(lambda u: type(u) is sin) == 1
assert f(x).count(f(x)) == 1
assert f(x).diff(x).count(f(x)) == 1
assert f(x).diff(x).count(x) == 2
def test_has_basics():
p = Wild('p')
assert sin(x).has(x)
assert sin(x).has(sin)
assert not sin(x).has(y)
assert not sin(x).has(cos)
assert f(x).has(x)
assert f(x).has(f)
assert not f(x).has(y)
assert not f(x).has(g)
assert f(x).diff(x).has(x)
assert f(x).diff(x).has(f)
assert f(x).diff(x).has(Derivative)
assert not f(x).diff(x).has(y)
assert not f(x).diff(x).has(g)
assert not f(x).diff(x).has(sin)
assert (x**2).has(Symbol)
assert not (x**2).has(Wild)
assert (2*p).has(Wild)
assert not x.has()
def test_has_multiple():
f = x**2*y + sin(2**t + log(z))
assert f.has(x)
assert f.has(y)
assert f.has(z)
assert f.has(t)
assert not f.has(u)
assert f.has(x, y, z, t)
assert f.has(x, y, z, t, u)
i = Integer(4400)
assert not i.has(x)
assert (i*x**i).has(x)
assert not (i*y**i).has(x)
assert (i*y**i).has(x, y)
assert not (i*y**i).has(x, z)
def test_has_piecewise():
f = (x*y + 3/y)**(3 + 2)
p = Piecewise((g(x), x < -1), (1, x <= 1), (f, True))
assert p.has(x)
assert p.has(y)
assert not p.has(z)
assert p.has(1)
assert p.has(3)
assert not p.has(4)
assert p.has(f)
assert p.has(g)
assert not p.has(h)
def test_has_iterative():
A, B, C = symbols('A,B,C', commutative=False)
f = x*gamma(x)*sin(x)*exp(x*y)*A*B*C*cos(x*A*B)
assert f.has(x)
assert f.has(x*y)
assert f.has(x*sin(x))
assert not f.has(x*sin(y))
assert f.has(x*A)
assert f.has(x*A*B)
assert not f.has(x*A*C)
assert f.has(x*A*B*C)
assert not f.has(x*A*C*B)
assert f.has(x*sin(x)*A*B*C)
assert not f.has(x*sin(x)*A*C*B)
assert not f.has(x*sin(y)*A*B*C)
assert f.has(x*gamma(x))
assert not f.has(x + sin(x))
assert (x & y & z).has(x & z)
def test_has_integrals():
f = Integral(x**2 + sin(x*y*z), (x, 0, x + y + z))
assert f.has(x + y)
assert f.has(x + z)
assert f.has(y + z)
assert f.has(x*y)
assert f.has(x*z)
assert f.has(y*z)
assert not f.has(2*x + y)
assert not f.has(2*x*y)
def test_has_tuple():
assert Tuple(x, y).has(x)
assert not Tuple(x, y).has(z)
assert Tuple(f(x), g(x)).has(x)
assert not Tuple(f(x), g(x)).has(y)
assert Tuple(f(x), g(x)).has(f)
assert Tuple(f(x), g(x)).has(f(x))
# XXX to be deprecated
#assert not Tuple(f, g).has(x)
#assert Tuple(f, g).has(f)
#assert not Tuple(f, g).has(h)
assert Tuple(True).has(True)
assert Tuple(True).has(S.true)
assert not Tuple(True).has(1)
def test_has_units():
from sympy.physics.units import m, s
assert (x*m/s).has(x)
assert (x*m/s).has(y, z) is False
def test_has_polys():
poly = Poly(x**2 + x*y*sin(z), x, y, t)
assert poly.has(x)
assert poly.has(x, y, z)
assert poly.has(x, y, z, t)
def test_has_physics():
assert FockState((x, y)).has(x)
def test_as_poly_as_expr():
f = x**2 + 2*x*y
assert f.as_poly().as_expr() == f
assert f.as_poly(x, y).as_expr() == f
assert (f + sin(x)).as_poly(x, y) is None
p = Poly(f, x, y)
assert p.as_poly() == p
# https://github.com/sympy/sympy/issues/20610
assert S(2).as_poly() is None
assert sqrt(2).as_poly(extension=True) is None
raises(AttributeError, lambda: Tuple(x, x).as_poly(x))
raises(AttributeError, lambda: Tuple(x ** 2, x, y).as_poly(x))
def test_nonzero():
assert bool(S.Zero) is False
assert bool(S.One) is True
assert bool(x) is True
assert bool(x + y) is True
assert bool(x - x) is False
assert bool(x*y) is True
assert bool(x*1) is True
assert bool(x*0) is False
def test_is_number():
assert Float(3.14).is_number is True
assert Integer(737).is_number is True
assert Rational(3, 2).is_number is True
assert Rational(8).is_number is True
assert x.is_number is False
assert (2*x).is_number is False
assert (x + y).is_number is False
assert log(2).is_number is True
assert log(x).is_number is False
assert (2 + log(2)).is_number is True
assert (8 + log(2)).is_number is True
assert (2 + log(x)).is_number is False
assert (8 + log(2) + x).is_number is False
assert (1 + x**2/x - x).is_number is True
assert Tuple(Integer(1)).is_number is False
assert Add(2, x).is_number is False
assert Mul(3, 4).is_number is True
assert Pow(log(2), 2).is_number is True
assert oo.is_number is True
g = WildFunction('g')
assert g.is_number is False
assert (2*g).is_number is False
assert (x**2).subs(x, 3).is_number is True
# test extensibility of .is_number
# on subinstances of Basic
class A(Basic):
pass
a = A()
assert a.is_number is False
def test_as_coeff_add():
assert S(2).as_coeff_add() == (2, ())
assert S(3.0).as_coeff_add() == (0, (S(3.0),))
assert S(-3.0).as_coeff_add() == (0, (S(-3.0),))
assert x.as_coeff_add() == (0, (x,))
assert (x - 1).as_coeff_add() == (-1, (x,))
assert (x + 1).as_coeff_add() == (1, (x,))
assert (x + 2).as_coeff_add() == (2, (x,))
assert (x + y).as_coeff_add(y) == (x, (y,))
assert (3*x).as_coeff_add(y) == (3*x, ())
# don't do expansion
e = (x + y)**2
assert e.as_coeff_add(y) == (0, (e,))
def test_as_coeff_mul():
assert S(2).as_coeff_mul() == (2, ())
assert S(3.0).as_coeff_mul() == (1, (S(3.0),))
assert S(-3.0).as_coeff_mul() == (-1, (S(3.0),))
assert S(-3.0).as_coeff_mul(rational=False) == (-S(3.0), ())
assert x.as_coeff_mul() == (1, (x,))
assert (-x).as_coeff_mul() == (-1, (x,))
assert (2*x).as_coeff_mul() == (2, (x,))
assert (x*y).as_coeff_mul(y) == (x, (y,))
assert (3 + x).as_coeff_mul() == (1, (3 + x,))
assert (3 + x).as_coeff_mul(y) == (3 + x, ())
# don't do expansion
e = exp(x + y)
assert e.as_coeff_mul(y) == (1, (e,))
e = 2**(x + y)
assert e.as_coeff_mul(y) == (1, (e,))
assert (1.1*x).as_coeff_mul(rational=False) == (1.1, (x,))
assert (1.1*x).as_coeff_mul() == (1, (1.1, x))
assert (-oo*x).as_coeff_mul(rational=True) == (-1, (oo, x))
def test_as_coeff_exponent():
assert (3*x**4).as_coeff_exponent(x) == (3, 4)
assert (2*x**3).as_coeff_exponent(x) == (2, 3)
assert (4*x**2).as_coeff_exponent(x) == (4, 2)
assert (6*x**1).as_coeff_exponent(x) == (6, 1)
assert (3*x**0).as_coeff_exponent(x) == (3, 0)
assert (2*x**0).as_coeff_exponent(x) == (2, 0)
assert (1*x**0).as_coeff_exponent(x) == (1, 0)
assert (0*x**0).as_coeff_exponent(x) == (0, 0)
assert (-1*x**0).as_coeff_exponent(x) == (-1, 0)
assert (-2*x**0).as_coeff_exponent(x) == (-2, 0)
assert (2*x**3 + pi*x**3).as_coeff_exponent(x) == (2 + pi, 3)
assert (x*log(2)/(2*x + pi*x)).as_coeff_exponent(x) == \
(log(2)/(2 + pi), 0)
# issue 4784
D = Derivative
fx = D(f(x), x)
assert fx.as_coeff_exponent(f(x)) == (fx, 0)
def test_extractions():
for base in (2, S.Exp1):
assert Pow(base**x, 3, evaluate=False
).extract_multiplicatively(base**x) == base**(2*x)
assert (base**(5*x)).extract_multiplicatively(
base**(3*x)) == base**(2*x)
assert ((x*y)**3).extract_multiplicatively(x**2 * y) == x*y**2
assert ((x*y)**3).extract_multiplicatively(x**4 * y) is None
assert (2*x).extract_multiplicatively(2) == x
assert (2*x).extract_multiplicatively(3) is None
assert (2*x).extract_multiplicatively(-1) is None
assert (S.Half*x).extract_multiplicatively(3) == x/6
assert (sqrt(x)).extract_multiplicatively(x) is None
assert (sqrt(x)).extract_multiplicatively(1/x) is None
assert x.extract_multiplicatively(-x) is None
assert (-2 - 4*I).extract_multiplicatively(-2) == 1 + 2*I
assert (-2 - 4*I).extract_multiplicatively(3) is None
assert (-2*x - 4*y - 8).extract_multiplicatively(-2) == x + 2*y + 4
assert (-2*x*y - 4*x**2*y).extract_multiplicatively(-2*y) == 2*x**2 + x
assert (2*x*y + 4*x**2*y).extract_multiplicatively(2*y) == 2*x**2 + x
assert (-4*y**2*x).extract_multiplicatively(-3*y) is None
assert (2*x).extract_multiplicatively(1) == 2*x
assert (-oo).extract_multiplicatively(5) is -oo
assert (oo).extract_multiplicatively(5) is oo
assert ((x*y)**3).extract_additively(1) is None
assert (x + 1).extract_additively(x) == 1
assert (x + 1).extract_additively(2*x) is None
assert (x + 1).extract_additively(-x) is None
assert (-x + 1).extract_additively(2*x) is None
assert (2*x + 3).extract_additively(x) == x + 3
assert (2*x + 3).extract_additively(2) == 2*x + 1
assert (2*x + 3).extract_additively(3) == 2*x
assert (2*x + 3).extract_additively(-2) is None
assert (2*x + 3).extract_additively(3*x) is None
assert (2*x + 3).extract_additively(2*x) == 3
assert x.extract_additively(0) == x
assert S(2).extract_additively(x) is None
assert S(2.).extract_additively(2.) is S.Zero
assert S(2.).extract_additively(2) is S.Zero
assert S(2*x + 3).extract_additively(x + 1) == x + 2
assert S(2*x + 3).extract_additively(y + 1) is None
assert S(2*x - 3).extract_additively(x + 1) is None
assert S(2*x - 3).extract_additively(y + z) is None
assert ((a + 1)*x*4 + y).extract_additively(x).expand() == \
4*a*x + 3*x + y
assert ((a + 1)*x*4 + 3*y).extract_additively(x + 2*y).expand() == \
4*a*x + 3*x + y
assert (y*(x + 1)).extract_additively(x + 1) is None
assert ((y + 1)*(x + 1) + 3).extract_additively(x + 1) == \
y*(x + 1) + 3
assert ((x + y)*(x + 1) + x + y + 3).extract_additively(x + y) == \
x*(x + y) + 3
assert (x + y + 2*((x + y)*(x + 1)) + 3).extract_additively((x + y)*(x + 1)) == \
x + y + (x + 1)*(x + y) + 3
assert ((y + 1)*(x + 2*y + 1) + 3).extract_additively(y + 1) == \
(x + 2*y)*(y + 1) + 3
assert (-x - x*I).extract_additively(-x) == -I*x
# extraction does not leave artificats, now
assert (4*x*(y + 1) + y).extract_additively(x) == x*(4*y + 3) + y
n = Symbol("n", integer=True)
assert (Integer(-3)).could_extract_minus_sign() is True
assert (-n*x + x).could_extract_minus_sign() != \
(n*x - x).could_extract_minus_sign()
assert (x - y).could_extract_minus_sign() != \
(-x + y).could_extract_minus_sign()
assert (1 - x - y).could_extract_minus_sign() is True
assert (1 - x + y).could_extract_minus_sign() is False
assert ((-x - x*y)/y).could_extract_minus_sign() is False
assert ((x + x*y)/(-y)).could_extract_minus_sign() is True
assert ((x + x*y)/y).could_extract_minus_sign() is False
assert ((-x - y)/(x + y)).could_extract_minus_sign() is False
class sign_invariant(Function, Expr):
nargs = 1
def __neg__(self):
return self
foo = sign_invariant(x)
assert foo == -foo
assert foo.could_extract_minus_sign() is False
assert (x - y).could_extract_minus_sign() is False
assert (-x + y).could_extract_minus_sign() is True
assert (x - 1).could_extract_minus_sign() is False
assert (1 - x).could_extract_minus_sign() is True
assert (sqrt(2) - 1).could_extract_minus_sign() is True
assert (1 - sqrt(2)).could_extract_minus_sign() is False
# check that result is canonical
eq = (3*x + 15*y).extract_multiplicatively(3)
assert eq.args == eq.func(*eq.args).args
def test_nan_extractions():
for r in (1, 0, I, nan):
assert nan.extract_additively(r) is None
assert nan.extract_multiplicatively(r) is None
def test_coeff():
assert (x + 1).coeff(x + 1) == 1
assert (3*x).coeff(0) == 0
assert (z*(1 + x)*x**2).coeff(1 + x) == z*x**2
assert (1 + 2*x*x**(1 + x)).coeff(x*x**(1 + x)) == 2
assert (1 + 2*x**(y + z)).coeff(x**(y + z)) == 2
assert (3 + 2*x + 4*x**2).coeff(1) == 0
assert (3 + 2*x + 4*x**2).coeff(-1) == 0
assert (3 + 2*x + 4*x**2).coeff(x) == 2
assert (3 + 2*x + 4*x**2).coeff(x**2) == 4
assert (3 + 2*x + 4*x**2).coeff(x**3) == 0
assert (-x/8 + x*y).coeff(x) == Rational(-1, 8) + y
assert (-x/8 + x*y).coeff(-x) == S.One/8
assert (4*x).coeff(2*x) == 0
assert (2*x).coeff(2*x) == 1
assert (-oo*x).coeff(x*oo) == -1
assert (10*x).coeff(x, 0) == 0
assert (10*x).coeff(10*x, 0) == 0
n1, n2 = symbols('n1 n2', commutative=False)
assert (n1*n2).coeff(n1) == 1
assert (n1*n2).coeff(n2) == n1
assert (n1*n2 + x*n1).coeff(n1) == 1 # 1*n1*(n2+x)
assert (n2*n1 + x*n1).coeff(n1) == n2 + x
assert (n2*n1 + x*n1**2).coeff(n1) == n2
assert (n1**x).coeff(n1) == 0
assert (n1*n2 + n2*n1).coeff(n1) == 0
assert (2*(n1 + n2)*n2).coeff(n1 + n2, right=1) == n2
assert (2*(n1 + n2)*n2).coeff(n1 + n2, right=0) == 2
assert (2*f(x) + 3*f(x).diff(x)).coeff(f(x)) == 2
expr = z*(x + y)**2
expr2 = z*(x + y)**2 + z*(2*x + 2*y)**2
assert expr.coeff(z) == (x + y)**2
assert expr.coeff(x + y) == 0
assert expr2.coeff(z) == (x + y)**2 + (2*x + 2*y)**2
assert (x + y + 3*z).coeff(1) == x + y
assert (-x + 2*y).coeff(-1) == x
assert (x - 2*y).coeff(-1) == 2*y
assert (3 + 2*x + 4*x**2).coeff(1) == 0
assert (-x - 2*y).coeff(2) == -y
assert (x + sqrt(2)*x).coeff(sqrt(2)) == x
assert (3 + 2*x + 4*x**2).coeff(x) == 2
assert (3 + 2*x + 4*x**2).coeff(x**2) == 4
assert (3 + 2*x + 4*x**2).coeff(x**3) == 0
assert (z*(x + y)**2).coeff((x + y)**2) == z
assert (z*(x + y)**2).coeff(x + y) == 0
assert (2 + 2*x + (x + 1)*y).coeff(x + 1) == y
assert (x + 2*y + 3).coeff(1) == x
assert (x + 2*y + 3).coeff(x, 0) == 2*y + 3
assert (x**2 + 2*y + 3*x).coeff(x**2, 0) == 2*y + 3*x
assert x.coeff(0, 0) == 0
assert x.coeff(x, 0) == 0
n, m, o, l = symbols('n m o l', commutative=False)
assert n.coeff(n) == 1
assert y.coeff(n) == 0
assert (3*n).coeff(n) == 3
assert (2 + n).coeff(x*m) == 0
assert (2*x*n*m).coeff(x) == 2*n*m
assert (2 + n).coeff(x*m*n + y) == 0
assert (2*x*n*m).coeff(3*n) == 0
assert (n*m + m*n*m).coeff(n) == 1 + m
assert (n*m + m*n*m).coeff(n, right=True) == m # = (1 + m)*n*m
assert (n*m + m*n).coeff(n) == 0
assert (n*m + o*m*n).coeff(m*n) == o
assert (n*m + o*m*n).coeff(m*n, right=True) == 1
assert (n*m + n*m*n).coeff(n*m, right=True) == 1 + n # = n*m*(n + 1)
assert (x*y).coeff(z, 0) == x*y
assert (x*n + y*n + z*m).coeff(n) == x + y
assert (n*m + n*o + o*l).coeff(n, right=True) == m + o
assert (x*n*m*n + y*n*m*o + z*l).coeff(m, right=True) == x*n + y*o
assert (x*n*m*n + x*n*m*o + z*l).coeff(m, right=True) == n + o
assert (x*n*m*n + x*n*m*o + z*l).coeff(m) == x*n
def test_coeff2():
r, kappa = symbols('r, kappa')
psi = Function("psi")
g = 1/r**2 * (2*r*psi(r).diff(r, 1) + r**2 * psi(r).diff(r, 2))
g = g.expand()
assert g.coeff(psi(r).diff(r)) == 2/r
def test_coeff2_0():
r, kappa = symbols('r, kappa')
psi = Function("psi")
g = 1/r**2 * (2*r*psi(r).diff(r, 1) + r**2 * psi(r).diff(r, 2))
g = g.expand()
assert g.coeff(psi(r).diff(r, 2)) == 1
def test_coeff_expand():
expr = z*(x + y)**2
expr2 = z*(x + y)**2 + z*(2*x + 2*y)**2
assert expr.coeff(z) == (x + y)**2
assert expr2.coeff(z) == (x + y)**2 + (2*x + 2*y)**2
def test_integrate():
assert x.integrate(x) == x**2/2
assert x.integrate((x, 0, 1)) == S.Half
def test_as_base_exp():
assert x.as_base_exp() == (x, S.One)
assert (x*y*z).as_base_exp() == (x*y*z, S.One)
assert (x + y + z).as_base_exp() == (x + y + z, S.One)
assert ((x + y)**z).as_base_exp() == (x + y, z)
def test_issue_4963():
assert hasattr(Mul(x, y), "is_commutative")
assert hasattr(Mul(x, y, evaluate=False), "is_commutative")
assert hasattr(Pow(x, y), "is_commutative")
assert hasattr(Pow(x, y, evaluate=False), "is_commutative")
expr = Mul(Pow(2, 2, evaluate=False), 3, evaluate=False) + 1
assert hasattr(expr, "is_commutative")
def test_action_verbs():
assert nsimplify(1/(exp(3*pi*x/5) + 1)) == \
(1/(exp(3*pi*x/5) + 1)).nsimplify()
assert ratsimp(1/x + 1/y) == (1/x + 1/y).ratsimp()
assert trigsimp(log(x), deep=True) == (log(x)).trigsimp(deep=True)
assert radsimp(1/(2 + sqrt(2))) == (1/(2 + sqrt(2))).radsimp()
assert radsimp(1/(a + b*sqrt(c)), symbolic=False) == \
(1/(a + b*sqrt(c))).radsimp(symbolic=False)
assert powsimp(x**y*x**z*y**z, combine='all') == \
(x**y*x**z*y**z).powsimp(combine='all')
assert (x**t*y**t).powsimp(force=True) == (x*y)**t
assert simplify(x**y*x**z*y**z) == (x**y*x**z*y**z).simplify()
assert together(1/x + 1/y) == (1/x + 1/y).together()
assert collect(a*x**2 + b*x**2 + a*x - b*x + c, x) == \
(a*x**2 + b*x**2 + a*x - b*x + c).collect(x)
assert apart(y/(y + 2)/(y + 1), y) == (y/(y + 2)/(y + 1)).apart(y)
assert combsimp(y/(x + 2)/(x + 1)) == (y/(x + 2)/(x + 1)).combsimp()
assert gammasimp(gamma(x)/gamma(x-5)) == (gamma(x)/gamma(x-5)).gammasimp()
assert factor(x**2 + 5*x + 6) == (x**2 + 5*x + 6).factor()
assert refine(sqrt(x**2)) == sqrt(x**2).refine()
assert cancel((x**2 + 5*x + 6)/(x + 2)) == ((x**2 + 5*x + 6)/(x + 2)).cancel()
def test_as_powers_dict():
assert x.as_powers_dict() == {x: 1}
assert (x**y*z).as_powers_dict() == {x: y, z: 1}
assert Mul(2, 2, evaluate=False).as_powers_dict() == {S(2): S(2)}
assert (x*y).as_powers_dict()[z] == 0
assert (x + y).as_powers_dict()[z] == 0
def test_as_coefficients_dict():
check = [S.One, x, y, x*y, 1]
assert [Add(3*x, 2*x, y, 3).as_coefficients_dict()[i] for i in check] == \
[3, 5, 1, 0, 3]
assert [Add(3*x, 2*x, y, 3, evaluate=False).as_coefficients_dict()[i]
for i in check] == [3, 5, 1, 0, 3]
assert [(3*x*y).as_coefficients_dict()[i] for i in check] == \
[0, 0, 0, 3, 0]
assert [(3.0*x*y).as_coefficients_dict()[i] for i in check] == \
[0, 0, 0, 3.0, 0]
assert (3.0*x*y).as_coefficients_dict()[3.0*x*y] == 0
eq = x*(x + 1)*a + x*b + c/x
assert eq.as_coefficients_dict(x) == {x: b, 1/x: c,
x*(x + 1): a}
assert eq.expand().as_coefficients_dict(x) == {x**2: a, x: a + b, 1/x: c}
assert x.as_coefficients_dict() == {x: S.One}
def test_args_cnc():
A = symbols('A', commutative=False)
assert (x + A).args_cnc() == \
[[], [x + A]]
assert (x + a).args_cnc() == \
[[a + x], []]
assert (x*a).args_cnc() == \
[[a, x], []]
assert (x*y*A*(A + 1)).args_cnc(cset=True) == \
[{x, y}, [A, 1 + A]]
assert Mul(x, x, evaluate=False).args_cnc(cset=True, warn=False) == \
[{x}, []]
assert Mul(x, x**2, evaluate=False).args_cnc(cset=True, warn=False) == \
[{x, x**2}, []]
raises(ValueError, lambda: Mul(x, x, evaluate=False).args_cnc(cset=True))
assert Mul(x, y, x, evaluate=False).args_cnc() == \
[[x, y, x], []]
# always split -1 from leading number
assert (-1.*x).args_cnc() == [[-1, 1.0, x], []]
def test_new_rawargs():
n = Symbol('n', commutative=False)
a = x + n
assert a.is_commutative is False
assert a._new_rawargs(x).is_commutative
assert a._new_rawargs(x, y).is_commutative
assert a._new_rawargs(x, n).is_commutative is False
assert a._new_rawargs(x, y, n).is_commutative is False
m = x*n
assert m.is_commutative is False
assert m._new_rawargs(x).is_commutative
assert m._new_rawargs(n).is_commutative is False
assert m._new_rawargs(x, y).is_commutative
assert m._new_rawargs(x, n).is_commutative is False
assert m._new_rawargs(x, y, n).is_commutative is False
assert m._new_rawargs(x, n, reeval=False).is_commutative is False
assert m._new_rawargs(S.One) is S.One
def test_issue_5226():
assert Add(evaluate=False) == 0
assert Mul(evaluate=False) == 1
assert Mul(x + y, evaluate=False).is_Add
def test_free_symbols():
# free_symbols should return the free symbols of an object
assert S.One.free_symbols == set()
assert x.free_symbols == {x}
assert Integral(x, (x, 1, y)).free_symbols == {y}
assert (-Integral(x, (x, 1, y))).free_symbols == {y}
assert meter.free_symbols == set()
assert (meter**x).free_symbols == {x}
def test_has_free():
assert x.has_free(x)
assert not x.has_free(y)
assert (x + y).has_free(x)
assert (x + y).has_free(*(x, z))
assert f(x).has_free(x)
assert f(x).has_free(f(x))
assert Integral(f(x), (f(x), 1, y)).has_free(y)
assert not Integral(f(x), (f(x), 1, y)).has_free(x)
assert not Integral(f(x), (f(x), 1, y)).has_free(f(x))
# simple extraction
assert (x + 1 + y).has_free(x + 1)
assert not (x + 2 + y).has_free(x + 1)
assert (2 + 3*x*y).has_free(3*x)
raises(TypeError, lambda: x.has_free({x, y}))
s = FiniteSet(1, 2)
assert Piecewise((s, x > 3), (4, True)).has_free(s)
assert not Piecewise((1, x > 3), (4, True)).has_free(s)
# can't make set of these, but fallback will handle
raises(TypeError, lambda: x.has_free(y, []))
def test_has_xfree():
assert (x + 1).has_xfree({x})
assert ((x + 1)**2).has_xfree({x + 1})
assert not (x + y + 1).has_xfree({x + 1})
raises(TypeError, lambda: x.has_xfree(x))
raises(TypeError, lambda: x.has_xfree([x]))
def test_issue_5300():
x = Symbol('x', commutative=False)
assert x*sqrt(2)/sqrt(6) == x*sqrt(3)/3
def test_floordiv():
from sympy.functions.elementary.integers import floor
assert x // y == floor(x / y)
def test_as_coeff_Mul():
assert Integer(3).as_coeff_Mul() == (Integer(3), Integer(1))
assert Rational(3, 4).as_coeff_Mul() == (Rational(3, 4), Integer(1))
assert Float(5.0).as_coeff_Mul() == (Float(5.0), Integer(1))
assert (Integer(3)*x).as_coeff_Mul() == (Integer(3), x)
assert (Rational(3, 4)*x).as_coeff_Mul() == (Rational(3, 4), x)
assert (Float(5.0)*x).as_coeff_Mul() == (Float(5.0), x)
assert (Integer(3)*x*y).as_coeff_Mul() == (Integer(3), x*y)
assert (Rational(3, 4)*x*y).as_coeff_Mul() == (Rational(3, 4), x*y)
assert (Float(5.0)*x*y).as_coeff_Mul() == (Float(5.0), x*y)
assert (x).as_coeff_Mul() == (S.One, x)
assert (x*y).as_coeff_Mul() == (S.One, x*y)
assert (-oo*x).as_coeff_Mul(rational=True) == (-1, oo*x)
def test_as_coeff_Add():
assert Integer(3).as_coeff_Add() == (Integer(3), Integer(0))
assert Rational(3, 4).as_coeff_Add() == (Rational(3, 4), Integer(0))
assert Float(5.0).as_coeff_Add() == (Float(5.0), Integer(0))
assert (Integer(3) + x).as_coeff_Add() == (Integer(3), x)
assert (Rational(3, 4) + x).as_coeff_Add() == (Rational(3, 4), x)
assert (Float(5.0) + x).as_coeff_Add() == (Float(5.0), x)
assert (Float(5.0) + x).as_coeff_Add(rational=True) == (0, Float(5.0) + x)
assert (Integer(3) + x + y).as_coeff_Add() == (Integer(3), x + y)
assert (Rational(3, 4) + x + y).as_coeff_Add() == (Rational(3, 4), x + y)
assert (Float(5.0) + x + y).as_coeff_Add() == (Float(5.0), x + y)
assert (x).as_coeff_Add() == (S.Zero, x)
assert (x*y).as_coeff_Add() == (S.Zero, x*y)
def test_expr_sorting():
exprs = [1/x**2, 1/x, sqrt(sqrt(x)), sqrt(x), x, sqrt(x)**3, x**2]
assert sorted(exprs, key=default_sort_key) == exprs
exprs = [x, 2*x, 2*x**2, 2*x**3, x**n, 2*x**n, sin(x), sin(x)**n,
sin(x**2), cos(x), cos(x**2), tan(x)]
assert sorted(exprs, key=default_sort_key) == exprs
exprs = [x + 1, x**2 + x + 1, x**3 + x**2 + x + 1]
assert sorted(exprs, key=default_sort_key) == exprs
exprs = [S(4), x - 3*I/2, x + 3*I/2, x - 4*I + 1, x + 4*I + 1]
assert sorted(exprs, key=default_sort_key) == exprs
exprs = [f(1), f(2), f(3), f(1, 2, 3), g(1), g(2), g(3), g(1, 2, 3)]
assert sorted(exprs, key=default_sort_key) == exprs
exprs = [f(x), g(x), exp(x), sin(x), cos(x), factorial(x)]
assert sorted(exprs, key=default_sort_key) == exprs
exprs = [Tuple(x, y), Tuple(x, z), Tuple(x, y, z)]
assert sorted(exprs, key=default_sort_key) == exprs
exprs = [[3], [1, 2]]
assert sorted(exprs, key=default_sort_key) == exprs
exprs = [[1, 2], [2, 3]]
assert sorted(exprs, key=default_sort_key) == exprs
exprs = [[1, 2], [1, 2, 3]]
assert sorted(exprs, key=default_sort_key) == exprs
exprs = [{x: -y}, {x: y}]
assert sorted(exprs, key=default_sort_key) == exprs
exprs = [{1}, {1, 2}]
assert sorted(exprs, key=default_sort_key) == exprs
a, b = exprs = [Dummy('x'), Dummy('x')]
assert sorted([b, a], key=default_sort_key) == exprs
def test_as_ordered_factors():
assert x.as_ordered_factors() == [x]
assert (2*x*x**n*sin(x)*cos(x)).as_ordered_factors() \
== [Integer(2), x, x**n, sin(x), cos(x)]
args = [f(1), f(2), f(3), f(1, 2, 3), g(1), g(2), g(3), g(1, 2, 3)]
expr = Mul(*args)
assert expr.as_ordered_factors() == args
A, B = symbols('A,B', commutative=False)
assert (A*B).as_ordered_factors() == [A, B]
assert (B*A).as_ordered_factors() == [B, A]
def test_as_ordered_terms():
assert x.as_ordered_terms() == [x]
assert (sin(x)**2*cos(x) + sin(x)*cos(x)**2 + 1).as_ordered_terms() \
== [sin(x)**2*cos(x), sin(x)*cos(x)**2, 1]
args = [f(1), f(2), f(3), f(1, 2, 3), g(1), g(2), g(3), g(1, 2, 3)]
expr = Add(*args)
assert expr.as_ordered_terms() == args
assert (1 + 4*sqrt(3)*pi*x).as_ordered_terms() == [4*pi*x*sqrt(3), 1]
assert ( 2 + 3*I).as_ordered_terms() == [2, 3*I]
assert (-2 + 3*I).as_ordered_terms() == [-2, 3*I]
assert ( 2 - 3*I).as_ordered_terms() == [2, -3*I]
assert (-2 - 3*I).as_ordered_terms() == [-2, -3*I]
assert ( 4 + 3*I).as_ordered_terms() == [4, 3*I]
assert (-4 + 3*I).as_ordered_terms() == [-4, 3*I]
assert ( 4 - 3*I).as_ordered_terms() == [4, -3*I]
assert (-4 - 3*I).as_ordered_terms() == [-4, -3*I]
e = x**2*y**2 + x*y**4 + y + 2
assert e.as_ordered_terms(order="lex") == [x**2*y**2, x*y**4, y, 2]
assert e.as_ordered_terms(order="grlex") == [x*y**4, x**2*y**2, y, 2]
assert e.as_ordered_terms(order="rev-lex") == [2, y, x*y**4, x**2*y**2]
assert e.as_ordered_terms(order="rev-grlex") == [2, y, x**2*y**2, x*y**4]
k = symbols('k')
assert k.as_ordered_terms(data=True) == ([(k, ((1.0, 0.0), (1,), ()))], [k])
def test_sort_key_atomic_expr():
from sympy.physics.units import m, s
assert sorted([-m, s], key=lambda arg: arg.sort_key()) == [-m, s]
def test_eval_interval():
assert exp(x)._eval_interval(*Tuple(x, 0, 1)) == exp(1) - exp(0)
# issue 4199
a = x/y
raises(NotImplementedError, lambda: a._eval_interval(x, S.Zero, oo)._eval_interval(y, oo, S.Zero))
raises(NotImplementedError, lambda: a._eval_interval(x, S.Zero, oo)._eval_interval(y, S.Zero, oo))
a = x - y
raises(NotImplementedError, lambda: a._eval_interval(x, S.One, oo)._eval_interval(y, oo, S.One))
raises(ValueError, lambda: x._eval_interval(x, None, None))
a = -y*Heaviside(x - y)
assert a._eval_interval(x, -oo, oo) == -y
assert a._eval_interval(x, oo, -oo) == y
def test_eval_interval_zoo():
# Test that limit is used when zoo is returned
assert Si(1/x)._eval_interval(x, S.Zero, S.One) == -pi/2 + Si(1)
def test_primitive():
assert (3*(x + 1)**2).primitive() == (3, (x + 1)**2)
assert (6*x + 2).primitive() == (2, 3*x + 1)
assert (x/2 + 3).primitive() == (S.Half, x + 6)
eq = (6*x + 2)*(x/2 + 3)
assert eq.primitive()[0] == 1
eq = (2 + 2*x)**2
assert eq.primitive()[0] == 1
assert (4.0*x).primitive() == (1, 4.0*x)
assert (4.0*x + y/2).primitive() == (S.Half, 8.0*x + y)
assert (-2*x).primitive() == (2, -x)
assert Add(5*z/7, 0.5*x, 3*y/2, evaluate=False).primitive() == \
(S.One/14, 7.0*x + 21*y + 10*z)
for i in [S.Infinity, S.NegativeInfinity, S.ComplexInfinity]:
assert (i + x/3).primitive() == \
(S.One/3, i + x)
assert (S.Infinity + 2*x/3 + 4*y/7).primitive() == \
(S.One/21, 14*x + 12*y + oo)
assert S.Zero.primitive() == (S.One, S.Zero)
def test_issue_5843():
a = 1 + x
assert (2*a).extract_multiplicatively(a) == 2
assert (4*a).extract_multiplicatively(2*a) == 2
assert ((3*a)*(2*a)).extract_multiplicatively(a) == 6*a
def test_is_constant():
from sympy.solvers.solvers import checksol
assert Sum(x, (x, 1, 10)).is_constant() is True
assert Sum(x, (x, 1, n)).is_constant() is False
assert Sum(x, (x, 1, n)).is_constant(y) is True
assert Sum(x, (x, 1, n)).is_constant(n) is False
assert Sum(x, (x, 1, n)).is_constant(x) is True
eq = a*cos(x)**2 + a*sin(x)**2 - a
assert eq.is_constant() is True
assert eq.subs({x: pi, a: 2}) == eq.subs({x: pi, a: 3}) == 0
assert x.is_constant() is False
assert x.is_constant(y) is True
assert log(x/y).is_constant() is False
assert checksol(x, x, Sum(x, (x, 1, n))) is False
assert checksol(x, x, Sum(x, (x, 1, n))) is False
assert f(1).is_constant
assert checksol(x, x, f(x)) is False
assert Pow(x, S.Zero, evaluate=False).is_constant() is True # == 1
assert Pow(S.Zero, x, evaluate=False).is_constant() is False # == 0 or 1
assert (2**x).is_constant() is False
assert Pow(S(2), S(3), evaluate=False).is_constant() is True
z1, z2 = symbols('z1 z2', zero=True)
assert (z1 + 2*z2).is_constant() is True
assert meter.is_constant() is True
assert (3*meter).is_constant() is True
assert (x*meter).is_constant() is False
def test_equals():
assert (-3 - sqrt(5) + (-sqrt(10)/2 - sqrt(2)/2)**2).equals(0)
assert (x**2 - 1).equals((x + 1)*(x - 1))
assert (cos(x)**2 + sin(x)**2).equals(1)
assert (a*cos(x)**2 + a*sin(x)**2).equals(a)
r = sqrt(2)
assert (-1/(r + r*x) + 1/r/(1 + x)).equals(0)
assert factorial(x + 1).equals((x + 1)*factorial(x))
assert sqrt(3).equals(2*sqrt(3)) is False
assert (sqrt(5)*sqrt(3)).equals(sqrt(3)) is False
assert (sqrt(5) + sqrt(3)).equals(0) is False
assert (sqrt(5) + pi).equals(0) is False
assert meter.equals(0) is False
assert (3*meter**2).equals(0) is False
eq = -(-1)**(S(3)/4)*6**(S.One/4) + (-6)**(S.One/4)*I
if eq != 0: # if canonicalization makes this zero, skip the test
assert eq.equals(0)
assert sqrt(x).equals(0) is False
# from integrate(x*sqrt(1 + 2*x), x);
# diff is zero only when assumptions allow
i = 2*sqrt(2)*x**(S(5)/2)*(1 + 1/(2*x))**(S(5)/2)/5 + \
2*sqrt(2)*x**(S(3)/2)*(1 + 1/(2*x))**(S(5)/2)/(-6 - 3/x)
ans = sqrt(2*x + 1)*(6*x**2 + x - 1)/15
diff = i - ans
assert diff.equals(0) is None # should be False, but previously this was False due to wrong intermediate result
assert diff.subs(x, Rational(-1, 2)/2) == 7*sqrt(2)/120
# there are regions for x for which the expression is True, for
# example, when x < -1/2 or x > 0 the expression is zero
p = Symbol('p', positive=True)
assert diff.subs(x, p).equals(0) is True
assert diff.subs(x, -1).equals(0) is True
# prove via minimal_polynomial or self-consistency
eq = sqrt(1 + sqrt(3)) + sqrt(3 + 3*sqrt(3)) - sqrt(10 + 6*sqrt(3))
assert eq.equals(0)
q = 3**Rational(1, 3) + 3
p = expand(q**3)**Rational(1, 3)
assert (p - q).equals(0)
# issue 6829
# eq = q*x + q/4 + x**4 + x**3 + 2*x**2 - S.One/3
# z = eq.subs(x, solve(eq, x)[0])
q = symbols('q')
z = (q*(-sqrt(-2*(-(q - S(7)/8)**S(2)/8 - S(2197)/13824)**(S.One/3) -
S(13)/12)/2 - sqrt((2*q - S(7)/4)/sqrt(-2*(-(q - S(7)/8)**S(2)/8 -
S(2197)/13824)**(S.One/3) - S(13)/12) + 2*(-(q - S(7)/8)**S(2)/8 -
S(2197)/13824)**(S.One/3) - S(13)/6)/2 - S.One/4) + q/4 + (-sqrt(-2*(-(q
- S(7)/8)**S(2)/8 - S(2197)/13824)**(S.One/3) - S(13)/12)/2 - sqrt((2*q
- S(7)/4)/sqrt(-2*(-(q - S(7)/8)**S(2)/8 - S(2197)/13824)**(S.One/3) -
S(13)/12) + 2*(-(q - S(7)/8)**S(2)/8 - S(2197)/13824)**(S.One/3) -
S(13)/6)/2 - S.One/4)**4 + (-sqrt(-2*(-(q - S(7)/8)**S(2)/8 -
S(2197)/13824)**(S.One/3) - S(13)/12)/2 - sqrt((2*q -
S(7)/4)/sqrt(-2*(-(q - S(7)/8)**S(2)/8 - S(2197)/13824)**(S.One/3) -
S(13)/12) + 2*(-(q - S(7)/8)**S(2)/8 - S(2197)/13824)**(S.One/3) -
S(13)/6)/2 - S.One/4)**3 + 2*(-sqrt(-2*(-(q - S(7)/8)**S(2)/8 -
S(2197)/13824)**(S.One/3) - S(13)/12)/2 - sqrt((2*q -
S(7)/4)/sqrt(-2*(-(q - S(7)/8)**S(2)/8 - S(2197)/13824)**(S.One/3) -
S(13)/12) + 2*(-(q - S(7)/8)**S(2)/8 - S(2197)/13824)**(S.One/3) -
S(13)/6)/2 - S.One/4)**2 - Rational(1, 3))
assert z.equals(0)
def test_random():
from sympy.functions.combinatorial.numbers import lucas
from sympy.simplify.simplify import posify
assert posify(x)[0]._random() is not None
assert lucas(n)._random(2, -2, 0, -1, 1) is None
# issue 8662
assert Piecewise((Max(x, y), z))._random() is None
def test_round():
assert str(Float('0.1249999').round(2)) == '0.12'
d20 = 12345678901234567890
ans = S(d20).round(2)
assert ans.is_Integer and ans == d20
ans = S(d20).round(-2)
assert ans.is_Integer and ans == 12345678901234567900
assert str(S('1/7').round(4)) == '0.1429'
assert str(S('.[12345]').round(4)) == '0.1235'
assert str(S('.1349').round(2)) == '0.13'
n = S(12345)
ans = n.round()
assert ans.is_Integer
assert ans == n
ans = n.round(1)
assert ans.is_Integer
assert ans == n
ans = n.round(4)
assert ans.is_Integer
assert ans == n
assert n.round(-1) == 12340
r = Float(str(n)).round(-4)
assert r == 10000.0
assert n.round(-5) == 0
assert str((pi + sqrt(2)).round(2)) == '4.56'
assert (10*(pi + sqrt(2))).round(-1) == 50.0
raises(TypeError, lambda: round(x + 2, 2))
assert str(S(2.3).round(1)) == '2.3'
# rounding in SymPy (as in Decimal) should be
# exact for the given precision; we check here
# that when a 5 follows the last digit that
# the rounded digit will be even.
for i in range(-99, 100):
# construct a decimal that ends in 5, e.g. 123 -> 0.1235
s = str(abs(i))
p = len(s) # we are going to round to the last digit of i
n = '0.%s5' % s # put a 5 after i's digits
j = p + 2 # 2 for '0.'
if i < 0: # 1 for '-'
j += 1
n = '-' + n
v = str(Float(n).round(p))[:j] # pertinent digits
if v.endswith('.'):
continue # it ends with 0 which is even
L = int(v[-1]) # last digit
assert L % 2 == 0, (n, '->', v)
assert (Float(.3, 3) + 2*pi).round() == 7
assert (Float(.3, 3) + 2*pi*100).round() == 629
assert (pi + 2*E*I).round() == 3 + 5*I
# don't let request for extra precision give more than
# what is known (in this case, only 3 digits)
assert str((Float(.03, 3) + 2*pi/100).round(5)) == '0.0928'
assert str((Float(.03, 3) + 2*pi/100).round(4)) == '0.0928'
assert S.Zero.round() == 0
a = (Add(1, Float('1.' + '9'*27, ''), evaluate=0))
assert a.round(10) == Float('3.000000000000000000000000000', '')
assert a.round(25) == Float('3.000000000000000000000000000', '')
assert a.round(26) == Float('3.000000000000000000000000000', '')
assert a.round(27) == Float('2.999999999999999999999999999', '')
assert a.round(30) == Float('2.999999999999999999999999999', '')
# XXX: Should round set the precision of the result?
# The previous version of the tests above is this but they only pass
# because Floats with unequal precision compare equal:
#
# assert a.round(10) == Float('3.0000000000', '')
# assert a.round(25) == Float('3.0000000000000000000000000', '')
# assert a.round(26) == Float('3.00000000000000000000000000', '')
# assert a.round(27) == Float('2.999999999999999999999999999', '')
# assert a.round(30) == Float('2.999999999999999999999999999', '')
raises(TypeError, lambda: x.round())
raises(TypeError, lambda: f(1).round())
# exact magnitude of 10
assert str(S.One.round()) == '1'
assert str(S(100).round()) == '100'
# applied to real and imaginary portions
assert (2*pi + E*I).round() == 6 + 3*I
assert (2*pi + I/10).round() == 6
assert (pi/10 + 2*I).round() == 2*I
# the lhs re and im parts are Float with dps of 2
# and those on the right have dps of 15 so they won't compare
# equal unless we use string or compare components (which will
# then coerce the floats to the same precision) or re-create
# the floats
assert str((pi/10 + E*I).round(2)) == '0.31 + 2.72*I'
assert str((pi/10 + E*I).round(2).as_real_imag()) == '(0.31, 2.72)'
assert str((pi/10 + E*I).round(2)) == '0.31 + 2.72*I'
# issue 6914
assert (I**(I + 3)).round(3) == Float('-0.208', '')*I
# issue 8720
assert S(-123.6).round() == -124
assert S(-1.5).round() == -2
assert S(-100.5).round() == -100
assert S(-1.5 - 10.5*I).round() == -2 - 10*I
# issue 7961
assert str(S(0.006).round(2)) == '0.01'
assert str(S(0.00106).round(4)) == '0.0011'
# issue 8147
assert S.NaN.round() is S.NaN
assert S.Infinity.round() is S.Infinity
assert S.NegativeInfinity.round() is S.NegativeInfinity
assert S.ComplexInfinity.round() is S.ComplexInfinity
# check that types match
for i in range(2):
fi = float(i)
# 2 args
assert all(type(round(i, p)) is int for p in (-1, 0, 1))
assert all(S(i).round(p).is_Integer for p in (-1, 0, 1))
assert all(type(round(fi, p)) is float for p in (-1, 0, 1))
assert all(S(fi).round(p).is_Float for p in (-1, 0, 1))
# 1 arg (p is None)
assert type(round(i)) is int
assert S(i).round().is_Integer
assert type(round(fi)) is int
assert S(fi).round().is_Integer
def test_held_expression_UnevaluatedExpr():
x = symbols("x")
he = UnevaluatedExpr(1/x)
e1 = x*he
assert isinstance(e1, Mul)
assert e1.args == (x, he)
assert e1.doit() == 1
assert UnevaluatedExpr(Derivative(x, x)).doit(deep=False
) == Derivative(x, x)
assert UnevaluatedExpr(Derivative(x, x)).doit() == 1
xx = Mul(x, x, evaluate=False)
assert xx != x**2
ue2 = UnevaluatedExpr(xx)
assert isinstance(ue2, UnevaluatedExpr)
assert ue2.args == (xx,)
assert ue2.doit() == x**2
assert ue2.doit(deep=False) == xx
x2 = UnevaluatedExpr(2)*2
assert type(x2) is Mul
assert x2.args == (2, UnevaluatedExpr(2))
def test_round_exception_nostr():
# Don't use the string form of the expression in the round exception, as
# it's too slow
s = Symbol('bad')
try:
s.round()
except TypeError as e:
assert 'bad' not in str(e)
else:
# Did not raise
raise AssertionError("Did not raise")
def test_extract_branch_factor():
assert exp_polar(2.0*I*pi).extract_branch_factor() == (1, 1)
def test_identity_removal():
assert Add.make_args(x + 0) == (x,)
assert Mul.make_args(x*1) == (x,)
def test_float_0():
assert Float(0.0) + 1 == Float(1.0)
@XFAIL
def test_float_0_fail():
assert Float(0.0)*x == Float(0.0)
assert (x + Float(0.0)).is_Add
def test_issue_6325():
ans = (b**2 + z**2 - (b*(a + b*t) + z*(c + t*z))**2/(
(a + b*t)**2 + (c + t*z)**2))/sqrt((a + b*t)**2 + (c + t*z)**2)
e = sqrt((a + b*t)**2 + (c + z*t)**2)
assert diff(e, t, 2) == ans
assert e.diff(t, 2) == ans
assert diff(e, t, 2, simplify=False) != ans
def test_issue_7426():
f1 = a % c
f2 = x % z
assert f1.equals(f2) is None
def test_issue_11122():
x = Symbol('x', extended_positive=False)
assert unchanged(Gt, x, 0) # (x > 0)
# (x > 0) should remain unevaluated after PR #16956
x = Symbol('x', positive=False, real=True)
assert (x > 0) is S.false
def test_issue_10651():
x = Symbol('x', real=True)
e1 = (-1 + x)/(1 - x)
e3 = (4*x**2 - 4)/((1 - x)*(1 + x))
e4 = 1/(cos(x)**2) - (tan(x))**2
x = Symbol('x', positive=True)
e5 = (1 + x)/x
assert e1.is_constant() is None
assert e3.is_constant() is None
assert e4.is_constant() is None
assert e5.is_constant() is False
def test_issue_10161():
x = symbols('x', real=True)
assert x*abs(x)*abs(x) == x**3
def test_issue_10755():
x = symbols('x')
raises(TypeError, lambda: int(log(x)))
raises(TypeError, lambda: log(x).round(2))
def test_issue_11877():
x = symbols('x')
assert integrate(log(S.Half - x), (x, 0, S.Half)) == Rational(-1, 2) -log(2)/2
def test_normal():
x = symbols('x')
e = Mul(S.Half, 1 + x, evaluate=False)
assert e.normal() == e
def test_expr():
x = symbols('x')
raises(TypeError, lambda: tan(x).series(x, 2, oo, "+"))
def test_ExprBuilder():
eb = ExprBuilder(Mul)
eb.args.extend([x, x])
assert eb.build() == x**2
def test_issue_22020():
from sympy.parsing.sympy_parser import parse_expr
x = parse_expr("log((2*V/3-V)/C)/-(R+r)*C")
y = parse_expr("log((2*V/3-V)/C)/-(R+r)*2")
assert x.equals(y) is False
def test_non_string_equality():
# Expressions should not compare equal to strings
x = symbols('x')
one = sympify(1)
assert (x == 'x') is False
assert (x != 'x') is True
assert (one == '1') is False
assert (one != '1') is True
assert (x + 1 == 'x + 1') is False
assert (x + 1 != 'x + 1') is True
# Make sure == doesn't try to convert the resulting expression to a string
# (e.g., by calling sympify() instead of _sympify())
class BadRepr:
def __repr__(self):
raise RuntimeError
assert (x == BadRepr()) is False
assert (x != BadRepr()) is True
def test_21494():
from sympy.testing.pytest import warns_deprecated_sympy
with warns_deprecated_sympy():
assert x.expr_free_symbols == {x}
with warns_deprecated_sympy():
assert Basic().expr_free_symbols == set()
with warns_deprecated_sympy():
assert S(2).expr_free_symbols == {S(2)}
with warns_deprecated_sympy():
assert Indexed("A", x).expr_free_symbols == {Indexed("A", x)}
with warns_deprecated_sympy():
assert Subs(x, x, 0).expr_free_symbols == set()
def test_Expr__eq__iterable_handling():
assert x != range(3)
def test_format():
assert '{:1.2f}'.format(S.Zero) == '0.00'
assert '{:+3.0f}'.format(S(3)) == ' +3'
assert '{:23.20f}'.format(pi) == ' 3.14159265358979323846'
assert '{:50.48f}'.format(exp(sin(1))) == '2.319776824715853173956590377503266813254904772376'
Back to Directory
File Manager