Source code for lrs.core.precision

"""Precision tracking for Active Inference agents."""

from dataclasses import dataclass, field
from typing import Dict


[docs] @dataclass class PrecisionParameters: """ Precision parameters using Beta distribution. Precision γ = α/(α+β) represents confidence in predictions. Args: alpha: Success parameter (default: 1.0) beta: Failure parameter (default: 1.0) gain_learning_rate: Learning rate for successes (default: 0.1) loss_learning_rate: Learning rate for failures (default: 0.2) adaptation_threshold: Threshold below which adaptation triggers (default: 0.4) Example: >>> precision = PrecisionParameters() >>> precision.value # 0.5 (maximum uncertainty) >>> precision.update(prediction_error=0.1) # Success >>> precision.value # ~0.52 (slight increase) """ alpha: float = 1.0 beta: float = 1.0 gain_learning_rate: float = 0.1 loss_learning_rate: float = 0.2 adaptation_threshold: float = 0.4 @property def value(self) -> float: """ Get current precision value γ = α/(α+β). Returns: Precision in [0,1] """ return self.alpha / (self.alpha + self.beta) @property def variance(self) -> float: """ Get variance of Beta distribution. Returns: Variance of precision estimate """ a = self.alpha b = self.beta return (a * b) / ((a + b) ** 2 * (a + b + 1))
[docs] def update(self, prediction_error: float) -> None: """ Update precision based on prediction error. Uses asymmetric learning rates: - Low error (success) → small increase in α - High error (failure) → larger increase in β Args: prediction_error: Prediction error δ ∈ [0,1] Example: >>> precision = PrecisionParameters() >>> precision.update(0.1) # Success >>> precision.update(0.9) # Failure """ # Inverse error = success signal inverse_error = 1.0 - prediction_error # Asymmetric updates self.alpha += self.gain_learning_rate * inverse_error self.beta += self.loss_learning_rate * prediction_error
[docs] def should_adapt(self) -> bool: """ Check if precision is below adaptation threshold. Returns: True if adaptation should be triggered """ return self.value < self.adaptation_threshold
[docs] def reset(self) -> None: """Reset to initial uniform prior.""" self.alpha = 1.0 self.beta = 1.0
[docs] def get_all(self) -> Dict[str, float]: """ Get all precision statistics. Returns: Dictionary with value, alpha, beta, variance """ return { "value": self.value, "alpha": self.alpha, "beta": self.beta, "variance": self.variance, }
# Backward compatibility properties @property def learning_rate_gain(self) -> float: """Alias for gain_learning_rate.""" return self.gain_learning_rate @property def learning_rate_loss(self) -> float: """Alias for loss_learning_rate.""" return self.loss_learning_rate @property def threshold(self) -> float: """Alias for adaptation_threshold.""" return self.adaptation_threshold
[docs] def get_all_values(self) -> Dict[str, float]: """Alias for get_all().""" return self.get_all()
[docs] @dataclass class HierarchicalPrecision: """ Hierarchical precision tracking across abstraction levels. Precision is maintained at three levels: - Abstract: Long-term goals and strategies - Planning: Policy sequences - Execution: Individual tool calls Errors propagate upward with attenuation. Example: >>> hp = HierarchicalPrecision() >>> hp.update('execution', 0.9) # High error >>> hp.execution # Decreased >>> hp.planning # Also decreased (propagation) """ _abstract: PrecisionParameters = field(default_factory=PrecisionParameters) _planning: PrecisionParameters = field(default_factory=PrecisionParameters) _execution: PrecisionParameters = field(default_factory=PrecisionParameters) propagation_threshold: float = 0.7 attenuation_factor: float = 0.5 # Properties that return float values (for convenience) @property def abstract(self) -> float: """Get abstract level precision value.""" return self._abstract.value @property def planning(self) -> float: """Get planning level precision value.""" return self._planning.value @property def execution(self) -> float: """Get execution level precision value.""" return self._execution.value
[docs] def get_level(self, level: str) -> PrecisionParameters: """ Get PrecisionParameters object for a specific level. Args: level: One of 'abstract', 'planning', or 'execution' Returns: PrecisionParameters object for that level Example: >>> hp = HierarchicalPrecision() >>> exec_params = hp.get_level('execution') >>> exec_params.value # 0.5 """ if level == "abstract": return self._abstract elif level == "planning": return self._planning elif level == "execution": return self._execution else: raise ValueError( f"Invalid level: {level}. Must be 'abstract', 'planning', or 'execution'" )
[docs] def update(self, level: str, prediction_error: float) -> None: """ Update precision at a specific level with upward propagation. High prediction errors (>0.7) propagate upward with attenuation. Args: level: Level to update ('abstract', 'planning', 'execution') prediction_error: Prediction error δ ∈ [0,1] Example: >>> hp = HierarchicalPrecision() >>> hp.update('execution', 0.95) # High error >>> # Execution precision drops AND planning is affected """ # Update the specified level params = self.get_level(level) params.update(prediction_error) # Propagate upward if error is high if prediction_error > self.propagation_threshold: attenuated_error = prediction_error * self.attenuation_factor if level == "execution": self._planning.update(attenuated_error) elif level == "planning": self._abstract.update(attenuated_error)
[docs] def get_all_values(self) -> Dict[str, float]: """ Get all precision values as a dictionary. Returns: Dictionary with abstract, planning, execution values """ return { "abstract": self._abstract.value, "planning": self._planning.value, "execution": self._execution.value, }
[docs] def get_all(self) -> Dict[str, float]: """Alias for get_all_values().""" return self.get_all_values()
[docs] def reset(self) -> None: """Reset all levels to initial values.""" self._abstract.reset() self._planning.reset() self._execution.reset()
[docs] def should_adapt(self, level: str = "execution") -> bool: """ Check if adaptation is needed at specified level. Args: level: Level to check (default: 'execution') Returns: True if adaptation should be triggered """ return self.get_level(level).should_adapt()
[docs] def beta_mean(alpha: float, beta: float) -> float: """ Calculate mean of Beta distribution. Args: alpha: Alpha parameter beta: Beta parameter Returns: Mean = α/(α+β) """ return alpha / (alpha + beta)
[docs] def beta_variance(alpha: float, beta: float) -> float: """ Calculate variance of Beta distribution. Args: alpha: Alpha parameter beta: Beta parameter Returns: Variance """ a_plus_b = alpha + beta return (alpha * beta) / (a_plus_b**2 * (a_plus_b + 1))