Python Essentials for AI Engineers · Lesson 15 of 36
Lambda Functions and When to Use Them
What is a Lambda?
A lambda is a small anonymous function defined in a single expression. It is equivalent to a def function but without a name, docstring, or multi-line body.
Python
# def version
def square(x):
return x ** 2
# lambda version — same behavior
square = lambda x: x ** 2
print(square(5)) # 25
# Multi-argument lambda
add = lambda a, b: a + b
print(add(3, 4)) # 7
# Lambda with default value
greet = lambda name, greeting="Hello": f"{greeting}, {name}"
print(greet("Pharmacist")) # "Hello, Pharmacist"
print(greet("Pharmacist", "Good morning")) # "Good morning, Pharmacist"Syntax
lambda : - Only one expression — no statements, no if/else blocks (ternary OK), no assignments
- Always returns the value of the expression (implicit return)
- Can have zero or more parameters
Python
# Zero parameters
get_timestamp = lambda: __import__("time").time()
# One parameter
double = lambda x: x * 2
# Multiple parameters
weighted_avg = lambda a, b, w: a * w + b * (1 - w)
# With ternary expression
classify = lambda score: "pass" if score >= 0.7 else "fail"
print(classify(0.85)) # "pass"
print(classify(0.60)) # "fail"Common Uses
Sorting with a Key
Python
drugs = [
{"name": "warfarin", "dose_mg": 5},
{"name": "metformin", "dose_mg": 500},
{"name": "aspirin", "dose_mg": 81},
{"name": "lisinopril","dose_mg": 10},
]
# Sort by dose (ascending)
sorted_by_dose = sorted(drugs, key=lambda d: d["dose_mg"])
# [warfarin 5, aspirin 81, lisinopril 10, ...]
# Sort by name length (descending)
sorted_by_name_length = sorted(drugs, key=lambda d: len(d["name"]), reverse=True)
# Sort by multiple fields: first by dose, then by name
sorted_multi = sorted(drugs, key=lambda d: (d["dose_mg"], d["name"]))map() and filter()
Python
scores = [0.92, 0.68, 0.85, 0.45, 0.77]
# map: apply function to each element
rounded = list(map(lambda s: round(s, 1), scores)) # [0.9, 0.7, 0.9, 0.5, 0.8]
# filter: keep elements where function returns True
passing = list(filter(lambda s: s >= 0.7, scores)) # [0.92, 0.85, 0.77]
# Note: list comprehensions are usually more readable
rounded = [round(s, 1) for s in scores]
passing = [s for s in scores if s >= 0.7]Lambda vs def: When to Use Each
Python
# Use lambda: short, one-expression transformations used inline once
drugs.sort(key=lambda d: d["dose_mg"])
# Use def: anything with logic, multiple steps, or reuse
def sort_key(drug: dict) -> tuple:
"""Sort by dose, then alphabetically."""
return (drug["dose_mg"], drug["name"])
drugs.sort(key=sort_key) # Reusable, testable, documented
# Use def if:
# - You need a docstring
# - The body has more than one expression
# - You'll call it in more than one place
# - A reader would benefit from a name
# Use lambda if:
# - It's a one-off key function for sort/min/max
# - It's used inline and the logic fits in one expression
# - You're building a small transform for map/filterLambdas in LangChain (RunnableLambda)
LangChain's RunnableLambda wraps a function as a pipeline step:
Python
from langchain_core.runnables import RunnableLambda
# Wrap a lambda as a Runnable step in an LCEL chain
uppercase_step = RunnableLambda(lambda text: text.upper())
trim_step = RunnableLambda(lambda text: text.strip())
# Pipeline: trim → uppercase
pipeline = trim_step | uppercase_step
print(pipeline.invoke(" warfarin aspirin interaction "))
# "WARFARIN ASPIRIN INTERACTION"
# More useful: transforming structured data between chain steps
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
# Format retrieved docs for the prompt
format_docs = RunnableLambda(
lambda docs: "\n\n".join(f"[{i+1}] {d.page_content}" for i, d in enumerate(docs))
)
rag_chain = (
{"context": retriever | format_docs, "question": lambda x: x}
| prompt
| model
| StrOutputParser()
)Closures: Lambdas Capturing Outer Variables
Lambdas capture variables from their enclosing scope (closure):
Python
def make_threshold_filter(threshold: float):
"""Return a filter function for a specific threshold."""
return lambda score: score >= threshold # Captures 'threshold'
passing_70 = make_threshold_filter(0.70)
passing_90 = make_threshold_filter(0.90)
scores = [0.65, 0.72, 0.85, 0.91, 0.55]
print(list(filter(passing_70, scores))) # [0.72, 0.85, 0.91]
print(list(filter(passing_90, scores))) # [0.91]
# Common closure gotcha: loop variable capture
# This is a bug:
filters = [lambda x: x > i for i in [1, 2, 3]] # All lambdas capture the SAME i
print([f(3) for f in filters]) # [False, False, False] — all see i=3
# Fix: capture the value at lambda creation time
filters = [lambda x, i=i: x > i for i in [1, 2, 3]] # Default arg freezes value
print([f(3) for f in filters]) # [True, True, False]functools Utilities (Related to Lambdas)
Python
from functools import partial, reduce
# partial: pre-fill arguments to create a new callable
from langchain_openai import ChatOpenAI
model = ChatOpenAI(model="gpt-4o", temperature=0)
# Suppose invoke_model(model, prompt) is a function you want to curry
def invoke_model(m, prompt: str) -> str:
return m.invoke(prompt).content
clinical_model = partial(invoke_model, model) # First arg pre-filled
print(clinical_model("What is warfarin?")) # Same as invoke_model(model, "What is warfarin?")
# reduce: fold a sequence into a single value
from functools import reduce
scores = [0.92, 0.85, 0.77, 0.91]
product = reduce(lambda acc, x: acc * x, scores, 1.0)
print(product) # 0.92 * 0.85 * 0.77 * 0.91 ≈ 0.548
# Usually more readable with sum/max/min or a for loopWhen Not to Use Lambda
Python
# Don't use lambda where a method reference works
# Instead of:
names = ["warfarin", "aspirin", "METFORMIN"]
lower = list(map(lambda s: s.lower(), names))
# Use:
lower = list(map(str.lower, names)) # str.lower is already a callable
# Don't assign lambdas to names — just use def
# Instead of:
normalize = lambda x, lo, hi: (x - lo) / (hi - lo) # PEP 8 discourages this
# Use:
def normalize(x: float, lo: float, hi: float) -> float:
return (x - lo) / (hi - lo)
# Don't use lambda for complex logic
# Instead of:
classify = lambda s, h, l: "high" if s > h else ("low" if s < l else "normal")
# Use:
def classify(score: float, high: float = 0.9, low: float = 0.5) -> str:
if score > high:
return "high"
if score < low:
return "low"
return "normal"