Learnixo
Back to blog
AI Systemsintermediate

Dataclasses: Clean Data Containers for AI

Use Python dataclasses to define structured data without boilerplate: auto-generated __init__, __repr__, __eq__, field defaults, frozen instances, and Pydantic comparison for AI applications.

Asma Hafeez KhanMay 16, 20265 min read
PythonDataclassesPydanticData StructuresType SafetyAI Engineering
Share:š•

The Problem Dataclasses Solve

Storing structured data in plain classes requires writing the same boilerplate every time:

Python
# Without dataclass: lots of repetitive code
class DrugInteraction:
    def __init__(self, drug_a: str, drug_b: str, severity: str, mechanism: str):
        self.drug_a = drug_a
        self.drug_b = drug_b
        self.severity = severity
        self.mechanism = mechanism

    def __repr__(self) -> str:
        return (f"DrugInteraction(drug_a={self.drug_a!r}, drug_b={self.drug_b!r}, "
                f"severity={self.severity!r}, mechanism={self.mechanism!r})")

    def __eq__(self, other) -> bool:
        if not isinstance(other, DrugInteraction):
            return NotImplemented
        return (self.drug_a == other.drug_a and self.drug_b == other.drug_b and
                self.severity == other.severity and self.mechanism == other.mechanism)

@dataclass: Eliminate the Boilerplate

Python
from dataclasses import dataclass, field
from typing import Optional

@dataclass
class DrugInteraction:
    drug_a: str
    drug_b: str
    severity: str
    mechanism: str

# @dataclass auto-generates:
# __init__(self, drug_a, drug_b, severity, mechanism)
# __repr__(self)
# __eq__(self, other)

interaction = DrugInteraction(
    drug_a="warfarin",
    drug_b="aspirin",
    severity="Major",
    mechanism="Additive anticoagulation and antiplatelet effects",
)

print(interaction)
# DrugInteraction(drug_a='warfarin', drug_b='aspirin', severity='Major', mechanism='...')

i2 = DrugInteraction("warfarin", "aspirin", "Major", "Additive anticoagulation and antiplatelet effects")
print(interaction == i2)   # True — __eq__ compares all fields

Default Values

Python
from dataclasses import dataclass, field

@dataclass
class RAGQueryConfig:
    query: str
    k: int = 4                              # Simple default
    min_score: float = 0.7                  # Simple default
    model: str = "gpt-4o"
    filters: dict = field(default_factory=dict)   # Mutable default — must use field()
    tags: list = field(default_factory=list)      # Never use tags: list = [] directly!

config = RAGQueryConfig(query="warfarin interactions")
print(config.k)       # 4
print(config.filters) # {}  — fresh dict per instance


# Non-default fields must come BEFORE fields with defaults
@dataclass
class WrongOrder:
    name: str = "default"  # Has default
    id: int               # No default — ERROR: non-default follows default

field() for Advanced Configuration

Python
from dataclasses import dataclass, field
import time

@dataclass
class LLMCallRecord:
    session_id: str
    question: str
    answer: str
    model: str = "gpt-4o"
    
    # Mutable defaults
    tool_calls: list[str] = field(default_factory=list)
    metadata: dict = field(default_factory=dict)
    
    # Computed at init time (not stored in repr)
    timestamp: float = field(default_factory=time.time, repr=False)
    
    # Excluded from __init__ — set after creation
    latency_ms: float = field(default=0.0, init=False)
    
    # Excluded from __eq__ comparison
    internal_id: str = field(default="", compare=False)


record = LLMCallRecord(
    session_id="sess_001",
    question="What is warfarin?",
    answer="Warfarin is an anticoagulant...",
    tool_calls=["search_drug"],
)
print(record.timestamp)   # Set automatically
print(record.latency_ms)  # 0.0 — default, not in __init__

Frozen Dataclasses: Immutable Records

Python
@dataclass(frozen=True)
class ModelConfig:
    """Immutable configuration — cannot be changed after creation."""
    model: str
    temperature: float
    max_tokens: int

    # Frozen dataclasses are hashable — can be used as dict keys or in sets
    def __hash__(self):
        return hash((self.model, self.temperature, self.max_tokens))


config = ModelConfig(model="gpt-4o", temperature=0.0, max_tokens=500)
# config.temperature = 0.3   # FrozenInstanceError!

# Can be used as dict key
config_cache: dict[ModelConfig, str] = {}
config_cache[config] = "cached_response"

__post_init__: Validation and Derived Fields

Python
@dataclass
class PatientRecord:
    patient_id: str
    age: int
    weight_kg: float
    height_cm: float
    medications: list[str] = field(default_factory=list)

    # Computed fields (not parameters)
    bmi: float = field(init=False, repr=True)
    bmi_category: str = field(init=False, repr=False)

    def __post_init__(self):
        """Called automatically after __init__. Use for validation and derived fields."""
        if self.age < 0 or self.age > 150:
            raise ValueError(f"Invalid age: {self.age}")
        if self.weight_kg <= 0:
            raise ValueError(f"Invalid weight: {self.weight_kg}kg")

        height_m = self.height_cm / 100
        self.bmi = round(self.weight_kg / height_m ** 2, 1)
        self.bmi_category = self._classify_bmi(self.bmi)

    @staticmethod
    def _classify_bmi(bmi: float) -> str:
        if bmi < 18.5:
            return "underweight"
        elif bmi < 25:
            return "normal"
        elif bmi < 30:
            return "overweight"
        return "obese"


patient = PatientRecord(patient_id="P001", age=67, weight_kg=80.0, height_cm=175.0)
print(patient.bmi)           # 26.1
print(patient.bmi_category)  # "overweight"

try:
    PatientRecord(patient_id="P002", age=-5, weight_kg=70.0, height_cm=170.0)
except ValueError as e:
    print(e)   # "Invalid age: -5"

Dataclass vs Pydantic

Both create structured data classes, but for different purposes:

Python
# Dataclass: lightweight, no runtime validation, standard library
from dataclasses import dataclass

@dataclass
class DataclassConfig:
    model: str
    temperature: float = 0.0

config = DataclassConfig(model="gpt-4o", temperature="not_a_float")  # No error!
print(config.temperature)   # "not_a_float" — type hint not enforced


# Pydantic BaseModel: runtime validation, coercion, JSON support
from pydantic import BaseModel, Field, validator

class PydanticConfig(BaseModel):
    model: str
    temperature: float = Field(default=0.0, ge=0.0, le=2.0)

try:
    config = PydanticConfig(model="gpt-4o", temperature="invalid")
except Exception as e:
    print(e)   # ValidationError: temperature must be a float

# Pydantic coerces valid types:
config = PydanticConfig(model="gpt-4o", temperature="0.3")
print(config.temperature)   # 0.3 as float — coerced from string

# Pydantic JSON support
json_str = config.model_dump_json()
config_from_json = PydanticConfig.model_validate_json(json_str)

| Feature | @dataclass | Pydantic BaseModel | |---|---|---| | Auto __init__ | Yes | Yes | | Runtime type validation | No | Yes | | Type coercion | No | Yes | | JSON serialization | Manual | Built-in | | Field validation (ge, le) | No | Yes | | LangChain tools/chains | Use Pydantic | Yes — required | | Performance | Faster | Slightly slower | | When to use | Internal data containers | API inputs, tool schemas, config |


Practical Dataclass Patterns in AI

Python
# 1. Retrieval results
@dataclass
class RetrievalResult:
    query: str
    documents: list = field(default_factory=list)
    scores: list[float] = field(default_factory=list)
    retrieval_time_ms: float = 0.0

    @property
    def top_doc(self):
        return self.documents[0] if self.documents else None


# 2. Evaluation results
@dataclass
class EvaluationResult:
    question: str
    expected_answer: str
    actual_answer: str
    faithfulness: float = 0.0
    relevance: float = 0.0
    correctness: float = 0.0

    @property
    def avg_score(self) -> float:
        return (self.faithfulness + self.relevance + self.correctness) / 3


# 3. Agent run metadata
@dataclass
class AgentRunRecord:
    session_id: str
    question: str
    answer: str = ""
    tool_calls: list[str] = field(default_factory=list)
    latency_ms: float = 0.0
    success: bool = True
    error: str | None = None
    timestamp: float = field(default_factory=__import__("time").time, repr=False)

Enjoyed this article?

Explore the AI Systems learning path for more.

Found this helpful?

Share:š•

Leave a comment

Have a question, correction, or just found this helpful? Leave a note below.