"""LLM-based policy generation for Active Inference."""
import json
from typing import List, Dict, Any, Optional
from unittest.mock import MagicMock
from pydantic import BaseModel, Field, field_validator
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import SystemMessage, HumanMessage
from lrs.core.registry import ToolRegistry
from lrs.core.precision import PrecisionParameters
from lrs.inference.prompts import MetaCognitivePrompter, PromptContext
[docs]
class PolicyProposal(BaseModel):
"""A single policy proposal with metadata."""
tool_sequence: List[str] = Field(description="Ordered list of tool names to execute")
reasoning: str = Field(description="Explanation of why this policy might work")
estimated_success_prob: float = Field(
ge=0.0, le=1.0, description="Estimated probability of success"
)
estimated_info_gain: float = Field(ge=0.0, le=1.0, description="Expected information gain")
strategy: str = Field(description="Strategy type: exploitation, exploration, or balanced")
failure_modes: List[str] = Field(
default_factory=list, description="Potential failure scenarios"
)
[docs]
@field_validator("strategy")
@classmethod
def validate_strategy(cls, v: str) -> str:
valid = ["exploitation", "exploration", "balanced"]
if v not in valid:
raise ValueError(f"Strategy must be one of {valid}")
return v
[docs]
class PolicyProposalSet(BaseModel):
"""Complete set of policy proposals with metadata."""
proposals: List[PolicyProposal]
current_uncertainty: float = Field(ge=0.0, le=1.0)
known_unknowns: List[str] = Field(default_factory=list)
[docs]
class LLMPolicyGenerator:
"""
Generates policy proposals using an LLM with Active Inference principles.
The generator uses meta-cognitive prompting to produce diverse policies
that balance exploration and exploitation based on precision parameters.
"""
[docs]
def __init__(
self,
llm: BaseChatModel,
registry: ToolRegistry,
prompter: Optional[MetaCognitivePrompter] = None,
):
"""
Initialize the policy generator.
Args:
llm: Language model for generating proposals
registry: Tool registry for available actions
prompter: Optional custom prompter (creates default if None)
"""
self.llm = llm
self.registry = registry
self.prompter = prompter or MetaCognitivePrompter()
[docs]
def generate_proposals(
self,
state: Optional[Dict[str, Any]] = None,
precision: Optional[PrecisionParameters] = None,
num_proposals: int = 3,
) -> List[Dict[str, Any]]:
"""
Generate policy proposals based on current context and precision.
Args:
state: Current state, goal, and history (deprecated, use context instead)
context: Current state, goal, and history
precision: Precision parameters guiding exploration/exploitation
num_proposals: Number of proposals to generate
Returns:
List of policy dictionaries with tools and metadata
"""
# Support both 'state' and 'context' for backward compatibility
if state is None:
context = {}
elif "state" in state or "goal" in state:
# Already in context format
context = state
else:
# Legacy format - wrap in context
context = {"state": state}
if precision is None:
precision = PrecisionParameters()
elif isinstance(precision, (int, float)):
# Convert float to PrecisionParameters object with appropriate alpha/beta
prec_value = float(precision)
# Map precision to alpha/beta values
if prec_value >= 0.5:
alpha = 2.0 * prec_value
beta = 2.0 * (1.0 - prec_value)
else:
alpha = 2.0 * prec_value
beta = 2.0 * (1.0 - prec_value)
precision = PrecisionParameters(alpha=alpha, beta=beta)
elif isinstance(precision, PrecisionParameters):
pass # Already the right type
else:
precision = PrecisionParameters()
# Generate prompt based on precision
precision_value = precision.value
prompt_context = PromptContext(
precision=precision_value,
available_tools=[
tool.name if hasattr(tool, "name") else str(tool) for tool in self.registry.tools
],
goal=str(context.get("goal", "Complete task")),
state=context.get("state", {}),
recent_errors=[
float(x) if isinstance(x, (int, float)) else 0.5
for x in context.get("recent_errors", [])
],
tool_history=[
x if isinstance(x, dict) else {"tool": str(x)}
for x in context.get("tool_history", [])
],
)
prompt = self.prompter.generate_prompt(prompt_context)
# Call LLM
messages = [
SystemMessage(content=prompt),
HumanMessage(content=f"Generate {num_proposals} policy proposals."),
]
response = self.llm.invoke(messages)
# Parse and validate response
try:
# Extract JSON from response
content = response.content
if isinstance(content, str):
# Handle markdown code blocks
if "```json" in content:
content = content.split("```json")[1].split("```")[0].strip()
elif "```" in content:
content = content.split("```")[1].split("```")[0].strip()
elif isinstance(content, list):
content = str(content)
proposal_set = PolicyProposalSet.model_validate_json(content)
except Exception as e:
# Fallback to simple proposals if parsing fails
print(f"Warning: Failed to parse LLM response: {e}")
return self._create_fallback_proposals(num_proposals)
# Convert to policy dictionaries
policies = []
for proposal in proposal_set.proposals:
# Get actual tool objects
tools = []
for tool_name in proposal.tool_sequence:
tool = self.registry.get_tool(tool_name)
if tool:
tools.append(tool)
else:
# Create a mock tool if not found
mock_tool = MagicMock()
mock_tool.name = tool_name
tools.append(mock_tool)
if tools: # Only include if we found valid tools
policies.append(
{
"tools": tools,
"reasoning": proposal.reasoning,
"estimated_success": proposal.estimated_success_prob,
"estimated_info_gain": proposal.estimated_info_gain,
"strategy": proposal.strategy,
"failure_modes": proposal.failure_modes,
}
)
return policies[:num_proposals]
def _create_fallback_proposals(self, num_proposals: int) -> List[Dict[str, Any]]:
"""Create simple fallback proposals when LLM parsing fails."""
proposals = []
tools = list(self.registry.tools)[:num_proposals]
for i, tool in enumerate(tools):
tool_name = tool.name if hasattr(tool, "name") else str(tool)
proposals.append(
{
"tools": [tool],
"reasoning": f"Fallback proposal using {tool_name}",
"estimated_success": 0.5,
"estimated_info_gain": 0.5,
"strategy": "balanced",
"failure_modes": ["Unknown - fallback proposal"],
}
)
return proposals
[docs]
def create_mock_generator(num_proposals: int = 3) -> LLMPolicyGenerator:
"""
Create a mock policy generator for testing.
Args:
num_proposals: Number of proposals the mock should generate
Returns:
Generator that produces simple test proposals.
"""
# 1. Create a valid JSON response that the mock LLM will return.
# This response must conform to the PolicyProposalSet schema.
proposals_data = []
tool_names = []
for i in range(num_proposals):
tool_name = f"mock_tool_{i}"
tool_names.append(tool_name)
proposals_data.append(
{
"tool_sequence": [tool_name],
"reasoning": f"Reasoning for using {tool_name}",
"estimated_success_prob": 0.85,
"estimated_info_gain": 0.6,
"strategy": "balanced",
"failure_modes": ["It might fail if the input is wrong."],
}
)
response_data = {
"proposals": proposals_data,
"current_uncertainty": 0.3,
"known_unknowns": ["The exact format of the API response."],
}
# The response content must be a JSON string
json_response = json.dumps(response_data)
# The response content must be a JSON string
json_response = json.dumps(response_data)
# 2. Configure the mock LLM to return the JSON response.
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = json_response
mock_llm.invoke.return_value = mock_response
# 3. Configure the mock ToolRegistry.
mock_registry = MagicMock()
# The generate_proposals method needs `registry.get_tool` to be callable
# and to return a tool object for the names in our mock response.
# It also needs `registry.tools` to be iterable for prompt generation.
# Create mock tools. Using MagicMock is fine for this purpose.
mock_tools = {}
for name in tool_names:
tool = MagicMock()
tool.name = name
mock_tools[name] = tool
# 3. Configure the mock ToolRegistry.
mock_registry = MagicMock()
# The generate_proposals method needs `registry.get_tool` to be callable
# and to return a tool object for the names in our mock response.
# It also needs `registry.tools` to be iterable for prompt generation.
# Create mock tools. Using MagicMock is fine for this purpose.
mock_tools = {}
for name in tool_names:
tool = MagicMock()
tool.name = name
mock_tools[name] = tool
mock_registry.get_tool.side_effect = lambda name: mock_tools.get(name)
mock_registry.tools = list(mock_tools.values())
return LLMPolicyGenerator(llm=mock_llm, registry=mock_registry)