我们本章会讨论:

  • Python 如何处理终追运算符中不同类型的操作数
  • 使用鸭子类型或显式类型检查处理不同类型的操作数
  • 中缀运算符如何表明自己无法处理的操作数
  • 众多比较运算符(如 ==,>,<= 等等)的特殊行为
  • 增量赋值运算符(如 += )的默认处理方式和重载方式

运算符重载基础

在某些圈子里,运算符重载名声不太好,因为总被滥用.Python 加了一些限制,做好了灵活性,可用性和安全性的平衡

  • 不能重载内置运算符
  • 不能新建运算符,只能重载现有的
  • 某些运算符不能重载 -- is,and,or 和 not(不过位运算 &,| 和 ~ 可以)

一元运算符

- (__neg__) 一元取负运算符,如果 x 是 -2, -x == 2

+ (__pos__) 一元取正运算符,通常 x == +x,但也有一些例外

~ (__invert__) 对整数按位去饭,定义 ~x == -(x + 1),如果 x 是 2, ~x == -3

支持一元操作符只需要实现相应的特殊方法,这些方法只有一个 self 参数,然后使用符合所在类的逻辑实现。不过,要遵守运算符的一个基本规则:始终返回一个新对象。也就是不能修改 self

对于 -+ 来说,结果可能是与 self 属于同一类的实例,多数的时候, + 最好返回 self 的副本。abs(...) 的结果应该是一个标量,但是对于 ~ 来说,很难说明什么结果是合理的,因为可能处理的不是整数,例如 ORM 中,SQL WHERE 子句应该返回反集

我们将把 -+ 运算符添加到第 10 章的例子中:


In [2]:
from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools

class Vector:
    typecode = 'd'
    
    def __init__(self, components):
        self._components = array(self.typecode, components) 
    
    def __iter__(self):
        return iter(self._components)
    
    def __repr__(self):
        components = reprlib.repr(self._components) 
        components = components[components.find('['):-1] 
        return 'Vector({})'.format(components)
    
    def __str__(self):
        return str(tuple(self))
    
    def __bytes__(self):
        return (bytes([ord(self.typecode)]) + 
                bytes(self._components))     
    
    def __eq__(self, other):
        return len(self) == len(other) and all(a == b for a, b in zip(self, other))
    
    def __abs__(self):
        return math.sqrt(sum(x * x for x in self)) 
    
    def __bool__(self):
        return bool(abs(self))
    
    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv) 
    
    # 上面都一样
    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)                # 获取实例所属的类
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral): # index 是 int 或其他整数类型
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            raise TypeError(msg.format(cls=cls))       

    shortcut_names = 'xyzt'

    def __getattr__(self, name):
        cls = type(self)

        if len(name) == 1:
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = '{.__name__!r} object has no attribute {!r}'
        raise AttributeError(msg.format(cls, name))
        
    def __setattr__(self, name, value):
        cls = type(self)
        if len(name) == 1:
            if name in cls.shortcut_names:
                error = 'readonly attribute {attr_name!}'
            elif name.islower():
                error = "can't set attributes 'a' to 'z' in {cls_name!r}"
        else:
            error = ''
        if error:
            msg = error.format(cls_name = cls.__name__, attr_name = name) # 这个方法好,无论错误是哪个,都可以给定值
            raise AttributeError(msg)
        super().__setattr__(name, value) # 默认情况,调用超类的 __setattr__ 方法,提供标准行为
        
    def __hash__(self):
        hashs = (hash(x) for x in self._components) # 注意这里是生成器表达式,不是列表推导式,可以节省内存
        return functools.reduce(operator.xor, hashs)

    def angle(self, n):
        r = math.sqrt(sum(x * x for x in self[n:]))
        a = math.atan2(r, self[n-1])
        if (n == len(self) - 1) and (self[-1] < 0):
            return math.pi * 2 - a
        else:
            return a
    
    def angles(self):
        return (self.angle(n) for n in range(1, len(self)))
        
    def __format__(self, fmt_spec=''):
        if fmt_spec.endswith('h'):
            fmt_spec = fmt_spec[:-1]
            coords = itertools.chain([abs(self)], # 使用 chain 函数生成生成器表达式,无缝迭代向量的模和各个角坐标
                                     self.angles())
            outer_fmt = '<{}>' # 球面坐标
            
        else:
            coords = self
            outer_fmt = '({})' # 笛卡尔坐标
        components = (format(c, fmt_spec) for c in coords)
        return outer_fmt.format(', '.join(components))
    
    
    def __neg__(self):
        return Vector(-x for x in self)
    
    def __pos__(self):
        return Vector(self)

因为 Vector 实例是可迭代对象,而且 Vector.__init__ 的参数是可迭代对象,所以我们的 __neg___pos__ 的实现短小精悍

我们不打算实现 __invert__ 方法,因此如果用户想计算 ~v,Python 会抛出 TypeError。

x 和 +x 何时不相等

在 Python 中几乎所有情况下 x == +x,但是在标准库中找到两例 x != +x 的情况

decimal.Decimal 类,如果 x 是 Decimal 类的实例,在 算数运算的上下文中创建,然后在不同的上下文中计算 +x,那么 x != +x。例如,x 所在的上下文用某个精度,而计算 +x 时,精度变了,如下所示:


In [3]:
import decimal
ctx = decimal.getcontext()
ctx.prec = 40                # 精度设为 40
one_third = decimal.Decimal('1') / decimal.Decimal('3')
one_third


Out[3]:
Decimal('0.3333333333333333333333333333333333333333')

In [4]:
one_third == +one_third


Out[4]:
True

In [5]:
ctx.prec = 28  #精度设为 28
one_third == +one_third


Out[5]:
False

In [6]:
+one_third


Out[6]:
Decimal('0.3333333333333333333333333333')

虽然每个 +one_third 表达式都会使用 one_third 的值创建一个新的 Decimal 实例,但是会使用当前算数运算符上下文的精度

第二例在 collections.Counter 的文档中。类实现了几个算数运算符,例如中缀运算符 +,作用是把两个 Counter 实例的计数器加在一起,然而,从使用角度出发,Counter 相加时,负值和零值计数会从结果中剔除,而一元运算符 + 等同于加上一个空的 Counter,因此产生一个新的 Counter 且仅保留大于 0 的计数器。


In [7]:
from collections import Counter

ct = Counter('abracadabra')
ct


Out[7]:
Counter({'a': 5, 'b': 2, 'c': 1, 'd': 1, 'r': 2})

In [8]:
ct['r'] = -3
ct['d'] = 0
ct


Out[8]:
Counter({'a': 5, 'b': 2, 'c': 1, 'd': 0, 'r': -3})

In [9]:
+ct


Out[9]:
Counter({'a': 5, 'b': 2, 'c': 1})

重载向量加法运算符 +

我们要为向量实现不定长的向量加法


In [12]:
from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools

class Vector:
    typecode = 'd'
    
    def __init__(self, components):
        self._components = array(self.typecode, components) 
    
    def __iter__(self):
        return iter(self._components)
    
    def __repr__(self):
        components = reprlib.repr(self._components) 
        components = components[components.find('['):-1] 
        return 'Vector({})'.format(components)
    
    def __str__(self):
        return str(tuple(self))
    
    def __bytes__(self):
        return (bytes([ord(self.typecode)]) + 
                bytes(self._components))     
    
    def __eq__(self, other):
        return len(self) == len(other) and all(a == b for a, b in zip(self, other))
    
    def __abs__(self):
        return math.sqrt(sum(x * x for x in self)) 
    
    def __bool__(self):
        return bool(abs(self))
    
    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv) 
    
    # 上面都一样
    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)                # 获取实例所属的类
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral): # index 是 int 或其他整数类型
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            raise TypeError(msg.format(cls=cls))       

    shortcut_names = 'xyzt'

    def __getattr__(self, name):
        cls = type(self)

        if len(name) == 1:
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = '{.__name__!r} object has no attribute {!r}'
        raise AttributeError(msg.format(cls, name))
        
    def __setattr__(self, name, value):
        cls = type(self)
        if len(name) == 1:
            if name in cls.shortcut_names:
                error = 'readonly attribute {attr_name!}'
            elif name.islower():
                error = "can't set attributes 'a' to 'z' in {cls_name!r}"
        else:
            error = ''
        if error:
            msg = error.format(cls_name = cls.__name__, attr_name = name) # 这个方法好,无论错误是哪个,都可以给定值
            raise AttributeError(msg)
        super().__setattr__(name, value) # 默认情况,调用超类的 __setattr__ 方法,提供标准行为
        
    def __hash__(self):
        hashs = (hash(x) for x in self._components) # 注意这里是生成器表达式,不是列表推导式,可以节省内存
        return functools.reduce(operator.xor, hashs)

    def angle(self, n):
        r = math.sqrt(sum(x * x for x in self[n:]))
        a = math.atan2(r, self[n-1])
        if (n == len(self) - 1) and (self[-1] < 0):
            return math.pi * 2 - a
        else:
            return a
    
    def angles(self):
        return (self.angle(n) for n in range(1, len(self)))
        
    def __format__(self, fmt_spec=''):
        if fmt_spec.endswith('h'):
            fmt_spec = fmt_spec[:-1]
            coords = itertools.chain([abs(self)], # 使用 chain 函数生成生成器表达式,无缝迭代向量的模和各个角坐标
                                     self.angles())
            outer_fmt = '<{}>' # 球面坐标
            
        else:
            coords = self
            outer_fmt = '({})' # 笛卡尔坐标
        components = (format(c, fmt_spec) for c in coords)
        return outer_fmt.format(', '.join(components))
    
    
    def __neg__(self):
        return Vector(-x for x in self)
    
    def __pos__(self):
        return Vector(self)
    
    
    def __add__(self, other):
        pairs = itertools.zip_longest(self, other, fillvalue=0.0)
        return Vector(a + b for a, b in pairs)

pairs 是个生成器,会生成 (a, b) 形式的元组,其中 a 来自 self, b 来自 other,如果 a 和 b 的长度不同,使用 fillvalue 填充较短的可迭代对象

实现一元运算符和中缀运算符时候不要修改操作数,只有增量赋值表达式可能会修改第一个操作数。

现在我们的加法也支持 Vector 之外的对象


In [14]:
v1 = Vector([3, 4, 5])
v1 + [10, 20, 30]


Out[14]:
Vector([13.0, 24.0, 35.0])

In [16]:
v2 = Vector([1, 2])
v1 + v2


Out[16]:
Vector([4.0, 6.0, 5.0])

zip_longest(...) 能处理任何可迭代对象,而且构建新 Vector 实例的生成器表达式仅仅是把 zip_longest(...) 生成的值对相加(a + b),因此可以使用任何生成数字元素的可迭代对象

然而,如果对调操作数,混合类型的加法就会失败:


In [17]:
v1 = Vector([3, 4, 5])
(10, 20, 30) + v1


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-17-2a88335a76c7> in <module>()
      1 v1 = Vector([3, 4, 5])
----> 2 (10, 20, 30) + v1

TypeError: can only concatenate tuple (not "Vector") to tuple

为了支持涉及不同类型的运算,Python 为中缀运算符特殊方法提供了特殊的分派机制,对于表达式 a + b 来说,会执行下面操作:

  1. 如果 a 有 __add__ 方法,而且返回值不是 NotImplemented,调用 a.__add__(b) 方法,返回结果

  2. 如果 a 没有 __add__ 方法, 或返回值是 NotImplemented,调用 b.__radd__(a) 返回结果

  3. 如果 b 没有 __radd__ 方法,或者返回为 NotImplemented,抛出 TypeError。( r 的含义是 reflected 或 reverse)

所以我们为了让混合类型加法可以正确运算,要实现 Vector.__radd__ 方法,这是一种后备机制

别把 NotImplemented 和 NotImplementedError 搞混了,前者是特殊的单例值,如果中缀运算符特殊方法不能处理给定的操作数,要把它反回给解释器,而后者是一种一场,抽象类中占位方法将它抛出(raise),提醒子类必须覆盖

最简单可用的 __radd__ 实现如下:


In [18]:
def __radd__(self, other):
    return self + other

前面的 Vector 对象的加法对象如果不可迭代,__add__ 就无法处理,而且提供的错误消息不是很有用


In [19]:
v1 + 1


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-19-76c8c91eaf2e> in <module>()
----> 1 v1 + 1

<ipython-input-12-820ea4db6ef9> in __add__(self, other)
    120 
    121     def __add__(self, other):
--> 122         pairs = itertools.zip_longest(self, other, fillvalue=0.0)
    123         return Vector(a + b for a, b in pairs)

TypeError: zip_longest argument #2 must support iteration

In [20]:
v1 + 'ABC'


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-20-30e5276cebc2> in <module>()
----> 1 v1 + 'ABC'

<ipython-input-12-820ea4db6ef9> in __add__(self, other)
    121     def __add__(self, other):
    122         pairs = itertools.zip_longest(self, other, fillvalue=0.0)
--> 123         return Vector(a + b for a, b in pairs)

<ipython-input-12-820ea4db6ef9> in __init__(self, components)
     11 
     12     def __init__(self, components):
---> 13         self._components = array(self.typecode, components)
     14 
     15     def __iter__(self):

<ipython-input-12-820ea4db6ef9> in <genexpr>(.0)
    121     def __add__(self, other):
    122         pairs = itertools.zip_longest(self, other, fillvalue=0.0)
--> 123         return Vector(a + b for a, b in pairs)

TypeError: unsupported operand type(s) for +: 'float' and 'str'

上面揭露的问题比晦涩难懂的错误消息更严重,如果由于类型不兼容导致特殊方法无法返回有效结果,应该返回 NoteImplemented,而不是抛出 TypeError,返回 NotImplemented 时,另一个操作数所属类型还有机会执行运算,Python 会尝试调用反向方法

为了遵守鸭子类型精神,我们不能测试 other 操作数类型,或者它的元素类型,我们要捕获异常,返回 NotImplemented。如果解释器还没有反转操作数,那么它将尝试去做,如果反向方法返回 NoteImplemented,那么 Python 会抛出 TypeError,并返回一个标准的错误消息,下面是 Vector 的加法特殊方法最终版:


In [21]:
from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools

class Vector:
    typecode = 'd'
    
    def __init__(self, components):
        self._components = array(self.typecode, components) 
    
    def __iter__(self):
        return iter(self._components)
    
    def __repr__(self):
        components = reprlib.repr(self._components) 
        components = components[components.find('['):-1] 
        return 'Vector({})'.format(components)
    
    def __str__(self):
        return str(tuple(self))
    
    def __bytes__(self):
        return (bytes([ord(self.typecode)]) + 
                bytes(self._components))     
    
    def __eq__(self, other):
        return len(self) == len(other) and all(a == b for a, b in zip(self, other))
    
    def __abs__(self):
        return math.sqrt(sum(x * x for x in self)) 
    
    def __bool__(self):
        return bool(abs(self))
    
    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv) 
    
    # 上面都一样
    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)                # 获取实例所属的类
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral): # index 是 int 或其他整数类型
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            raise TypeError(msg.format(cls=cls))       

    shortcut_names = 'xyzt'

    def __getattr__(self, name):
        cls = type(self)

        if len(name) == 1:
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = '{.__name__!r} object has no attribute {!r}'
        raise AttributeError(msg.format(cls, name))
        
    def __setattr__(self, name, value):
        cls = type(self)
        if len(name) == 1:
            if name in cls.shortcut_names:
                error = 'readonly attribute {attr_name!}'
            elif name.islower():
                error = "can't set attributes 'a' to 'z' in {cls_name!r}"
        else:
            error = ''
        if error:
            msg = error.format(cls_name = cls.__name__, attr_name = name) # 这个方法好,无论错误是哪个,都可以给定值
            raise AttributeError(msg)
        super().__setattr__(name, value) # 默认情况,调用超类的 __setattr__ 方法,提供标准行为
        
    def __hash__(self):
        hashs = (hash(x) for x in self._components) # 注意这里是生成器表达式,不是列表推导式,可以节省内存
        return functools.reduce(operator.xor, hashs)

    def angle(self, n):
        r = math.sqrt(sum(x * x for x in self[n:]))
        a = math.atan2(r, self[n-1])
        if (n == len(self) - 1) and (self[-1] < 0):
            return math.pi * 2 - a
        else:
            return a
    
    def angles(self):
        return (self.angle(n) for n in range(1, len(self)))
        
    def __format__(self, fmt_spec=''):
        if fmt_spec.endswith('h'):
            fmt_spec = fmt_spec[:-1]
            coords = itertools.chain([abs(self)], # 使用 chain 函数生成生成器表达式,无缝迭代向量的模和各个角坐标
                                     self.angles())
            outer_fmt = '<{}>' # 球面坐标
            
        else:
            coords = self
            outer_fmt = '({})' # 笛卡尔坐标
        components = (format(c, fmt_spec) for c in coords)
        return outer_fmt.format(', '.join(components))
    
    
    def __neg__(self):
        return Vector(-x for x in self)
    
    def __pos__(self):
        return Vector(self)
    
    
    def __add__(self, other):
        try:
            pairs = itertools.zip_longest(self, other, fillvalue=0.0)
            return Vector(a + b for a, b in pairs)
        except TypeError:
            return NotImplemented
    
    def __radd__(self, other):
        return self + other

重载标量乘法运算符

Vector([1, 2, 3]) x 是啥意思?如果 x 是数字,就是标量乘积,结果是一个新的 Vector 实例,还有一种是两个向量的点积。也就是 1 `N 和 N*` 1 的两个矩阵的乘法。

Numpy 库目前做法是,不重载这两种意义的 *,只用 * 计算标量乘积,例如 Numpy 中,点积使用 numpy.dot() 函数计算

从 Python 3.5 开始,@ 符号可以用作中缀点积运算符

标量积中,我们依然先实现最简单可用的 __mul____rmul__ 方法:


In [1]:
def __mul__(self, scalar):
    return Vector(n * scalar for n in self)

def __rmul__(self, scalar):
    return self * scalar

这两个方法确实可用,但是提供不兼容操作数时候会出问题。scalar 参数的值要是个数字,与浮点数相乘得到的的积是另一个浮点数(因为 Vector 类内部使用的是浮点数数组)。因此,不能使用复数,但是可以是 int,bool(int 的子类),甚至是 fractions.Fraction 实例等标量。

我们可以像上面那样采用鸭子类型技术,抛出 TypeError。但是这个问题有个更易于理解的方式,而且也更合理:白鹅类型。我们将使用 isinstance() 检查 scalar 类型,但是不硬编码具体的类型,而是检查 numbers.Real 抽象基类。这个抽象基类包含了我们所需要的全部类型,而且还支持以后声明为 numbers.Real 抽象基类的真实子类或虚拟子类的数值类型。下面展示了白鹅类型的实际应用 -- 显式检查抽象类型

我们在第 16 章说过,decimal.Decimal 没有把自己注册为 numbers.Real 的虚拟子类,因此,Vector 不会处理 decimal.Decimal 数字

增加 * 运算符方法:


In [2]:
from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools

class Vector:
    typecode = 'd'
    
    def __init__(self, components):
        self._components = array(self.typecode, components) 
    
    def __iter__(self):
        return iter(self._components)
    
    def __repr__(self):
        components = reprlib.repr(self._components) 
        components = components[components.find('['):-1] 
        return 'Vector({})'.format(components)
    
    def __str__(self):
        return str(tuple(self))
    
    def __bytes__(self):
        return (bytes([ord(self.typecode)]) + 
                bytes(self._components))     
    
    def __eq__(self, other):
        return len(self) == len(other) and all(a == b for a, b in zip(self, other))
    
    def __abs__(self):
        return math.sqrt(sum(x * x for x in self)) 
    
    def __bool__(self):
        return bool(abs(self))
    
    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv) 
    
    # 上面都一样
    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)                # 获取实例所属的类
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral): # index 是 int 或其他整数类型
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            raise TypeError(msg.format(cls=cls))       

    shortcut_names = 'xyzt'

    def __getattr__(self, name):
        cls = type(self)

        if len(name) == 1:
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = '{.__name__!r} object has no attribute {!r}'
        raise AttributeError(msg.format(cls, name))
        
    def __setattr__(self, name, value):
        cls = type(self)
        if len(name) == 1:
            if name in cls.shortcut_names:
                error = 'readonly attribute {attr_name!}'
            elif name.islower():
                error = "can't set attributes 'a' to 'z' in {cls_name!r}"
        else:
            error = ''
        if error:
            msg = error.format(cls_name = cls.__name__, attr_name = name) # 这个方法好,无论错误是哪个,都可以给定值
            raise AttributeError(msg)
        super().__setattr__(name, value) # 默认情况,调用超类的 __setattr__ 方法,提供标准行为
        
    def __hash__(self):
        hashs = (hash(x) for x in self._components) # 注意这里是生成器表达式,不是列表推导式,可以节省内存
        return functools.reduce(operator.xor, hashs)

    def angle(self, n):
        r = math.sqrt(sum(x * x for x in self[n:]))
        a = math.atan2(r, self[n-1])
        if (n == len(self) - 1) and (self[-1] < 0):
            return math.pi * 2 - a
        else:
            return a
    
    def angles(self):
        return (self.angle(n) for n in range(1, len(self)))
        
    def __format__(self, fmt_spec=''):
        if fmt_spec.endswith('h'):
            fmt_spec = fmt_spec[:-1]
            coords = itertools.chain([abs(self)], # 使用 chain 函数生成生成器表达式,无缝迭代向量的模和各个角坐标
                                     self.angles())
            outer_fmt = '<{}>' # 球面坐标
            
        else:
            coords = self
            outer_fmt = '({})' # 笛卡尔坐标
        components = (format(c, fmt_spec) for c in coords)
        return outer_fmt.format(', '.join(components))
    
    
    def __neg__(self):
        return Vector(-x for x in self)
    
    def __pos__(self):
        return Vector(self)
    
    
    def __add__(self, other):
        try:
            pairs = itertools.zip_longest(self, other, fillvalue=0.0)
            return Vector(a + b for a, b in pairs)
        except TypeError:
            return NotImplemented
    
    def __radd__(self, other):
        return self + other
    
    def __mul__(self, scalar):
        if isinstance(scalar, numbers.Real):
            return Vector(n * scalar for n in self)
        else:
            return NotImplemented
        
    def __rmul__(self, scalar):
        return self * scalar

In [3]:
v1 = Vector([1.0, 2.0, 3.0])
14 * v1


Out[3]:
Vector([14.0, 28.0, 42.0])

In [4]:
v1 * True


Out[4]:
Vector([1.0, 2.0, 3.0])

In [5]:
from fractions import Fraction
v1 * Fraction(1, 3)


Out[5]:
Vector([0.3333333333333333, 0.6666666666666666, 1.0])

众多比较运算符

Python 解释器对众多比较符(==,!=,>,<,>=,<=)的处理与前文类似,不过在两个方面有重大区别

  • 正向和反向调用使用的是同一系列方法,对于 == 来说,正向和反向调用都是 __eq__ 方法,只是把参数对调了,而正向的 __gt__ 方法调用的是反向的 __lt__ 方法,并把参数对掉

  • 对于 == 和 != 来说,如果反向调用失败,Python 会比较对象的 ID,而不抛出 TypeError

Python 2 之后比较运算符后备机制都变了,对于 __ne__,现在 Python 3 返回结果是对 __eq__ 结果取反,对于排序比较运算符,Python 3 抛出 TypeError,并把错误消息设为 “unorderable types: int() < tuple()‘。在 Python 2 中,这些比较的结果很怪异,会考虑对象类型和 ID,并且无规律可循。然而,比较整数和元组确实没有意义,因此此时抛出 TypeError 是这门语言的一大进步

了解这些规则之后,我们来分析并改进 Vector.__eq__ 方法的行为


In [7]:
va = Vector([1.0, 2.0, 3.0])
vb = Vector(range(1, 4))
va == vb


Out[7]:
True

In [8]:
t3 = (1, 2, 3)
va == t3


Out[8]:
True

Vector 和 元组比较的结果可能不太理想,作者的观点是结果应该由应用上下文决定。不过,”Python 之禅“作者说: 如果存在多重可能,不要猜测

Python 中 [1, 2] == (1, 2) 结果是 False,所以我们也要在 __eq__ 中做类型检查:


In [25]:
from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools

class Vector:
    typecode = 'd'
    
    def __init__(self, components):
        self._components = array(self.typecode, components) 
    
    def __iter__(self):
        return iter(self._components)
    
    def __repr__(self):
        components = reprlib.repr(self._components) 
        components = components[components.find('['):-1] 
        return 'Vector({})'.format(components)
    
    def __str__(self):
        return str(tuple(self))
    
    def __bytes__(self):
        return (bytes([ord(self.typecode)]) + 
                bytes(self._components))     
    
    def __abs__(self):
        return math.sqrt(sum(x * x for x in self)) 
    
    def __bool__(self):
        return bool(abs(self))
    
    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv) 
    
    # 上面都一样
    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)                # 获取实例所属的类
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral): # index 是 int 或其他整数类型
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            raise TypeError(msg.format(cls=cls))       

    shortcut_names = 'xyzt'

    def __getattr__(self, name):
        cls = type(self)

        if len(name) == 1:
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = '{.__name__!r} object has no attribute {!r}'
        raise AttributeError(msg.format(cls, name))
        
    def __setattr__(self, name, value):
        cls = type(self)
        if len(name) == 1:
            if name in cls.shortcut_names:
                error = 'readonly attribute {attr_name!}'
            elif name.islower():
                error = "can't set attributes 'a' to 'z' in {cls_name!r}"
        else:
            error = ''
        if error:
            msg = error.format(cls_name = cls.__name__, attr_name = name) # 这个方法好,无论错误是哪个,都可以给定值
            raise AttributeError(msg)
        super().__setattr__(name, value) # 默认情况,调用超类的 __setattr__ 方法,提供标准行为
        
    def __hash__(self):
        hashs = (hash(x) for x in self._components) # 注意这里是生成器表达式,不是列表推导式,可以节省内存
        return functools.reduce(operator.xor, hashs)

    def angle(self, n):
        r = math.sqrt(sum(x * x for x in self[n:]))
        a = math.atan2(r, self[n-1])
        if (n == len(self) - 1) and (self[-1] < 0):
            return math.pi * 2 - a
        else:
            return a
    
    def angles(self):
        return (self.angle(n) for n in range(1, len(self)))
        
    def __format__(self, fmt_spec=''):
        if fmt_spec.endswith('h'):
            fmt_spec = fmt_spec[:-1]
            coords = itertools.chain([abs(self)], # 使用 chain 函数生成生成器表达式,无缝迭代向量的模和各个角坐标
                                     self.angles())
            outer_fmt = '<{}>' # 球面坐标
            
        else:
            coords = self
            outer_fmt = '({})' # 笛卡尔坐标
        components = (format(c, fmt_spec) for c in coords)
        return outer_fmt.format(', '.join(components))
    
    
    def __neg__(self):
        return Vector(-x for x in self)
    
    def __pos__(self):
        return Vector(self)
    
    
    def __add__(self, other):
        try:
            pairs = itertools.zip_longest(self, other, fillvalue=0.0)
            return Vector(a + b for a, b in pairs)
        except TypeError:
            return NotImplemented
    
    def __radd__(self, other):
        return self + other
    
    def __mul__(self, scalar):
        if isinstance(scalar, numbers.Real):
            return Vector(n * scalar for n in self)
        else:
            return NotImplemented
        
    def __rmul__(self, scalar):
        return self * scalar
    
    def __eq__(self, other):
        if isinstance(other, Vector):
            print('.......')
            return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
        else:
            return NotImplemented

In [26]:
t3 = (1, 2, 3)
va = Vector([1.0, 2.0, 3.0])
va == t3


Out[26]:
False

上面首先调用 Vector.__eq__(va, t3)

经过上面函数确认,t3 不是 Vector 实例,因此返回 NotImplemented

Python 得到 NotImplemented,尝试调用 tuple.__eq__(t3, va)

tuple 不知道 Vector 是什么,因此返回 NotImplemented

对于 == 来说,如果反向调用返回 NoteImplemented,Python 会比较对象 ID,进行最后一搏。

对于 != 来说,我们不应实现它,因为从 object 继承的 __ne__ 方法的后备行为满足了我们的需求,定义了 __eq__ 方法,而且它不返回 NotImplemented,__ne__ 会对 __eq__ 返回结果取反

也就是说,使用 != 运算符比较的结果是一致的:


In [27]:
va != t3


Out[27]:
True

__ne__ 运作方式与下面类似:


In [28]:
def __ne__(self, other):
    eq_result = self == other
    if eq_result is NotImplemented:
        return NotImplemented
    else:
        return not eq_result

可以看到,Python 3 中 __ne__ 对我们来说够用了,一般不用重载。

增量赋值运算符

现在我们的类已经支持增量运算符 += 和 *= 了:


In [36]:
v1 = Vector([1, 2, 3])
v1_alias = v1
id(v1)


Out[36]:
140383358524216

In [37]:
v1 += Vector([4, 5, 6])
v1


Out[37]:
Vector([5.0, 7.0, 9.0])

In [38]:
id(v1)


Out[38]:
140383358525168

In [39]:
v1_alias


Out[39]:
Vector([1.0, 2.0, 3.0])

In [40]:
v1 *= 11
v1


Out[40]:
Vector([55.0, 77.0, 99.0])

In [41]:
id(v1)


Out[41]:
140383358523376

这里的增量运算符只是语法糖,a += b 的行为和 a = a + b 一样,对于不可便类型来说,这是预期行为,而且,如果定义了 __add__ 方法的话,不用写额外的代码 += 就能使用

然而,如果实现了就地运算符方法,例如 __iadd__,计算 a + b 的结果时会调用就地运算符方法,这种运算符名称表明,他们会就地修改左操作数,不会创建新对象

不可变类型,例如 Vector 类,一定不能实现就地特殊方法,这是明显的事实,不过还是值得提出来

为了展示如何实现就地运算符,我们将扩展 11 章的 BingoCage 类,实现 __add____iadd__ 方法


In [60]:
import abc

class Tombola(abc.ABC):
    @abc.abstractmethod
    def load(self, iterable):
        '''从可迭代对象中添加元素'''
        
    @abc.abstractmethod # 抽象方法使用此标记
    def pick(self):
        '''随机删除元素,然后将其返回
           如果实例为空,这个方法抛出 LookupError
        '''
        
    def loaded(self):
        '''如果至少有一个元素,返回 True,否则返回 False'''
        return bool(self.inspect()) # 抽象基类中的具体方法只能依赖抽象基类定义的接口(即只能使用抽象基类的其他具体方法,抽象方法或特性)
    
    def inspect(self):
        '''返回一个有序元组,由当前元素构成'''
        items = []
        while 1:  # 我们不知道具体子类如何存储元素,为了得到 inspect 结果,不断调用 pick 方法,把 Tombola 清空
            try:
                items.append(self.pick())
            except LookupError:
                break
        self.load(items)  # 再加回去元素
        return tuple(sorted(items))
    
    
import random

class BingoCage(Tombola):
    
    def __init__(self, items):
        self._randomizer = random.SystemRandom()
        self._items = []
        self.load(items)
        
    def load(self, items):
        self._items.extend(items)
        self._randomizer.shuffle(self._items)
        
    def pick(self):
        try:
            return self._items.pop()
        except IndexError:
            raise LookupError('pick from empty BingoCage')
            
    def __call__(self):
        self.pick()


# ==== add

class AddableBingoCage(BingoCage):
    def __add__(self, other): # __add__ 方法的第二个操作数只能是 Tombola 实例
        if isinstance(other, Tombola): # other 是 Tombola 实例,获取元素
            return AddableBingoCage(self.inspect() + other.inspect())
        else:
            return NotImplemented
        
    def __iadd__(self, other):
        if isinstance(other, Tombola): 
            other_iterable = other.inspect()
        else:
            try:
                other_iterable = iter(other) # 否则创建迭代器
            except TypeError:
                self_cls = type(self).__name__
                msg = 'right operand in += must be {!r} or an iterable'
                raise TypeError(msg.format(self_cls))
        self.load(other_iterable)
        return self  # 非常重要,增量赋值特殊方法必须返回 self

最后,还有一点要注意,从设计上来看,AddableBingoCage 不用定义 __radd__ 方法,因为不需要。如果右操作数是相同类型,那么正向方法 __add__ 会处理,因此 Python 计算 a + b 时,如果 a 是 AddableBingoCage 实例,而 b 不是,那么返回 NotImplemented,那么 Python 最好放这i,抛出 TypeError,因为无法处理 b

一般来说,如果终追运算符的正向方法(如 __mul__)之处理与 self 同一类型的操作数,那么就无需实现反向方法,因为按照定义,反向方法是为了处理不同类型的操作数

最后我们看看效果:


In [61]:
vowels = 'AEIOU'
globe = AddableBingoCage(vowels)
globe.inspect()


Out[61]:
('A', 'E', 'I', 'O', 'U')

In [62]:
globe.pick() in vowels


Out[62]:
True

In [63]:
len(globe.inspect())


Out[63]:
4

In [64]:
globe2 = AddableBingoCage('XYZ')
globe3 = globe + globe2
len(globe3.inspect())


Out[64]:
7

In [65]:
void = globe + [10, 20]


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-65-af411a756e83> in <module>()
----> 1 void = globe + [10, 20]

TypeError: unsupported operand type(s) for +: 'AddableBingoCage' and 'list'

In [66]:
globe_orig = globe
len(globe.inspect())


Out[66]:
4

In [67]:
globe += globe2
len(globe.inspect())


Out[67]:
7

In [68]:
globe += ['M', 'N']
len(globe.inspect())


Out[68]:
9

In [69]:
globe is globe_orig


Out[69]:
True

In [70]:
globe += 1


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-60-842c09972972> in __iadd__(self, other)
     66             try:
---> 67                 other_iterable = iter(other)
     68             except TypeError:

TypeError: 'int' object is not iterable

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<ipython-input-70-f0e1c411856a> in <module>()
----> 1 globe += 1

<ipython-input-60-842c09972972> in __iadd__(self, other)
     69                 self_cls = type(self).__name__
     70                 msg = 'right operand in += must be {!r} or an iterable'
---> 71                 raise TypeError(msg.format(self_cls))
     72         self.load(other_iterable)
     73         return self

TypeError: right operand in += must be 'AddableBingoCage' or an iterable

In [ ]: