11.8. Functional Currying

  • functools.partial()

  • functools.partialmethod()

One of the most commonly used functions in the functools module is partial(), which allows you to create a new function with some of the arguments of an existing function already set. This can be useful in situations where you need to repeatedly call a function with the same arguments, but don't want to keep typing them out.

11.8.1. SetUp

>>> from functools import partial

11.8.2. Problem

>>> def add(a, b):
...     return a + b
>>>
>>> data = (1, 2, 3, 4, 5)
>>> result = map(lambda x: add(x,10), data)
>>>
>>> tuple(result)
(11, 12, 13, 14, 15)

11.8.3. Solution

>>> def add(a, b):
...     return a + b
>>>
>>> data = (1, 2, 3, 4, 5)
>>> add10 = partial(add, b=10)
>>> result = map(add10, data)
>>>
>>> tuple(result)
(11, 12, 13, 14, 15)

11.8.4. Partial

  • Create alias function and its arguments

  • Useful when you need to pass function with arguments to for example map or filter

>>> from functools import partial
>>>
>>>
>>> basetwo = partial(int, base=2)
>>> basetwo.__doc__ = 'Convert base 2 string to an int.'
>>> basetwo('10010')
18

11.8.5. Partialmethod

>>> from functools import partialmethod
>>>
>>>
>>> class Cell(object):
...     def __init__(self):
...         self._alive = False
...
...     @property
...     def alive(self):
...         return self._alive
...
...     def set_state(self, state):
...         self._alive = bool(state)
...
...     set_alive = partialmethod(set_state, True)
...     set_dead = partialmethod(set_state, False)
>>>
>>>
>>> c = Cell()
>>>
>>> c.alive
False
>>>
>>> c.set_alive()
>>> c.alive
True

11.8.6. Case Study

  • pandas.read_csv()

  • We often use the same arguments for reading CSV files

>>> def read_csv(filepath_or_buffer, sep=', ', delimiter=None, header='infer',
...              names=None, index_col=None, usecols=None, squeeze=False,
...              prefix=None, mangle_dupe_cols=True, dtype=None, engine=None,
...              converters=None, true_values=None, false_values=None,
...              skipinitialspace=False, skiprows=None, nrows=None,
...              na_values=None, keep_default_na=True, na_filter=True,
...              verbose=False, skip_blank_lines=True, parse_dates=False,
...              infer_datetime_format=False, keep_date_col=False,
...              date_parser=None, dayfirst=False, iterator=False,
...              chunksize=None, compression='infer', thousands=None,
...              decimal=b'.', lineterminator=None, quotechar='"',
...              quoting=0, escapechar=None, comment=None, encoding=None,
...              dialect=None, tupleize_cols=None, error_bad_lines=True,
...              warn_bad_lines=True, skipfooter=0, doublequote=True,
...              delim_whitespace=False, low_memory=True, memory_map=False,
...              float_precision=None): ...

Using partial to create a new function myread_csv with some arguments already set:

>>> from functools import partial
>>>
>>>
>>> myread_csv = partial(read_csv, sep=';', encoding='utf-8', decimal=',', thousands=' ')
>>>
>>> a = myread_csv('myfile1.csv')
>>> b = myread_csv('myfile2.csv')
>>> c = myread_csv('myfile3.csv')

Is equivalent to:

>>> def myread_csv2(filepath_or_buffer, sep=';', encoding='utf-8', decimal=',', thousands=' '):
...     return read_csv(filepath_or_buffer, sep=sep, encoding=encoding, decimal=decimal, thousands=thousands)
>>>
>>> a = myread_csv('myfile1.csv')
>>> b = myread_csv('myfile2.csv')
>>> c = myread_csv('myfile3.csv')

11.8.7. Use Case - 1

  • We want to round to two decimal places

Problem:

>>> data = (1.1111, 2.2222, 3.3333, 4.4444)
>>> result = map(round, data)
>>>
>>> print(tuple(result))
(1, 2, 3, 4)

Function:

>>> def round2(x):
...     return round(x, ndigits=2)
>>>
>>> data = (1.1111, 2.2222, 3.3333, 4.4444)
>>> result = map(round2, data)
>>>
>>> print(tuple(result))
(1.11, 2.22, 3.33, 4.44)

Lambda:

>>> data = (1.1111, 2.2222, 3.3333, 4.4444)
>>> result = map(lambda x: round(x, ndigits=2), data)
>>>
>>> print(tuple(result))
(1.11, 2.22, 3.33, 4.44)

Partial:

>>> from functools import partial
>>>
>>> round2 = partial(round, ndigits=2)
>>> result = map(round2, data)
>>>
>>> print(tuple(result))
(1.11, 2.22, 3.33, 4.44)

11.8.8. Use Case - 2

>>>
... from functools import partial
... import pandas as pd
...
... plot = partial(pd.DataFrame.plot, kind='line', xlabel='time', ylabel='value', title='value in time')
...
... plot(df.temperature)
... plot(df.humidity)
... plot(df.co2)
... plot(df.noise)