11.8. Functional Currying
functools.partial()
functools.partialmethod()
One of the most commonly used functions in the functools module is
partial()
, which allows you to create a new function with some of the
arguments of an existing function already set. This can be useful in
situations where you need to repeatedly call a function with the same
arguments, but don't want to keep typing them out.
11.8.1. SetUp
>>> from functools import partial
11.8.2. Problem
>>> def add(a, b):
... return a + b
>>>
>>> data = (1, 2, 3, 4, 5)
>>> result = map(lambda x: add(x,10), data)
>>>
>>> tuple(result)
(11, 12, 13, 14, 15)
11.8.3. Solution
>>> def add(a, b):
... return a + b
>>>
>>> data = (1, 2, 3, 4, 5)
>>> add10 = partial(add, b=10)
>>> result = map(add10, data)
>>>
>>> tuple(result)
(11, 12, 13, 14, 15)
11.8.4. Partial
Create alias function and its arguments
Useful when you need to pass function with arguments to for example
map
orfilter
>>> from functools import partial
>>>
>>>
>>> basetwo = partial(int, base=2)
>>> basetwo.__doc__ = 'Convert base 2 string to an int.'
>>> basetwo('10010')
18
11.8.5. Partialmethod
>>> from functools import partialmethod
>>>
>>>
>>> class Cell(object):
... def __init__(self):
... self._alive = False
...
... @property
... def alive(self):
... return self._alive
...
... def set_state(self, state):
... self._alive = bool(state)
...
... set_alive = partialmethod(set_state, True)
... set_dead = partialmethod(set_state, False)
>>>
>>>
>>> c = Cell()
>>>
>>> c.alive
False
>>>
>>> c.set_alive()
>>> c.alive
True
11.8.6. Case Study
pandas.read_csv()
We often use the same arguments for reading CSV files
>>> def read_csv(filepath_or_buffer, sep=', ', delimiter=None, header='infer',
... names=None, index_col=None, usecols=None, squeeze=False,
... prefix=None, mangle_dupe_cols=True, dtype=None, engine=None,
... converters=None, true_values=None, false_values=None,
... skipinitialspace=False, skiprows=None, nrows=None,
... na_values=None, keep_default_na=True, na_filter=True,
... verbose=False, skip_blank_lines=True, parse_dates=False,
... infer_datetime_format=False, keep_date_col=False,
... date_parser=None, dayfirst=False, iterator=False,
... chunksize=None, compression='infer', thousands=None,
... decimal=b'.', lineterminator=None, quotechar='"',
... quoting=0, escapechar=None, comment=None, encoding=None,
... dialect=None, tupleize_cols=None, error_bad_lines=True,
... warn_bad_lines=True, skipfooter=0, doublequote=True,
... delim_whitespace=False, low_memory=True, memory_map=False,
... float_precision=None): ...
Using partial
to create a new function myread_csv
with some
arguments already set:
>>> from functools import partial
>>>
>>>
>>> myread_csv = partial(read_csv, sep=';', encoding='utf-8', decimal=',', thousands=' ')
>>>
>>> a = myread_csv('myfile1.csv')
>>> b = myread_csv('myfile2.csv')
>>> c = myread_csv('myfile3.csv')
Is equivalent to:
>>> def myread_csv2(filepath_or_buffer, sep=';', encoding='utf-8', decimal=',', thousands=' '):
... return read_csv(filepath_or_buffer, sep=sep, encoding=encoding, decimal=decimal, thousands=thousands)
>>>
>>> a = myread_csv('myfile1.csv')
>>> b = myread_csv('myfile2.csv')
>>> c = myread_csv('myfile3.csv')
11.8.7. Use Case - 1
We want to round to two decimal places
Problem:
>>> data = (1.1111, 2.2222, 3.3333, 4.4444)
>>> result = map(round, data)
>>>
>>> print(tuple(result))
(1, 2, 3, 4)
Function:
>>> def round2(x):
... return round(x, ndigits=2)
>>>
>>> data = (1.1111, 2.2222, 3.3333, 4.4444)
>>> result = map(round2, data)
>>>
>>> print(tuple(result))
(1.11, 2.22, 3.33, 4.44)
Lambda:
>>> data = (1.1111, 2.2222, 3.3333, 4.4444)
>>> result = map(lambda x: round(x, ndigits=2), data)
>>>
>>> print(tuple(result))
(1.11, 2.22, 3.33, 4.44)
Partial:
>>> from functools import partial
>>>
>>> round2 = partial(round, ndigits=2)
>>> result = map(round2, data)
>>>
>>> print(tuple(result))
(1.11, 2.22, 3.33, 4.44)
11.8.8. Use Case - 2
>>>
... from functools import partial
... import pandas as pd
...
... plot = partial(pd.DataFrame.plot, kind='line', xlabel='time', ylabel='value', title='value in time')
...
... plot(df.temperature)
... plot(df.humidity)
... plot(df.co2)
... plot(df.noise)