with 语句和上下文管理器

# create/aquire some resource
...
try:
    # do something with the resource
    ...
finally:
    # destroy/release the resource
    ...

处理文件,线程,数据库,网络编程等等资源的时候,我们经常需要使用上面这样的代码形式,以确保资源的正常使用和释放。

好在Python 提供了 with 语句帮我们自动进行这样的处理,例如之前在打开文件时我们使用:

In [1]:
with open('my_file', 'w') as fp:
    # do stuff with fp
    data = fp.write("Hello world")

这等效于下面的代码,但是要更简便:

In [2]:
fp = open('my_file', 'w')
try:
    # do stuff with f
    data = fp.write("Hello world")
finally:
    fp.close()

上下文管理器

其基本用法如下:

with <expression>:
    <block>

<expression> 执行的结果应当返回一个实现了上下文管理器的对象,即实现这样两个方法,__enter____exit__

In [3]:
print fp.__enter__
print fp.__exit__
<built-in method __enter__ of file object at 0x0000000003C1D540>
<built-in method __exit__ of file object at 0x0000000003C1D540>

__enter__ 方法在 <block> 执行前执行,而 __exit__<block> 执行结束后执行:

比如可以这样定义一个简单的上下文管理器:

In [4]:
class ContextManager(object):
    
    def __enter__(self):
        print "Entering"
    
    def __exit__(self, exc_type, exc_value, traceback):
        print "Exiting"

使用 with 语句执行:

In [5]:
with ContextManager():
    print "  Inside the with statement"
Entering
  Inside the with statement
Exiting

即使 <block> 中执行的内容出错,__exit__ 也会被执行:

In [6]:
with ContextManager():
    print 1/0
Entering
Exiting
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
<ipython-input-6-b509c97cb388> in <module>()
      1 with ContextManager():
----> 2     print 1/0

ZeroDivisionError: integer division or modulo by zero

__enter__ 的返回值

如果在 __enter__ 方法下添加了返回值,那么我们可以使用 as 把这个返回值传给某个参数:

In [7]:
class ContextManager(object):
    
    def __enter__(self):
        print "Entering"
        return "my value"
    
    def __exit__(self, exc_type, exc_value, traceback):
        print "Exiting"

__enter__ 返回的值传给 value 变量:

In [8]:
with ContextManager() as value:
    print value
Entering
my value
Exiting

一个通常的做法是将 __enter__ 的返回值设为这个上下文管理器对象本身,文件对象就是这样做的:

In [9]:
fp = open('my_file', 'r')
print fp.__enter__()
fp.close()
<open file 'my_file', mode 'r' at 0x0000000003B63030>
In [10]:
import os
os.remove('my_file')

实现方法非常简单:

In [11]:
class ContextManager(object):
    
    def __enter__(self):
        print "Entering"
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):
        print "Exiting"
In [12]:
with ContextManager() as value:
    print value
Entering
<__main__.ContextManager object at 0x0000000003D48828>
Exiting

错误处理

上下文管理器对象将错误处理交给 __exit__ 进行,可以将错误类型,错误值和 traceback 等内容作为参数传递给 __exit__ 函数:

In [13]:
class ContextManager(object):
    
    def __enter__(self):
        print "Entering"
    
    def __exit__(self, exc_type, exc_value, traceback):
        print "Exiting"
        if exc_type is not None:
            print "  Exception:", exc_value

如果没有错误,这些值都将是 None, 当有错误发生的时候:

In [14]:
with ContextManager():
    print 1/0
Entering
Exiting
  Exception: integer division or modulo by zero
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
<ipython-input-14-b509c97cb388> in <module>()
      1 with ContextManager():
----> 2     print 1/0

ZeroDivisionError: integer division or modulo by zero

在这个例子中,我们只是简单的显示了错误的值,并没有对错误进行处理,所以错误被向上抛出了,如果不想让错误抛出,只需要将 __exit__ 的返回值设为 True

In [15]:
class ContextManager(object):
    
    def __enter__(self):
        print "Entering"
    
    def __exit__(self, exc_type, exc_value, traceback):
        print "Exiting"
        if exc_type is not None:
            print " Exception suppresed:", exc_value
            return True
In [16]:
with ContextManager():
    print 1/0
Entering
Exiting
 Exception suppresed: integer division or modulo by zero

在这种情况下,错误就不会被向上抛出。

数据库的例子

对于数据库的 transaction 来说,如果没有错误,我们就将其 commit 进行保存,如果有错误,那么我们将其回滚到上一次成功的状态。

In [17]:
class Transaction(object):
    
    def __init__(self, connection):
        self.connection = connection
    
    def __enter__(self):
        return self.connection.cursor()
    
    def __exit__(self, exc_type, exc_value, traceback):
        if exc_value is None:
            # transaction was OK, so commit
            self.connection.commit()
        else:
            # transaction had a problem, so rollback
            self.connection.rollback()

建立一个数据库,保存一个地址表:

In [18]:
import sqlite3 as db
connection = db.connect(":memory:")

with Transaction(connection) as cursor:
    cursor.execute("""CREATE TABLE IF NOT EXISTS addresses (
        address_id INTEGER PRIMARY KEY,
        street_address TEXT,
        city TEXT,
        state TEXT,
        country TEXT,
        postal_code TEXT
    )""")

插入数据:

In [19]:
with Transaction(connection) as cursor:
    cursor.executemany("""INSERT OR REPLACE INTO addresses VALUES (?, ?, ?, ?, ?, ?)""", [
        (0, '515 Congress Ave', 'Austin', 'Texas', 'USA', '78701'),
        (1, '245 Park Avenue', 'New York', 'New York', 'USA', '10167'),
        (2, '21 J.J. Thompson Ave.', 'Cambridge', None, 'UK', 'CB3 0FA'),
        (3, 'Supreme Business Park', 'Hiranandani Gardens, Powai, Mumbai', 'Maharashtra', 'India', '400076'),
    ])

假设插入数据之后出现了问题:

In [20]:
with Transaction(connection) as cursor:
    cursor.execute("""INSERT OR REPLACE INTO addresses VALUES (?, ?, ?, ?, ?, ?)""",
        (4, '2100 Pennsylvania Ave', 'Washington', 'DC', 'USA', '78701'),
    )
    raise Exception("out of addresses")
---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
<ipython-input-20-ed8abdd56558> in <module>()
      3         (4, '2100 Pennsylvania Ave', 'Washington', 'DC', 'USA', '78701'),
      4     )
----> 5     raise Exception("out of addresses")

Exception: out of addresses

那么最新的一次插入将不会被保存,而是返回上一次 commit 成功的状态:

In [21]:
cursor.execute("SELECT * FROM addresses")
for row in cursor:
    print row
(0, u'515 Congress Ave', u'Austin', u'Texas', u'USA', u'78701')
(1, u'245 Park Avenue', u'New York', u'New York', u'USA', u'10167')
(2, u'21 J.J. Thompson Ave.', u'Cambridge', None, u'UK', u'CB3 0FA')
(3, u'Supreme Business Park', u'Hiranandani Gardens, Powai, Mumbai', u'Maharashtra', u'India', u'400076')

contextlib 模块

很多的上下文管理器有很多相似的地方,为了防止写入很多重复的模式,可以使用 contextlib 模块来进行处理。

最简单的处理方式是使用 closing 函数确保对象的 close() 方法始终被调用:

In [23]:
from contextlib import closing
import urllib

with closing(urllib.urlopen('http://www.baidu.com')) as url:
    html = url.read()

print html[:100]
<!DOCTYPE html><!--STATUS OK--><html><head><meta http-equiv="content-type" content="text/html;charse

另一个有用的方法是使用修饰符 @contextlib

In [24]:
from contextlib import contextmanager

@contextmanager
def my_contextmanager():
    print "Enter"
    yield
    print "Exit"

with my_contextmanager():
    print "  Inside the with statement"
Enter
  Inside the with statement
Exit

yield 之前的部分可以看成是 __enter__ 的部分,yield 的值可以看成是 __enter__ 返回的值,yield 之后的部分可以看成是 __exit__ 的部分。

使用 yield 的值:

In [25]:
@contextmanager
def my_contextmanager():
    print "Enter"
    yield "my value"
    print "Exit"
    
with my_contextmanager() as value:
    print value
Enter
my value
Exit

错误处理可以用 try 块来完成:

In [26]:
@contextmanager
def my_contextmanager():
    print "Enter"
    try:
        yield
    except Exception as exc:
        print "   Error:", exc
    finally:
        print "Exit"
In [27]:
with my_contextmanager():
    print 1/0
Enter
   Error: integer division or modulo by zero
Exit

对于之前的数据库 transaction 我们可以这样定义:

In [28]:
@contextmanager
def transaction(connection):
    cursor = connection.cursor()
    try:
        yield cursor
    except:
        connection.rollback()
        raise
    else:
        connection.commit()