Source code for lrs.integration.langchain_adapter
"""LangChain integration for LRS-Agents."""
from typing import Dict, Any, Optional, Callable
import time
import platform
import threading
from langchain.tools import BaseTool
from lrs.core.lens import ToolLens, ExecutionResult
def _extract_input_schema(tool: BaseTool) -> Dict[str, Any]:
"""Extract input schema from LangChain tool."""
if hasattr(tool, 'args_schema') and tool.args_schema:
try:
# Try Pydantic V2 method first
return tool.args_schema.model_json_schema()
except AttributeError:
# Fall back to Pydantic V1
return tool.args_schema.schema()
return {}
def _extract_output_schema(tool: BaseTool) -> Dict[str, Any]:
"""Extract output schema from LangChain tool."""
# Most LangChain tools return strings or dicts
return {
"type": "object",
"properties": {
"result": {"type": "string"}
}
}
[docs]
class LangChainToolLens(ToolLens):
"""
Wraps a LangChain tool as a ToolLens.
Provides timeout handling, prediction error calculation,
and statistics tracking for any LangChain tool.
Args:
tool: LangChain BaseTool to wrap
timeout: Maximum execution time in seconds (default: 30.0)
error_fn: Custom function to calculate prediction error
Example:
>>> from langchain.tools import Tool
>>> lc_tool = Tool(name="search", func=lambda q: f"Results for {q}")
>>> lrs_tool = LangChainToolLens(lc_tool, timeout=10.0)
>>> result = lrs_tool.get({"query": "test"})
"""
[docs]
def __init__(
self,
tool: BaseTool,
timeout: float = 30.0,
error_fn: Optional[Callable] = None
):
"""Initialize LangChain tool wrapper."""
input_schema = _extract_input_schema(tool)
output_schema = _extract_output_schema(tool)
super().__init__(
name=tool.name,
input_schema=input_schema,
output_schema=output_schema
)
self.tool = tool
self.timeout = timeout
self.error_fn = error_fn or self._default_error_fn
def _default_error_fn(self, result: Any, output_schema: Dict) -> float:
"""Default prediction error calculation."""
if result is None:
return 0.9 # High surprise for null
elif isinstance(result, str) and len(result) == 0:
return 0.7 # Medium surprise for empty
else:
return 0.1 # Low surprise for success
[docs]
def get(self, state: Dict[str, Any]) -> ExecutionResult:
"""Execute LangChain tool with timeout."""
self.call_count += 1
start_time = time.time()
try:
# Platform-specific timeout handling
if platform.system() == 'Windows':
# Windows doesn't support SIGALRM - use threading
result_container = {'result': None, 'error': None}
def target():
try:
result_container['result'] = self.tool.run(**state)
except Exception as e:
result_container['error'] = e
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
thread.join(timeout=self.timeout)
if thread.is_alive():
# Timeout occurred
self.failure_count += 1
execution_time = time.time() - start_time
return ExecutionResult(
success=False,
value=None,
error=f"Timeout after {self.timeout}s",
prediction_error=0.7
)
if result_container['error']:
raise result_container['error']
tool_result = result_container['result']
else:
# Unix-like systems can use signal
import signal
def timeout_handler(signum, frame):
raise TimeoutError(f"Timeout after {self.timeout}s")
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(int(self.timeout))
try:
tool_result = self.tool.run(**state)
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
# Success
execution_time = time.time() - start_time
prediction_error = self.error_fn(tool_result, self.output_schema)
return ExecutionResult(
success=True,
value=tool_result,
error=None,
prediction_error=prediction_error
)
except TimeoutError as e:
self.failure_count += 1
execution_time = time.time() - start_time
return ExecutionResult(
success=False,
value=None,
error=str(e),
prediction_error=0.7
)
except Exception as e:
self.failure_count += 1
execution_time = time.time() - start_time
return ExecutionResult(
success=False,
value=None,
error=str(e),
prediction_error=0.9
)
[docs]
def set(self, state: Dict[str, Any], obs: Any) -> Dict[str, Any]:
"""Update state with tool result."""
return {
**state,
f'{self.name}_result': obs
}
[docs]
def wrap_langchain_tool(
tool: BaseTool,
timeout: float = 30.0,
error_fn: Optional[Callable] = None
) -> LangChainToolLens:
"""
Convenience function to wrap a LangChain tool.
Args:
tool: LangChain BaseTool to wrap
timeout: Maximum execution time in seconds
error_fn: Optional custom error calculation function
Returns:
LangChainToolLens: Wrapped tool ready for LRS use
Example:
>>> from langchain_community.tools import DuckDuckGoSearchRun
>>> search = wrap_langchain_tool(DuckDuckGoSearchRun(), timeout=10.0)
"""
return LangChainToolLens(tool, timeout=timeout, error_fn=error_fn)