Python正确重载运算符的方法示例详解


Posted in Python onAugust 27, 2017

前言

说到运算符重载相信大家都不陌生,运算符重载的作用是让用户定义的对象使用中缀运算符(如 + 和 |)或一元运算符(如 - 和 ~)。说得宽泛一些,在 Python 中,函数调用(())、属性访问(.)和元素访问 / 切片([])也是运算符。

我们为 Vector 类简略实现了几个运算符。__add__ 和 __mul__ 方法是为了展示如何使用特殊方法重载运算符,不过有些小问题被我们忽视了。此外,我们定义的Vector2d.__eq__ 方法认为 Vector(3, 4) == [3, 4] 是真的(True),这可能并不合理。下面来一起看看详细的介绍吧。

运算符重载基础

在某些圈子中,运算符重载的名声并不好。这个语言特性可能(已经)被滥用,让程序员困惑,导致缺陷和意料之外的性能瓶颈。但是,如果使用得当,API 会变得好用,代码会变得易于阅读。Python 施加了一些限制,做好了灵活性、可用性和安全性方面的平衡:

  • 不能重载内置类型的运算符
  • 不能新建运算符,只能重载现有的
  • 某些运算符不能重载——is、and、or 和 not(不过位运算符
  • &、| 和 ~ 可以)

前面的博文已经为 Vector 定义了一个中缀运算符,即 ==,这个运算符由__eq__ 方法支持。我们将改进 __eq__ 方法的实现,更好地处理不是Vector 实例的操作数。然而,在运算符重载方面,众多比较运算符(==、!=、>、<、>=、<=)是特例,因此我们首先将在 Vector 中重载四个算术运算符:一元运算符 - 和 +,以及中缀运算符 + 和 *。

一元运算符

-(__neg__)

一元取负算术运算符。如果 x 是 -2,那么 -x == 2。

+(__pos__)

一元取正算术运算符。通常,x == +x,但也有一些例外。如果好奇,请阅读“x 和 +x 何时不相等”附注栏。

~(__invert__)

对整数按位取反,定义为 ~x == -(x+1)。如果 x 是 2,那么 ~x== -3。

支持一元运算符很简单,只需实现相应的特殊方法。这些特殊方法只有一个参数,self。然后,使用符合所在类的逻辑实现。不过,要遵守运算符的一个基本规则:始终返回一个新对象。也就是说,不能修改self,要创建并返回合适类型的新实例。

对 - 和 + 来说,结果可能是与 self 同属一类的实例。多数时候,+ 最好返回 self 的副本。abs(...) 的结果应该是一个标量。但是对 ~ 来说,很难说什么结果是合理的,因为可能不是处理整数的位,例如在ORM 中,SQL WHERE 子句应该返回反集。

def __abs__(self):
  return math.sqrt(sum(x * x for x in self))

 def __neg__(self):
  return Vector(-x for x in self)   #为了计算 -v,构建一个新 Vector 实例,把 self 的每个分量都取反

 def __pos__(self):
  return Vector(self)      #为了计算 +v,构建一个新 Vector 实例,传入 self 的各个分量

x 和 +x 何时不相等

每个人都觉得 x == +x,而且在 Python 中,几乎所有情况下都是这样。但是,我在标准库中找到两例 x != +x 的情况。

第一例与 decimal.Decimal 类有关。如果 x 是 Decimal 实例,在算术运算的上下文中创建,然后在不同的上下文中计算 +x,那么 x!= +x。例如,x 所在的上下文使用某个精度,而计算 +x 时,精度变了,例如下面的 ?

算术运算上下文的精度变化可能导致 x 不等于 +x

>>> import decimal
>>> ctx = decimal.getcontext()







#获取当前全局算术运算符的上下文引用
>>> ctx.prec = 40











  #把算术运算上下文的精度设为40
>>> one_third = decimal.Decimal('1') / decimal.Decimal('3') #使用当前精度计算1/3
>>> one_third
Decimal('0.3333333333333333333333333333333333333333')

 #查看结果,小数点后的40个数字
>>> one_third == +one_third









#one_third = +one_thied返回TRUE
True
>>> ctx.prec = 28












#把精度降为28
>>> one_third == +one_third









#one_third = +one_thied返回FalseFalse >>> +one_third Decimal('0.3333333333333333333333333333')
 #查看+one_third,小术后的28位数字

虽然每个 +one_third 表达式都会使用 one_third 的值创建一个新 Decimal 实例,但是会使用当前算术运算上下文的精度。

x != +x 的第二例在 collections.Counter 的文档中(https://docs.python.org/3/library/collections.html#collections.Counter)。类实现了几个算术运算符,例如中缀运算符 +,作用是把两个Counter 实例的计数器加在一起。然而,从实用角度出发,Counter 相加时,负值和零值计数会从结果中剔除。而一元运算符 + 等同于加上一个空 Counter,因此它产生一个新的Counter 且仅保留大于零的计数器。

?  一元运算符 + 得到一个新 Counter 实例,但是没有零值和负值计数器

>>> from collections import Counter
>>> ct = Counter('abracadabra')
>>> ct['r'] = -3
>>> ct['d'] = 0
>>> ct
Counter({'a': 5, 'r': -3, 'b': 2, 'c': 1, 'd': 0})
>>> +ct
Counter({'a': 5, 'b': 2, 'c': 1})

重载向量加法运算符+

两个欧几里得向量加在一起得到的是一个新向量,它的各个分量是两个向量中相应的分量之和。比如说:

>>> v1 = Vector([3, 4, 5])
>>> v2 = Vector([6, 7, 8])
>>> v1 + v2
Vector([9.0, 11.0, 13.0])
>>> v1 + v2 == Vector([3+6, 4+7, 5+8])
True

确定这些基本的要求之后,__add__ 方法的实现短小精悍,? 如下

def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)   #生成一个元祖,a来自self,b来自other,如果两个长度不够,通过fillvalue设置的补全值自动补全短的
  return Vector(a + b for a, b in pairs)        #使用生成器表达式计算pairs中的各个元素的和

还可以把Vector 加到元组或任何生成数字的可迭代对象上:

# 在Vector类中定义 

 def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)   #生成一个元祖,a来自self,b来自other,如果两个长度不够,通过fillvalue设置的补全值自动补全短的
  return Vector(a + b for a, b in pairs)        #使用生成器表达式计算pairs中的各个元素的和

 def __radd__(self, other):            #会直接委托给__add__
  return self + other

__radd__ 通常就这么简单:直接调用适当的运算符,在这里就是委托__add__。任何可交换的运算符都能这么做。处理数字和向量时,+ 可以交换,但是拼接序列时不行。

重载标量乘法运算符*

Vector([1, 2, 3]) * x 是什么意思?如果 x 是数字,就是计算标量积(scalar product),结果是一个新 Vector 实例,各个分量都会乘以x——这也叫元素级乘法(elementwise multiplication)。

>>> v1 = Vector([1, 2, 3])
>>> v1 * 10
Vector([10.0, 20.0, 30.0])
>>> 11 * v1
Vector([11.0, 22.0, 33.0])

涉及 Vector 操作数的积还有一种,叫两个向量的点积(dotproduct);如果把一个向量看作 1×N 矩阵,把另一个向量看作 N×1 矩阵,那么就是矩阵乘法。NumPy 等库目前的做法是,不重载这两种意义的 *,只用 * 计算标量积。例如,在 NumPy 中,点积使用numpy.dot() 函数计算。

回到标量积的话题。我们依然先实现最简可用的 __mul__ 和 __rmul__方法:

def __mul__(self, scalar):
  if isinstance(scalar, numbers.Real):
   return Vector(n * scalar for n in self)
  else:
   return NotImplemented

 def __rmul__(self, scalar):
  return self * scalar

这两个方法确实可用,但是提供不兼容的操作数时会出问题。scalar参数的值要是数字,与浮点数相乘得到的积是另一个浮点数(因为Vector 类在内部使用浮点数数组)。因此,不能使用复数,但可以是int、bool(int 的子类),甚至 fractions.Fraction 实例等标量。

提供了点积所需的 @ 记号(例如,a @ b 是 a 和 b 的点积)。@ 运算符由特殊方法 __matmul__、__rmatmul__ 和__imatmul__ 提供支持,名称取自“matrix multiplication”(矩阵乘法)

>>> va = Vector([1, 2, 3])
>>> vz = Vector([5, 6, 7])
>>> va @ vz == 38.0 # 1*5 + 2*6 + 3*7
True
>>> [10, 20, 30] @ vz
380.0
>>> va @ 3
Traceback (most recent call last):
...
TypeError: unsupported operand type(s) for @: 'Vector' and 'int'

下面是相应特殊方法的代码:

>>> va = Vector([1, 2, 3])
>>> vz = Vector([5, 6, 7])
>>> va @ vz == 38.0 # 1*5 + 2*6 + 3*7
True
>>> [10, 20, 30] @ vz
380.0
>>> va @ 3
Traceback (most recent call last):
...
TypeError: unsupported operand type(s) for @: 'Vector' and 'int'

众多比较运算符

Python 解释器对众多比较运算符(==、!=、>、<、>=、<=)的处理与前文类似,不过在两个方面有重大区别。

  • 正向和反向调用使用的是同一系列方法。例如,对 == 来说,正向和反向调用都是 __eq__ 方法,只是把参数对调了;而正向的 __gt__ 方法调用的是反向的 __lt__方法,并把参数对调。
  • 对 == 和 != 来说,如果反向调用失败,Python 会比较对象的 ID,而不抛出 TypeError。

众多比较运算符:正向方法返回NotImplemented的话,调用反向方法

分组   中缀运算符   正向方法调用   反向方法调用   后备机制
  相等性   a == b   a.__eq__(b)   b.__eq__(a)   返回 id(a) == id(b)
    a != b   a.__ne__(b)   b.__ne__(a)   返回 not (a == b)
  排序   a > b   a.__gt__(b)   b.__lt__(a)   抛出 TypeError
    a   a.__lt__(b)   b.__gt__(a)   抛出 TypeError
    a >= b   a.__ge__(b)   b.__le__(a)   抛出 TypeError
    a   a.__le__(b)   b.__ge__(a)   抛出T ypeError

看下面的?

from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools


class Vector:
 typecode = 'd'

 def __init__(self, components):
  self._components = array(self.typecode, components)

 def __iter__(self):
  return iter(self._components)

 def __repr__(self):
  components = reprlib.repr(self._components)
  components = components[components.find('['):-1]
  return 'Vector({})'.format(components)

 def __str__(self):
  return str(tuple(self))

 def __bytes__(self):
  return (bytes([ord(self.typecode)]) + bytes(self._components))

 def __eq__(self, other):
  return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))

 def __hash__(self):
  hashes = map(hash, self._components)
  return functools.reduce(operator.xor, hashes, 0)

 def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)   #生成一个元祖,a来自self,b来自other,如果两个长度不够,通过fillvalue设置的补全值自动补全短的
  return Vector(a + b for a, b in pairs)        #使用生成器表达式计算pairs中的各个元素的和

 def __radd__(self, other):            #会直接委托给__add__
  return self + other

 def __mul__(self, scalar):
  if isinstance(scalar, numbers.Real):
   return Vector(n * scalar for n in self)
  else:
   return NotImplemented

 def __rmul__(self, scalar):
  return self * scalar

 def __matmul__(self, other):
  try:
   return sum(a * b for a, b in zip(self, other))
  except TypeError:
   return NotImplemented

 def __rmatmul__(self, other):
  return self @ other

 def __abs__(self):
  return math.sqrt(sum(x * x for x in self))

 def __neg__(self):
  return Vector(-x for x in self)   #为了计算 -v,构建一个新 Vector 实例,把 self 的每个分量都取反

 def __pos__(self):
  return Vector(self)       #为了计算 +v,构建一个新 Vector 实例,传入 self 的各个分量

 def __bool__(self):
  return bool(abs(self))

 def __len__(self):
  return len(self._components)

 def __getitem__(self, index):
  cls = type(self)

  if isinstance(index, slice):
   return cls(self._components[index])
  elif isinstance(index, numbers.Integral):
   return self._components[index]
  else:
   msg = '{.__name__} indices must be integers'
   raise TypeError(msg.format(cls))

 shorcut_names = 'xyzt'

 def __getattr__(self, name):
  cls = type(self)

  if len(name) == 1:
   pos = cls.shorcut_names.find(name)
   if 0 <= pos < len(self._components):
    return self._components[pos]
  msg = '{.__name__!r} object has no attribute {!r}'
  raise AttributeError(msg.format(cls, name))

 def angle(self, n):
  r = math.sqrt(sum(x * x for x in self[n:]))
  a = math.atan2(r, self[n-1])
  if (n == len(self) - 1 ) and (self[-1] < 0):
   return math.pi * 2 - a
  else:
   return a

 def angles(self):
  return (self.angle(n) for n in range(1, len(self)))

 def __format__(self, fmt_spec=''):
  if fmt_spec.endswith('h'):
   fmt_spec = fmt_spec[:-1]
   coords = itertools.chain([abs(self)], self.angles())
   outer_fmt = '<{}>'
  else:
   coords = self
   outer_fmt = '({})'
  components = (format(c, fmt_spec) for c in coords)
  return outer_fmt.format(', '.join(components))

 @classmethod
 def frombytes(cls, octets):
  typecode = chr(octets[0])
  memv = memoryview(octets[1:]).cast(typecode)
  return cls(memv)

va = Vector([1.0, 2.0, 3.0])
vb = Vector(range(1, 4))
print('va == vb:', va == vb)     #两个具有相同数值分量的 Vector 实例是相等的
t3 = (1, 2, 3)
print('va == t3:', va == t3)

print('[1, 2] == (1, 2):', [1, 2] == (1, 2))

上面代码执行返回的结果为:

va == vb: True
va == t3: True
[1, 2] == (1, 2): False

从 Python 自身来找线索,我们发现 [1,2] == (1, 2) 的结果是False。因此,我们要保守一点,做些类型检查。如果第二个操作数是Vector 实例(或者 Vector 子类的实例),那么就使用 __eq__ 方法的当前逻辑。否则,返回 NotImplemented,让 Python 处理。

? vector_v8.py:改进 Vector 类的 __eq__ 方法

def __eq__(self, other):
   if isinstance(other, Vector):          #判断对比的是否和Vector同属一个实例
    return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
   else:
    return NotImplemented           #否则,返回NotImplemented

改进以后的代码执行结果:

>>> va = Vector([1.0, 2.0, 3.0])
>>> vb = Vector(range(1, 4))
>>> va == vb 
True
>>> t3 = (1, 2, 3)
>>> va == t3
False

增量赋值运算符

Vector 类已经支持增量赋值运算符 += 和 *= 了,示例如下

?  增量赋值不会修改不可变目标,而是新建实例,然后重新绑定

>>> v1 = Vector([1, 2, 3])
>>> v1_alias = v1 




# 复制一份,供后面审查Vector([1, 2, 3])对象
>>> id(v1) 







# 记住一开始绑定给v1的Vector实例的ID
>>> v1 += Vector([4, 5, 6]) 


# 增量加法运算
>>> v1 








 # 结果与预期相符
Vector([5.0, 7.0, 9.0])
>>> id(v1) 







# 但是创建了新的Vector实例
>>> v1_alias 






 # 审查v1_alias,确认原来的Vector实例没被修改
Vector([1.0, 2.0, 3.0])
>>> v1 *= 11 






 # 增量乘法运算
>>> v1 








# 同样,结果与预期相符,但是创建了新的Vector实例
Vector([55.0, 77.0, 99.0])
>>> id(v1)

完整代码:

from array import array
import reprlib
import math
import numbers
import functools
import operator
import itertools


class Vector:
 typecode = 'd'

 def __init__(self, components):
  self._components = array(self.typecode, components)

 def __iter__(self):
  return iter(self._components)

 def __repr__(self):
  components = reprlib.repr(self._components)
  components = components[components.find('['):-1]
  return 'Vector({})'.format(components)

 def __str__(self):
  return str(tuple(self))

 def __bytes__(self):
  return (bytes([ord(self.typecode)]) + bytes(self._components))

 def __eq__(self, other):
  if isinstance(other, Vector):          
   return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
  else:
   return NotImplemented          

 def __hash__(self):
  hashes = map(hash, self._components)
  return functools.reduce(operator.xor, hashes, 0)

 def __add__(self, other):
  pairs = itertools.zip_longest(self, other, fillvalue=0.0)   
  return Vector(a + b for a, b in pairs)        

 def __radd__(self, other):            
  return self + other

 def __mul__(self, scalar):
  if isinstance(scalar, numbers.Real):
   return Vector(n * scalar for n in self)
  else:
   return NotImplemented

 def __rmul__(self, scalar):
  return self * scalar

 def __matmul__(self, other):
  try:
   return sum(a * b for a, b in zip(self, other))
  except TypeError:
   return NotImplemented

 def __rmatmul__(self, other):
  return self @ other

 def __abs__(self):
  return math.sqrt(sum(x * x for x in self))

 def __neg__(self):
  return Vector(-x for x in self)   

 def __pos__(self):
  return Vector(self)       

 def __bool__(self):
  return bool(abs(self))

 def __len__(self):
  return len(self._components)

 def __getitem__(self, index):
  cls = type(self)

  if isinstance(index, slice):
   return cls(self._components[index])
  elif isinstance(index, numbers.Integral):
   return self._components[index]
  else:
   msg = '{.__name__} indices must be integers'
   raise TypeError(msg.format(cls))

 shorcut_names = 'xyzt'

 def __getattr__(self, name):
  cls = type(self)

  if len(name) == 1:
   pos = cls.shorcut_names.find(name)
   if 0 <= pos < len(self._components):
    return self._components[pos]
  msg = '{.__name__!r} object has no attribute {!r}'
  raise AttributeError(msg.format(cls, name))

 def angle(self, n):
  r = math.sqrt(sum(x * x for x in self[n:]))
  a = math.atan2(r, self[n-1])
  if (n == len(self) - 1 ) and (self[-1] < 0):
   return math.pi * 2 - a
  else:
   return a

 def angles(self):
  return (self.angle(n) for n in range(1, len(self)))

 def __format__(self, fmt_spec=''):
  if fmt_spec.endswith('h'):
   fmt_spec = fmt_spec[:-1]
   coords = itertools.chain([abs(self)], self.angles())
   outer_fmt = '<{}>'
  else:
   coords = self
   outer_fmt = '({})'
  components = (format(c, fmt_spec) for c in coords)
  return outer_fmt.format(', '.join(components))

 @classmethod
 def frombytes(cls, octets):
  typecode = chr(octets[0])
  memv = memoryview(octets[1:]).cast(typecode)
  return cls(memv)

总结

以上就是这篇文章的全部内容了,希望本文的内容对大家的学习或者工作能带来一定的帮助,如果有疑问大家可以留言交流,谢谢大家对三水点靠木的支持。

Python 相关文章推荐
在Python中使用__slots__方法的详细教程
Apr 28 Python
利用python发送和接收邮件
Sep 27 Python
python操作 hbase 数据的方法
Dec 18 Python
Python类的动态修改的实例方法
Mar 24 Python
python编程培训 python培训靠谱吗
Jan 17 Python
有关Python的22个编程技巧
Aug 29 Python
Python实现FTP文件传输的实例
Jul 07 Python
python爬虫开发之使用python爬虫库requests,urllib与今日头条搜索功能爬取搜索内容实例
Mar 10 Python
PyQt5+Pycharm安装和配置图文教程详解
Mar 24 Python
keras实现基于孪生网络的图片相似度计算方式
Jun 11 Python
python要安装在哪个盘
Jun 15 Python
pandas抽取行列数据的几种方法
Dec 13 Python
深入学习Python中的上下文管理器与else块
Aug 27 #Python
python利用MethodType绑定方法到类示例代码
Aug 27 #Python
Python中使用haystack实现django全文检索搜索引擎功能
Aug 26 #Python
python读取excel表格生成erlang数据
Aug 26 #Python
使用Python实现简单的服务器功能
Aug 25 #Python
详解Python实现多进程异步事件驱动引擎
Aug 25 #Python
python基础while循环及if判断的实例讲解
Aug 25 #Python
You might like
用PHP开发GUI
2006/10/09 PHP
PHP和XSS跨站攻击的防范
2007/04/17 PHP
smarty 缓存控制前的页面静态化原理
2013/03/15 PHP
PHP删除HTMl标签的三种解决方法
2013/06/30 PHP
php实现指定字符串中查找子字符串的方法
2015/03/17 PHP
PDO的安全处理与事物处理方法
2016/10/31 PHP
php获取手机端的号码以及ip地址实例代码
2018/09/12 PHP
Chrome Form多次提交表单问题的解决方法
2011/05/09 Javascript
jQuery学习笔记 操作jQuery对象 属性处理
2012/09/19 Javascript
jquery获取tr中控件值并操作tr实现思路
2013/03/27 Javascript
jQuery的控件及事件(输入控件及回车事件)使用示例
2013/07/25 Javascript
[JSF]使用DataModel处理表行事件的实例代码
2013/08/05 Javascript
jQuery制作仿腾讯web qq用户体验桌面
2013/08/20 Javascript
jQuery操作JSON的CRUD用法实例
2015/02/25 Javascript
javascript中Date()函数在各浏览器中的显示效果
2015/06/18 Javascript
使用Javascript监控前端相关数据的代码
2016/10/27 Javascript
移动端点击态处理的三种实现方式
2017/01/12 Javascript
JavaScript脚本语言是什么_动力节点Java学院整理
2017/06/26 Javascript
vue任意关系组件通信与跨组件监听状态vue-communication
2020/10/18 Javascript
python自动化测试之从命令行运行测试用例with verbosity
2014/09/28 Python
自己编程中遇到的Python错误和解决方法汇总整理
2015/06/03 Python
Python编程生成随机用户名及密码的方法示例
2017/05/05 Python
django框架如何集成celery进行开发
2017/05/24 Python
Python中正则表达式的用法总结
2019/02/22 Python
python实现树的深度优先遍历与广度优先遍历详解
2019/10/26 Python
python:动态路由的Flask程序代码
2019/11/22 Python
pyqt5实现井字棋的示例代码
2020/12/07 Python
HTML5学习心得总结(推荐)
2016/07/08 HTML / CSS
THE OUTNET英国官网:国际设计师品牌折扣网站
2016/08/14 全球购物
Expedia加拿大官方网站:加拿大最大的在线旅游提供商
2017/12/31 全球购物
Reformation官网:美国女装品牌
2018/09/14 全球购物
Foot Locker英国官网:美国知名运动产品零售商
2019/02/21 全球购物
毕业生教师求职信
2013/10/20 职场文书
库房主管岗位职责
2013/12/31 职场文书
项目考察欢迎辞
2014/01/17 职场文书
交通安全责任书范本
2014/07/24 职场文书