How to use map, reduce and filter in Python

Last updated on March 31, 2021 by Aqsa Mustafa

In Python, functions are treated no different than regular objects like numbers and strings. You can assign a function to a variable and store it inside a data structure. You can pass a function to another function as one of its parameters. You can even define a function inside another function. Such functional programming approach in Python can be best illustrated by built-in functions called map(), filter(), and reduce().

In this tutorial, we will see how to use map(), filter(), and reduce() in Python. While their functionalities can be equally achieved with list comprehension or explicit for-loops, these functions allow us to write short, simple, and concise code. For example, a problem that can be solved with 4-5 lines of code using an explicit for-loop can be done in 1-2 code lines using these functions. So, without further ado, let's get started and see each one of them with examples.

Map Function

map() is a built-in Python function, thus we don't need to import it from any module. It takes iterable(s) and a function, invokes the given function on each item of iterable(s), and saves the result. The map() function returns an iterator. You can convert it into a list or a tuple by passing the return value as an argument to list() or tuple(), respectively. Its syntax is:

map(func, iterable1,...)

As you can see, you can pass more than one iterable to it as well. Let's take an example to understand the working of the map() function.

Let's say we have a list of numbers, and we want to calculate the square root of each value and store the output in another list. We can do that using a for loop, e.g.,:

import math
list_1 = [2, 4, 5, 7, 1, 4, 5, 8]
result = []
for num in list_1:
    result.append(math.sqrt(num))

print("Square root of numbers in a list")
print(result)

Output:

Square root of numbers in a list
[1.4142135623730951, 2.0, 2.23606797749979, 2.6457513110645907, 1.0, 2.0, 2.23606797749979, 2.8284271247461903]

We can also solve the above problem using the map() function, which will be more concise. Let’s see.

import math
list_1 = [2, 4, 5, 7, 1, 4, 5, 8]
result = list(map(lambda x:math.sqrt(x), list_1))

print("Square root of numbers in a list")
print(result)

Output:

Square root of numbers in a list
[1.4142135623730951, 2.0, 2.23606797749979, 2.6457513110645907, 1.0, 2.0, 2.23606797749979, 2.8284271247461903]

In the above example, we pass an anonymous function (or lambda function) that takes a number and calculates its square root and list_1 (an iterable) as arguments to map(). It calls the given anonymous function for each value of list_1 and saves the result. For example, first, it will get invoked with the value 2. It will calculate the root and return 1.414. Then, 4 will be passed to it, and so on. Thus, the output of map() will contain the same number of elements as the input iterable, and it will be a map object. But, since we want a list, we will convert it to a list using the list() function.

As already mentioned, we can pass any number of iterables to the map() function. Let's take an example to see how we will do that. Let's say we have three lists containing first names, middle and last names. We want to combine all three of them in a single list, i.e., each list item should have the full name now. With the map() function, it is a breeze. Let’s see.

first = ["Joseph", "Benjamin", "Joseph", "Eoin", "Jason"]
middle = ["Edward", "Andrew", "Charles", "Joseph", "Jonathan"]
last = ["Root", "Stokes", "Buttler", "Morgan", "Roy"]

full_names = list(map(lambda f, m, l: f"{f} {m} {l}", first, middle, last))
print("Full names of England Cricket Players")
print(full_names)

Output:

Full names of England Cricket Players
['Joseph Edward Root', 'Benjamin Andrew Stokes', 'Joseph Charles Buttler', 'Eoin Joseph Morgan', 'Jason Jonathan Roy']

In the above code, we pass three iterables (lists) to the map() function. Therefore, the anonymous function has three parameters, i.e., first, middle and last name. It gets invoked for each item of iterables and joins the three arguments with a space using the f-string formatting to get the full name. For example, first, the anonymous function will be called with Joseph, Edward, and Root as arguments, and it will return Joseph Edward Root. Then, it will get invoked with Benjamin, Andrew, and Stokes as inputs, and so on.

What if the length of iterables passed is not equal? For example, iterable_2 has more or less number of items than iterable_1. In this case, map() will stop when the shorter iterable gets exhausted. In other words, the output length will be equal to the length of the shortest iterable. Let’s take an example.

#calculating area of rectangles
length = [14, 25, 67, 18, 69]
width = [4, 6, 2, 55, 56, 66, 5]

area = list(map(lambda l, w: l*w, length, width))

print("Area of Rectangle")
print(area)

Output:

Area of Rectangle
[56, 150, 134, 990, 3864]

In the above example, the sizes of length and width lists are 5 and 7, respectively. Since the map() function terminates when any list gets exhausted, the output area list contains five elements because, after that number of items, the length list gets consumed.

Filter Function

The filter() function is a built-in Python function. Similar to map(), filter() also takes a function and an iterable. The function that is passed to filter(), however, needs to return a boolean value. filter() passes each element of an iterable through the given function and checks for a condition there. If that condition satisfies, i.e., the function returns True, that item gets added to the result list. Otherwise, the filter() function discards it. Therefore, the length of the output is less than or equal to the input iterable's length. Its syntax is:

filter(func, iterable)

Let's take an example, where we have a list that contains ages of different people, and we want to filter out those ages that are less than 18. Using loops, we can solve the given problem in the following way.

age = [12, 18, 20, 19, 22, 15, 17, 19]
result = []
for x in age:
    if x >= 18:
        result.append(x)

print("Age greater than or equal to 18")
print(result)

Output:

Age greater than or equal to 18
[18, 20, 19, 22, 19]

With the filter() function, we can solve it in a much easier and faster way. Let's see.

age = [12, 18, 20, 19, 22, 15, 17, 19]
result =  list(filter(lambda x:x>=18, age))

print("Age greater than or equal to 18")
print(result)

Output:

Age greater than or equal to 18
[18, 20, 19, 22, 19]

In the above code, we pass an anonymous function that takes a single argument and checks if it is greater than or equal to 18 to the filter() function. It returns True if it is. Otherwise, it returns False. Moreover, we pass the age list as a second argument. As you can see in the output, values that are less than 18 get discarded.

If we do not provide any function to filter(), then it will simply return the passed iterable. Let's see.

age = [12, 18, 20, 19, 22, 15, 17, 19]
result =  list(filter(None, age))

print(result)

Output:

[12, 18, 20, 19, 22, 15, 17, 19]

Reduce Function

While both the filter() and map() functions return an iterator, reduce() is a bit different, and it outputs a single value. Simply put, it reduces or folds all values of an iterable to a single value. reduce() takes a two-argument function and an iterable. Moreover, it works differently than map() and filter().

The reduce() function also takes an optional initializer argument, which is a seed value. If passed, then the callback function is initially called with the initializer value and the first element. It is applied, and the result gets stored. Next time, the function gets invoked with the saved result and the second list element. After that, the process continues, as explained above. You can notice that, in this case, the reduce() function runs one iteration more.

The syntax of the reduce() function is:

reduce(func, iterable[, initializer])

Consider the following example.

from functools import reduce

numbers = [1, 5, 6, 7, 9, 10, 12, 45]
result = reduce(lambda a, b: a+b, numbers)

print("The sum of the list items is:", result)

Output:

The sum of the list items is: 95

First, we import the reduce() function from the functools module since it is not a built-in function. We have a numbers list, and we want to calculate the sum of all of its values. For that, we use the reduce() function. Here, the callback function takes two arguments and returns its sum. Initially, it gets called with values 1 and 5, the result obtained will be 6. Then, the function gets invoked with 6 and 6, the result stored will be 12, and so on. Finally, the return value will be 95.

Let's now do the same example, except we also pass 15 as an initializer value.

from functools import reduce

numbers = [1, 5, 6, 7, 9, 10, 12, 45]
result = reduce(lambda a, b: a+b, numbers, 15)

print("The sum of the list items is:", result)

Output:

The sum of the list items is: 110

Conclusion

That's it for map(), filter(), and reduce(). In this tutorial, we saw what these functions are, how they work, and the differences among them. These functions are quite powerful, and as you have seen in this tutorial, they allow you to write simpler and less verbose code with fewer lines of code. While some folks prefer list comprehension and explicit loops for readability and speed, which one to choose often comes down to personal taste and coding style. Feel free to share your thought on this topic.

Support Xmodulo

This website is made possible by minimal ads and your gracious donation via PayPal or credit card

Please note that this article is published by Xmodulo.com under a Creative Commons Attribution-ShareAlike 3.0 Unported License. If you would like to use the whole or any part of this article, you need to cite this web page at Xmodulo.com as the original source.

Xmodulo © 2021 ‒ AboutWrite for Us ‒ Feed ‒ Powered by DigitalOcean