Source code for lrs.inference.llm_policy_generator

"""LLM-based policy generation for Active Inference."""

import json
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from enum import Enum
from unittest.mock import Mock, MagicMock

from pydantic import BaseModel, Field, field_validator
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import SystemMessage, HumanMessage

from lrs.core.lens import ToolLens
from lrs.core.registry import ToolRegistry
from lrs.core.precision import PrecisionParameters
from lrs.inference.prompts import MetaCognitivePrompter, StrategyMode, PromptContext


[docs] class PolicyProposal(BaseModel): """A single policy proposal with metadata.""" tool_sequence: List[str] = Field(description="Ordered list of tool names to execute") reasoning: str = Field(description="Explanation of why this policy might work") estimated_success_prob: float = Field(ge=0.0, le=1.0, description="Estimated probability of success") estimated_info_gain: float = Field(ge=0.0, le=1.0, description="Expected information gain") strategy: str = Field(description="Strategy type: exploitation, exploration, or balanced") failure_modes: List[str] = Field(default_factory=list, description="Potential failure scenarios")
[docs] @field_validator('strategy') @classmethod def validate_strategy(cls, v: str) -> str: valid = ['exploitation', 'exploration', 'balanced'] if v not in valid: raise ValueError(f"Strategy must be one of {valid}") return v
[docs] class PolicyProposalSet(BaseModel): """Complete set of policy proposals with metadata.""" proposals: List[PolicyProposal] current_uncertainty: float = Field(ge=0.0, le=1.0) known_unknowns: List[str] = Field(default_factory=list)
[docs] class LLMPolicyGenerator: """ Generates policy proposals using an LLM with Active Inference principles. The generator uses meta-cognitive prompting to produce diverse policies that balance exploration and exploitation based on precision parameters. """
[docs] def __init__( self, llm: BaseChatModel, registry: ToolRegistry, prompter: Optional[MetaCognitivePrompter] = None ): """ Initialize the policy generator. Args: llm: Language model for generating proposals registry: Tool registry for available actions prompter: Optional custom prompter (creates default if None) """ self.llm = llm self.registry = registry self.prompter = prompter or MetaCognitivePrompter()
[docs] def generate_proposals( self, context: Dict[str, Any], precision: PrecisionParameters, num_proposals: int = 3 ) -> List[Dict[str, Any]]: """ Generate policy proposals based on current context and precision. Args: context: Current state, goal, and history precision: Precision parameters guiding exploration/exploitation num_proposals: Number of proposals to generate Returns: List of policy dictionaries with tools and metadata """ # Generate prompt based on precision prompt_context = PromptContext( precision=precision.value, available_tools=[tool.name for tool in self.registry.tools], goal=context.get('goal', 'Complete the task'), state=context.get('state', {}), recent_errors=context.get('recent_errors', []), tool_history=context.get('tool_history', []) ) prompt = self.prompter.generate_prompt(prompt_context) # Call LLM messages = [ SystemMessage(content=prompt), HumanMessage(content=f"Generate {num_proposals} policy proposals.") ] response = self.llm.invoke(messages) # Parse and validate response try: # Extract JSON from response content = response.content if isinstance(content, str): # Handle markdown code blocks if '```json' in content: content = content.split('```json')[1].split('```')[0].strip() elif '```' in content: content = content.split('```')[1].split('```')[0].strip() proposal_set = PolicyProposalSet.model_validate_json(content) except Exception as e: # Fallback to simple proposals if parsing fails print(f"Warning: Failed to parse LLM response: {e}") return self._create_fallback_proposals(num_proposals) # Convert to policy dictionaries policies = [] for proposal in proposal_set.proposals: # Get actual tool objects tools = [] for tool_name in proposal.tool_sequence: tool = self.registry.get_tool(tool_name) if tool: tools.append(tool) if tools: # Only include if we found valid tools policies.append({ 'tools': tools, 'reasoning': proposal.reasoning, 'estimated_success': proposal.estimated_success_prob, 'estimated_info_gain': proposal.estimated_info_gain, 'strategy': proposal.strategy, 'failure_modes': proposal.failure_modes }) return policies[:num_proposals]
def _create_fallback_proposals(self, num_proposals: int) -> List[Dict[str, Any]]: """Create simple fallback proposals when LLM parsing fails.""" proposals = [] tools = list(self.registry.tools)[:num_proposals] for i, tool in enumerate(tools): proposals.append({ 'tools': [tool], 'reasoning': f'Fallback proposal using {tool.name}', 'estimated_success': 0.5, 'estimated_info_gain': 0.5, 'strategy': 'balanced', 'failure_modes': ['Unknown - fallback proposal'] }) return proposals
[docs] def create_mock_generator(num_proposals: int = 3) -> LLMPolicyGenerator: """ Create a mock policy generator for testing. Args: num_proposals: Number of proposals the mock should generate Returns: Generator that produces simple test proposals. """ # 1. Create a valid JSON response that the mock LLM will return. # This response must conform to the PolicyProposalSet schema. proposals_data = [] tool_names = [] for i in range(num_proposals): tool_name = f"mock_tool_{i}" tool_names.append(tool_name) proposals_data.append({ "tool_sequence": [tool_name], "reasoning": f"Reasoning for using {tool_name}", "estimated_success_prob": 0.85, "estimated_info_gain": 0.6, "strategy": "balanced", "failure_modes": ["It might fail if the input is wrong."] }) response_data = { "proposals": proposals_data, "current_uncertainty": 0.3, "known_unknowns": ["The exact format of the API response."] } # The response content must be a JSON string json_response = json.dumps(response_data) # The response content must be a JSON string json_response = json.dumps(response_data) # 2. Configure the mock LLM to return the JSON response. mock_llm = MagicMock() mock_response = MagicMock() mock_response.content = json_response mock_llm.invoke.return_value = mock_response # 3. Configure the mock ToolRegistry. mock_registry = MagicMock() # The generate_proposals method needs `registry.get_tool` to be callable # and to return a tool object for the names in our mock response. # It also needs `registry.tools` to be iterable for prompt generation. # Create mock tools. Using MagicMock is fine for this purpose. mock_tools = {} for name in tool_names: tool = MagicMock() tool.name = name mock_tools[name] = tool # 3. Configure the mock ToolRegistry. mock_registry = MagicMock() # The generate_proposals method needs `registry.get_tool` to be callable # and to return a tool object for the names in our mock response. # It also needs `registry.tools` to be iterable for prompt generation. # Create mock tools. Using MagicMock is fine for this purpose. mock_tools = {} for name in tool_names: tool = MagicMock() tool.name = name mock_tools[name] = tool mock_registry.get_tool.side_effect = lambda name: mock_tools.get(name) mock_registry.tools = list(mock_tools.values()) return LLMPolicyGenerator(llm=mock_llm, registry=mock_registry)