Python正确重载运算符的方法示例详解


Posted in Python onAugust 27, 2017

前言

说到运算符重载相信大家都不陌生,运算符重载的作用是让用户定义的对象使用中缀运算符(如 + 和 |)或一元运算符(如 - 和 ~)。说得宽泛一些,在 Python 中,函数调用(())、属性访问(.)和元素访问 / 切片([])也是运算符。

我们为 Vector 类简略实现了几个运算符。__add__ 和 __mul__ 方法是为了展示如何使用特殊方法重载运算符,不过有些小问题被我们忽视了。此外,我们定义的Vector2d.__eq__ 方法认为 Vector(3, 4) == [3, 4] 是真的(True),这可能并不合理。下面来一起看看详细的介绍吧。

运算符重载基础

在某些圈子中,运算符重载的名声并不好。这个语言特性可能(已经)被滥用,让程序员困惑,导致缺陷和意料之外的性能瓶颈。但是,如果使用得当,API 会变得好用,代码会变得易于阅读。Python 施加了一些限制,做好了灵活性、可用性和安全性方面的平衡:

  • 不能重载内置类型的运算符
  • 不能新建运算符,只能重载现有的
  • 某些运算符不能重载——is、and、or 和 not(不过位运算符
  • &、| 和 ~ 可以)

前面的博文已经为 Vector 定义了一个中缀运算符,即 ==,这个运算符由__eq__ 方法支持。我们将改进 __eq__ 方法的实现,更好地处理不是Vector 实例的操作数。然而,在运算符重载方面,众多比较运算符(==、!=、>、<、>=、<=)是特例,因此我们首先将在 Vector 中重载四个算术运算符:一元运算符 - 和 +,以及中缀运算符 + 和 *。

一元运算符

-(__neg__)

一元取负算术运算符。如果 x 是 -2,那么 -x == 2。

+(__pos__)

一元取正算术运算符。通常,x == +x,但也有一些例外。如果好奇,请阅读“x 和 +x 何时不相等”附注栏。

~(__invert__)

对整数按位取反,定义为 ~x == -(x+1)。如果 x 是 2,那么 ~x== -3。

支持一元运算符很简单,只需实现相应的特殊方法。这些特殊方法只有一个参数,self。然后,使用符合所在类的逻辑实现。不过,要遵守运算符的一个基本规则:始终返回一个新对象。也就是说,不能修改self,要创建并返回合适类型的新实例。

对 - 和 + 来说,结果可能是与 self 同属一类的实例。多数时候,+ 最好返回 self 的副本。abs(...) 的结果应该是一个标量。但是对 ~ 来说,很难说什么结果是合理的,因为可能不是处理整数的位,例如在ORM 中,SQL WHERE 子句应该返回反集。

def __abs__(self):
  return math.sqrt(sum(x * x for x in self))

 def __neg__(self):
  return Vector(-x for x in self)   #为了计算 -v,构建一个新 Vector 实例,把 self 的每个分量都取反

 def __pos__(self):
  return Vector(self)      #为了计算 +v,构建一个新 Vector 实例,传入 self 的各个分量

x 和 +x 何时不相等

每个人都觉得 x == +x,而且在 Python 中,几乎所有情况下都是这样。但是,我在标准库中找到两例 x != +x 的情况。

第一例与 decimal.Decimal 类有关。如果 x 是 Decimal 实例,在算术运算的上下文中创建,然后在不同的上下文中计算 +x,那么 x!= +x。例如,x 所在的上下文使用某个精度,而计算 +x 时,精度变了,例如下面的 ?

算术运算上下文的精度变化可能导致 x 不等于 +x

>>> import decimal
>>> ctx = decimal.getcontext()







#获取当前全局算术运算符的上下文引用
>>> ctx.prec = 40











  #把算术运算上下文的精度设为40
>>> one_third = decimal.Decimal('1') / decimal.Decimal('3') #使用当前精度计算1/3
>>> one_third
Decimal('0.3333333333333333333333333333333333333333')

 #查看结果,小数点后的40个数字
>>> one_third == +one_third









#one_third = +one_thied返回TRUE
True
>>> ctx.prec = 28












#把精度降为28
>>> one_third == +one_third









#one_third = +one_thied返回FalseFalse >>> +one_third Decimal('0.3333333333333333333333333333')
 #查看+one_third,小术后的28位数字

虽然每个 +one_third 表达式都会使用 one_third 的值创建一个新 Decimal 实例,但是会使用当前算术运算上下文的精度。

x != +x 的第二例在 collections.Counter 的文档中(https://docs.python.org/3/library/collections.html#collections.Counter)。类实现了几个算术运算符,例如中缀运算符 +,作用是把两个Counter 实例的计数器加在一起。然而,从实用角度出发,Counter 相加时,负值和零值计数会从结果中剔除。而一元运算符 + 等同于加上一个空 Counter,因此它产生一个新的Counter 且仅保留大于零的计数器。

?  一元运算符 + 得到一个新 Counter 实例,但是没有零值和负值计数器

>>> from collections import Counter
>>> ct = Counter('abracadabra')
>>> ct['r'] = -3
>>> ct['d'] = 0
>>> ct
Counter({'a': 5, 'r': -3, 'b': 2, 'c': 1, 'd': 0})
>>> +ct
Counter({'a': 5, 'b': 2, 'c': 1})

重载向量加法运算符+

两个欧几里得向量加在一起得到的是一个新向量,它的各个分量是两个向量中相应的分量之和。比如说:

>>> v1 = Vector([3, 4, 5])
>>> v2 = Vector([6, 7, 8])
>>> v1 + v2
Vector([9.0, 11.0, 13.0])
>>> v1 + v2 == Vector([3+6, 4+7, 5+8])
True

确定这些基本的要求之后,__add__ 方法的实现短小精悍,? 如下

def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)   #生成一个元祖,a来自self,b来自other,如果两个长度不够,通过fillvalue设置的补全值自动补全短的
  return Vector(a + b for a, b in pairs)        #使用生成器表达式计算pairs中的各个元素的和

还可以把Vector 加到元组或任何生成数字的可迭代对象上:

# 在Vector类中定义 

 def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)   #生成一个元祖,a来自self,b来自other,如果两个长度不够,通过fillvalue设置的补全值自动补全短的
  return Vector(a + b for a, b in pairs)        #使用生成器表达式计算pairs中的各个元素的和

 def __radd__(self, other):            #会直接委托给__add__
  return self + other

__radd__ 通常就这么简单:直接调用适当的运算符,在这里就是委托__add__。任何可交换的运算符都能这么做。处理数字和向量时,+ 可以交换,但是拼接序列时不行。

重载标量乘法运算符*

Vector([1, 2, 3]) * x 是什么意思?如果 x 是数字,就是计算标量积(scalar product),结果是一个新 Vector 实例,各个分量都会乘以x——这也叫元素级乘法(elementwise multiplication)。

>>> v1 = Vector([1, 2, 3])
>>> v1 * 10
Vector([10.0, 20.0, 30.0])
>>> 11 * v1
Vector([11.0, 22.0, 33.0])

涉及 Vector 操作数的积还有一种,叫两个向量的点积(dotproduct);如果把一个向量看作 1×N 矩阵,把另一个向量看作 N×1 矩阵,那么就是矩阵乘法。NumPy 等库目前的做法是,不重载这两种意义的 *,只用 * 计算标量积。例如,在 NumPy 中,点积使用numpy.dot() 函数计算。

回到标量积的话题。我们依然先实现最简可用的 __mul__ 和 __rmul__方法:

def __mul__(self, scalar):
  if isinstance(scalar, numbers.Real):
   return Vector(n * scalar for n in self)
  else:
   return NotImplemented

 def __rmul__(self, scalar):
  return self * scalar

这两个方法确实可用,但是提供不兼容的操作数时会出问题。scalar参数的值要是数字,与浮点数相乘得到的积是另一个浮点数(因为Vector 类在内部使用浮点数数组)。因此,不能使用复数,但可以是int、bool(int 的子类),甚至 fractions.Fraction 实例等标量。

提供了点积所需的 @ 记号(例如,a @ b 是 a 和 b 的点积)。@ 运算符由特殊方法 __matmul__、__rmatmul__ 和__imatmul__ 提供支持,名称取自“matrix multiplication”(矩阵乘法)

>>> va = Vector([1, 2, 3])
>>> vz = Vector([5, 6, 7])
>>> va @ vz == 38.0 # 1*5 + 2*6 + 3*7
True
>>> [10, 20, 30] @ vz
380.0
>>> va @ 3
Traceback (most recent call last):
...
TypeError: unsupported operand type(s) for @: 'Vector' and 'int'

下面是相应特殊方法的代码:

>>> va = Vector([1, 2, 3])
>>> vz = Vector([5, 6, 7])
>>> va @ vz == 38.0 # 1*5 + 2*6 + 3*7
True
>>> [10, 20, 30] @ vz
380.0
>>> va @ 3
Traceback (most recent call last):
...
TypeError: unsupported operand type(s) for @: 'Vector' and 'int'

众多比较运算符

Python 解释器对众多比较运算符(==、!=、>、<、>=、<=)的处理与前文类似,不过在两个方面有重大区别。

  • 正向和反向调用使用的是同一系列方法。例如,对 == 来说,正向和反向调用都是 __eq__ 方法,只是把参数对调了;而正向的 __gt__ 方法调用的是反向的 __lt__方法,并把参数对调。
  • 对 == 和 != 来说,如果反向调用失败,Python 会比较对象的 ID,而不抛出 TypeError。

众多比较运算符:正向方法返回NotImplemented的话,调用反向方法

分组   中缀运算符   正向方法调用   反向方法调用   后备机制
  相等性   a == b   a.__eq__(b)   b.__eq__(a)   返回 id(a) == id(b)
    a != b   a.__ne__(b)   b.__ne__(a)   返回 not (a == b)
  排序   a > b   a.__gt__(b)   b.__lt__(a)   抛出 TypeError
    a   a.__lt__(b)   b.__gt__(a)   抛出 TypeError
    a >= b   a.__ge__(b)   b.__le__(a)   抛出 TypeError
    a   a.__le__(b)   b.__ge__(a)   抛出T ypeError

看下面的?

from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools


class Vector:
 typecode = 'd'

 def __init__(self, components):
  self._components = array(self.typecode, components)

 def __iter__(self):
  return iter(self._components)

 def __repr__(self):
  components = reprlib.repr(self._components)
  components = components[components.find('['):-1]
  return 'Vector({})'.format(components)

 def __str__(self):
  return str(tuple(self))

 def __bytes__(self):
  return (bytes([ord(self.typecode)]) + bytes(self._components))

 def __eq__(self, other):
  return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))

 def __hash__(self):
  hashes = map(hash, self._components)
  return functools.reduce(operator.xor, hashes, 0)

 def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)   #生成一个元祖,a来自self,b来自other,如果两个长度不够,通过fillvalue设置的补全值自动补全短的
  return Vector(a + b for a, b in pairs)        #使用生成器表达式计算pairs中的各个元素的和

 def __radd__(self, other):            #会直接委托给__add__
  return self + other

 def __mul__(self, scalar):
  if isinstance(scalar, numbers.Real):
   return Vector(n * scalar for n in self)
  else:
   return NotImplemented

 def __rmul__(self, scalar):
  return self * scalar

 def __matmul__(self, other):
  try:
   return sum(a * b for a, b in zip(self, other))
  except TypeError:
   return NotImplemented

 def __rmatmul__(self, other):
  return self @ other

 def __abs__(self):
  return math.sqrt(sum(x * x for x in self))

 def __neg__(self):
  return Vector(-x for x in self)   #为了计算 -v,构建一个新 Vector 实例,把 self 的每个分量都取反

 def __pos__(self):
  return Vector(self)       #为了计算 +v,构建一个新 Vector 实例,传入 self 的各个分量

 def __bool__(self):
  return bool(abs(self))

 def __len__(self):
  return len(self._components)

 def __getitem__(self, index):
  cls = type(self)

  if isinstance(index, slice):
   return cls(self._components[index])
  elif isinstance(index, numbers.Integral):
   return self._components[index]
  else:
   msg = '{.__name__} indices must be integers'
   raise TypeError(msg.format(cls))

 shorcut_names = 'xyzt'

 def __getattr__(self, name):
  cls = type(self)

  if len(name) == 1:
   pos = cls.shorcut_names.find(name)
   if 0 <= pos < len(self._components):
    return self._components[pos]
  msg = '{.__name__!r} object has no attribute {!r}'
  raise AttributeError(msg.format(cls, name))

 def angle(self, n):
  r = math.sqrt(sum(x * x for x in self[n:]))
  a = math.atan2(r, self[n-1])
  if (n == len(self) - 1 ) and (self[-1] < 0):
   return math.pi * 2 - a
  else:
   return a

 def angles(self):
  return (self.angle(n) for n in range(1, len(self)))

 def __format__(self, fmt_spec=''):
  if fmt_spec.endswith('h'):
   fmt_spec = fmt_spec[:-1]
   coords = itertools.chain([abs(self)], self.angles())
   outer_fmt = '<{}>'
  else:
   coords = self
   outer_fmt = '({})'
  components = (format(c, fmt_spec) for c in coords)
  return outer_fmt.format(', '.join(components))

 @classmethod
 def frombytes(cls, octets):
  typecode = chr(octets[0])
  memv = memoryview(octets[1:]).cast(typecode)
  return cls(memv)

va = Vector([1.0, 2.0, 3.0])
vb = Vector(range(1, 4))
print('va == vb:', va == vb)     #两个具有相同数值分量的 Vector 实例是相等的
t3 = (1, 2, 3)
print('va == t3:', va == t3)

print('[1, 2] == (1, 2):', [1, 2] == (1, 2))

上面代码执行返回的结果为:

va == vb: True
va == t3: True
[1, 2] == (1, 2): False

从 Python 自身来找线索,我们发现 [1,2] == (1, 2) 的结果是False。因此,我们要保守一点,做些类型检查。如果第二个操作数是Vector 实例(或者 Vector 子类的实例),那么就使用 __eq__ 方法的当前逻辑。否则,返回 NotImplemented,让 Python 处理。

? vector_v8.py:改进 Vector 类的 __eq__ 方法

def __eq__(self, other):
   if isinstance(other, Vector):          #判断对比的是否和Vector同属一个实例
    return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
   else:
    return NotImplemented           #否则,返回NotImplemented

改进以后的代码执行结果:

>>> va = Vector([1.0, 2.0, 3.0])
>>> vb = Vector(range(1, 4))
>>> va == vb 
True
>>> t3 = (1, 2, 3)
>>> va == t3
False

增量赋值运算符

Vector 类已经支持增量赋值运算符 += 和 *= 了,示例如下

?  增量赋值不会修改不可变目标,而是新建实例,然后重新绑定

>>> v1 = Vector([1, 2, 3])
>>> v1_alias = v1 




# 复制一份,供后面审查Vector([1, 2, 3])对象
>>> id(v1) 







# 记住一开始绑定给v1的Vector实例的ID
>>> v1 += Vector([4, 5, 6]) 


# 增量加法运算
>>> v1 








 # 结果与预期相符
Vector([5.0, 7.0, 9.0])
>>> id(v1) 







# 但是创建了新的Vector实例
>>> v1_alias 






 # 审查v1_alias,确认原来的Vector实例没被修改
Vector([1.0, 2.0, 3.0])
>>> v1 *= 11 






 # 增量乘法运算
>>> v1 








# 同样,结果与预期相符,但是创建了新的Vector实例
Vector([55.0, 77.0, 99.0])
>>> id(v1)

完整代码:

from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools


class Vector:
 typecode = 'd'

 def __init__(self, components):
  self._components = array(self.typecode, components)

 def __iter__(self):
  return iter(self._components)

 def __repr__(self):
  components = reprlib.repr(self._components)
  components = components[components.find('['):-1]
  return 'Vector({})'.format(components)

 def __str__(self):
  return str(tuple(self))

 def __bytes__(self):
  return (bytes([ord(self.typecode)]) + bytes(self._components))

 def __eq__(self, other):
  if isinstance(other, Vector):          
   return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
  else:
   return NotImplemented          

 def __hash__(self):
  hashes = map(hash, self._components)
  return functools.reduce(operator.xor, hashes, 0)

 def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)   
  return Vector(a + b for a, b in pairs)        

 def __radd__(self, other):            
  return self + other

 def __mul__(self, scalar):
  if isinstance(scalar, numbers.Real):
   return Vector(n * scalar for n in self)
  else:
   return NotImplemented

 def __rmul__(self, scalar):
  return self * scalar

 def __matmul__(self, other):
  try:
   return sum(a * b for a, b in zip(self, other))
  except TypeError:
   return NotImplemented

 def __rmatmul__(self, other):
  return self @ other

 def __abs__(self):
  return math.sqrt(sum(x * x for x in self))

 def __neg__(self):
  return Vector(-x for x in self)   

 def __pos__(self):
  return Vector(self)       

 def __bool__(self):
  return bool(abs(self))

 def __len__(self):
  return len(self._components)

 def __getitem__(self, index):
  cls = type(self)

  if isinstance(index, slice):
   return cls(self._components[index])
  elif isinstance(index, numbers.Integral):
   return self._components[index]
  else:
   msg = '{.__name__} indices must be integers'
   raise TypeError(msg.format(cls))

 shorcut_names = 'xyzt'

 def __getattr__(self, name):
  cls = type(self)

  if len(name) == 1:
   pos = cls.shorcut_names.find(name)
   if 0 <= pos < len(self._components):
    return self._components[pos]
  msg = '{.__name__!r} object has no attribute {!r}'
  raise AttributeError(msg.format(cls, name))

 def angle(self, n):
  r = math.sqrt(sum(x * x for x in self[n:]))
  a = math.atan2(r, self[n-1])
  if (n == len(self) - 1 ) and (self[-1] < 0):
   return math.pi * 2 - a
  else:
   return a

 def angles(self):
  return (self.angle(n) for n in range(1, len(self)))

 def __format__(self, fmt_spec=''):
  if fmt_spec.endswith('h'):
   fmt_spec = fmt_spec[:-1]
   coords = itertools.chain([abs(self)], self.angles())
   outer_fmt = '<{}>'
  else:
   coords = self
   outer_fmt = '({})'
  components = (format(c, fmt_spec) for c in coords)
  return outer_fmt.format(', '.join(components))

 @classmethod
 def frombytes(cls, octets):
  typecode = chr(octets[0])
  memv = memoryview(octets[1:]).cast(typecode)
  return cls(memv)

总结

以上就是这篇文章的全部内容了,希望本文的内容对大家的学习或者工作能带来一定的帮助,如果有疑问大家可以留言交流,谢谢大家对三水点靠木的支持。

Python 相关文章推荐
python线程、进程和协程详解
Jul 19 Python
python分割列表(list)的方法示例
May 07 Python
django如何连接已存在数据的数据库
Aug 14 Python
Python Scapy随心所欲研究TCP协议栈
Nov 20 Python
python3使用QQ邮箱发送邮件
May 20 Python
django admin组件使用方法详解
Jul 19 Python
python或C++读取指定文件夹下的所有图片
Aug 31 Python
基于Python实现剪切板实时监控方法解析
Sep 11 Python
python通过SSH登陆linux并操作的实现
Oct 10 Python
解决Python中导入自己写的类,被划红线,但不影响执行的问题
Jul 13 Python
python第三方网页解析器 lxml 扩展库与 xpath 的使用方法
Apr 06 Python
详解Django的MVT设计模式
Apr 29 Python
深入学习Python中的上下文管理器与else块
Aug 27 #Python
python利用MethodType绑定方法到类示例代码
Aug 27 #Python
Python中使用haystack实现django全文检索搜索引擎功能
Aug 26 #Python
python读取excel表格生成erlang数据
Aug 26 #Python
使用Python实现简单的服务器功能
Aug 25 #Python
详解Python实现多进程异步事件驱动引擎
Aug 25 #Python
python基础while循环及if判断的实例讲解
Aug 25 #Python
You might like
php递归方法实现无限分类实例代码
2014/02/28 PHP
PHP错误Warning: Cannot modify header information - headers already sent by解决方法
2014/09/27 PHP
PHP从FLV文件获取视频预览图的方法
2015/03/12 PHP
php中动态调用函数的方法
2015/03/16 PHP
PHP处理数组和XML之间的互相转换
2016/06/02 PHP
Extjs实现进度条的两种便捷方式
2013/09/26 Javascript
Jquery幻灯片特效代码分享--打开页面随机选择切换方式(3)
2015/08/15 Javascript
nodejs实现bigpipe异步加载页面方案
2016/01/26 NodeJs
基于javascript html5实现多文件上传
2016/03/03 Javascript
jquery实现ajax加载超时提示的方法
2016/07/23 Javascript
基于BootStrap实现局部刷新分页实例代码
2016/08/08 Javascript
用nodejs搭建websocket服务器
2017/01/23 NodeJs
vue中用动态组件实现选项卡切换效果
2017/03/25 Javascript
微信小程序实现动态改变view标签宽度和高度的方法【附demo源码下载】
2017/12/05 Javascript
js原生实现移动端手指滑动轮播图效果的示例
2018/01/02 Javascript
微信实现自动跳转到用其他浏览器打开指定APP下载
2019/02/15 Javascript
JavaScript中的垃圾回收与内存泄漏示例详解
2019/05/02 Javascript
JS实现的检验身份证格式并输出出生日期,年龄,性别,出生地示例
2019/05/17 Javascript
微信小程序开发(三):返回上一级页面并刷新操作示例【页面栈】
2020/06/01 Javascript
JavaScript中使用Spread运算符的八种方法总结
2020/06/18 Javascript
python访问纯真IP数据库的代码
2011/05/19 Python
python基于隐马尔可夫模型实现中文拼音输入
2016/04/01 Python
numpy中的高维数组转置实例
2018/04/17 Python
python3使用SMTP发送简单文本邮件
2018/06/19 Python
Scrapy框架爬取Boss直聘网Python职位信息的源码
2019/02/22 Python
Flask使用Pyecharts在单个页面展示多个图表的方法
2019/08/05 Python
Python实现隐马尔可夫模型的前向后向算法的示例代码
2019/12/31 Python
基于Tensorflow批量数据的输入实现方式
2020/02/05 Python
Python tkinter之Bind(绑定事件)的使用示例
2021/02/05 Python
基于HTML5的WebSocket的实例代码
2018/08/15 HTML / CSS
汽车运用工程专业毕业生推荐信
2013/12/25 职场文书
回门宴答谢词
2014/01/13 职场文书
餐厅感恩节活动策划方案
2014/10/11 职场文书
2014学习十八届四中全会精神思想汇报范文
2014/10/23 职场文书
2015年国庆节演讲稿范文
2015/07/30 职场文书
MySQL 常见的数据表设计误区汇总
2021/06/07 MySQL