AutoGPT/classic/original_autogpt/autogpt/agents/agent.py

314 lines
10 KiB
Python

from __future__ import annotations
import inspect
import logging
from typing import TYPE_CHECKING, Any, ClassVar, Optional
import sentry_sdk
from forge.agent.base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
from forge.agent.protocols import (
AfterExecute,
AfterParse,
CommandProvider,
DirectiveProvider,
MessageProvider,
)
from forge.command.command import Command
from forge.components.action_history import (
ActionHistoryComponent,
EpisodicActionHistory,
)
from forge.components.action_history.action_history import ActionHistoryConfiguration
from forge.components.code_executor.code_executor import (
CodeExecutorComponent,
CodeExecutorConfiguration,
)
from forge.components.context.context import AgentContext, ContextComponent
from forge.components.file_manager import FileManagerComponent
from forge.components.git_operations import GitOperationsComponent
from forge.components.image_gen import ImageGeneratorComponent
from forge.components.system import SystemComponent
from forge.components.user_interaction import UserInteractionComponent
from forge.components.watchdog import WatchdogComponent
from forge.components.web import WebSearchComponent, WebSeleniumComponent
from forge.file_storage.base import FileStorage
from forge.llm.prompting.schema import ChatPrompt
from forge.llm.prompting.utils import dump_prompt
from forge.llm.providers import (
AssistantFunctionCall,
ChatMessage,
ChatModelResponse,
MultiProvider,
)
from forge.llm.providers.utils import function_specs_from_commands
from forge.models.action import (
ActionErrorResult,
ActionInterruptedByHuman,
ActionResult,
ActionSuccessResult,
)
from forge.models.config import Configurable
from forge.utils.exceptions import (
AgentException,
AgentTerminated,
CommandExecutionError,
UnknownCommandError,
)
from pydantic import Field
from .prompt_strategies.one_shot import (
OneShotAgentActionProposal,
OneShotAgentPromptStrategy,
)
if TYPE_CHECKING:
from autogpt.app.config import AppConfig
logger = logging.getLogger(__name__)
class AgentConfiguration(BaseAgentConfiguration):
pass
class AgentSettings(BaseAgentSettings):
config: AgentConfiguration = Field( # type: ignore
default_factory=AgentConfiguration
)
history: EpisodicActionHistory[OneShotAgentActionProposal] = Field(
default_factory=EpisodicActionHistory[OneShotAgentActionProposal]
)
"""(STATE) The action history of the agent."""
context: AgentContext = Field(default_factory=AgentContext)
class Agent(BaseAgent[OneShotAgentActionProposal], Configurable[AgentSettings]):
default_settings: ClassVar[AgentSettings] = AgentSettings(
name="Agent",
description=__doc__ if __doc__ else "",
)
def __init__(
self,
settings: AgentSettings,
llm_provider: MultiProvider,
file_storage: FileStorage,
app_config: AppConfig,
):
super().__init__(settings)
self.llm_provider = llm_provider
prompt_config = OneShotAgentPromptStrategy.default_configuration.model_copy(
deep=True
)
prompt_config.use_functions_api = (
settings.config.use_functions_api
# Anthropic currently doesn't support tools + prefilling :(
and self.llm.provider_name != "anthropic"
)
self.prompt_strategy = OneShotAgentPromptStrategy(prompt_config, logger)
self.commands: list[Command] = []
# Components
self.system = SystemComponent()
self.history = (
ActionHistoryComponent(
settings.history,
lambda x: self.llm_provider.count_tokens(x, self.llm.name),
llm_provider,
ActionHistoryConfiguration(
llm_name=app_config.fast_llm, max_tokens=self.send_token_limit
),
)
.run_after(WatchdogComponent)
.run_after(SystemComponent)
)
if not app_config.noninteractive_mode:
self.user_interaction = UserInteractionComponent()
self.file_manager = FileManagerComponent(file_storage, settings)
self.code_executor = CodeExecutorComponent(
self.file_manager.workspace,
CodeExecutorConfiguration(
docker_container_name=f"{settings.agent_id}_sandbox"
),
)
self.git_ops = GitOperationsComponent()
self.image_gen = ImageGeneratorComponent(self.file_manager.workspace)
self.web_search = WebSearchComponent()
self.web_selenium = WebSeleniumComponent(
llm_provider,
app_config.app_data_dir,
)
self.context = ContextComponent(self.file_manager.workspace, settings.context)
self.watchdog = WatchdogComponent(settings.config, settings.history).run_after(
ContextComponent
)
self.event_history = settings.history
self.app_config = app_config
async def propose_action(self) -> OneShotAgentActionProposal:
"""Proposes the next action to execute, based on the task and current state.
Returns:
The command name and arguments, if any, and the agent's thoughts.
"""
self.reset_trace()
# Get directives
resources = await self.run_pipeline(DirectiveProvider.get_resources)
constraints = await self.run_pipeline(DirectiveProvider.get_constraints)
best_practices = await self.run_pipeline(DirectiveProvider.get_best_practices)
directives = self.state.directives.model_copy(deep=True)
directives.resources += resources
directives.constraints += constraints
directives.best_practices += best_practices
# Get commands
self.commands = await self.run_pipeline(CommandProvider.get_commands)
self._remove_disabled_commands()
# Get messages
messages = await self.run_pipeline(MessageProvider.get_messages)
include_os_info = (
self.code_executor.config.execute_local_commands
if hasattr(self, "code_executor")
else False
)
prompt: ChatPrompt = self.prompt_strategy.build_prompt(
messages=messages,
task=self.state.task,
ai_profile=self.state.ai_profile,
ai_directives=directives,
commands=function_specs_from_commands(self.commands),
include_os_info=include_os_info,
)
logger.debug(f"Executing prompt:\n{dump_prompt(prompt)}")
output = await self.complete_and_parse(prompt)
self.config.cycle_count += 1
return output
async def complete_and_parse(
self, prompt: ChatPrompt, exception: Optional[Exception] = None
) -> OneShotAgentActionProposal:
if exception:
prompt.messages.append(ChatMessage.system(f"Error: {exception}"))
response: ChatModelResponse[
OneShotAgentActionProposal
] = await self.llm_provider.create_chat_completion(
prompt.messages,
model_name=self.llm.name,
completion_parser=self.prompt_strategy.parse_response_content,
functions=prompt.functions,
prefill_response=prompt.prefill_response,
)
result = response.parsed_result
await self.run_pipeline(AfterParse.after_parse, result)
return result
async def execute(
self,
proposal: OneShotAgentActionProposal,
user_feedback: str = "",
) -> ActionResult:
tool = proposal.use_tool
# Get commands
self.commands = await self.run_pipeline(CommandProvider.get_commands)
self._remove_disabled_commands()
try:
return_value = await self._execute_tool(tool)
result = ActionSuccessResult(outputs=return_value)
except AgentTerminated:
raise
except AgentException as e:
result = ActionErrorResult.from_exception(e)
logger.warning(f"{tool} raised an error: {e}")
sentry_sdk.capture_exception(e)
result_tlength = self.llm_provider.count_tokens(str(result), self.llm.name)
if result_tlength > self.send_token_limit // 3:
result = ActionErrorResult(
reason=f"Command {tool.name} returned too much output. "
"Do not execute this command again with the same arguments."
)
await self.run_pipeline(AfterExecute.after_execute, result)
logger.debug("\n".join(self.trace))
return result
async def do_not_execute(
self, denied_proposal: OneShotAgentActionProposal, user_feedback: str
) -> ActionResult:
result = ActionInterruptedByHuman(feedback=user_feedback)
await self.run_pipeline(AfterExecute.after_execute, result)
logger.debug("\n".join(self.trace))
return result
async def _execute_tool(self, tool_call: AssistantFunctionCall) -> Any:
"""Execute the command and return the result
Args:
tool_call (AssistantFunctionCall): The tool call to execute
Returns:
str: The execution result
"""
# Execute a native command with the same name or alias, if it exists
command = self._get_command(tool_call.name)
try:
result = command(**tool_call.arguments)
if inspect.isawaitable(result):
return await result
return result
except AgentException:
raise
except Exception as e:
raise CommandExecutionError(str(e))
def _get_command(self, command_name: str) -> Command:
for command in reversed(self.commands):
if command_name in command.names:
return command
raise UnknownCommandError(
f"Cannot execute command '{command_name}': unknown command."
)
def _remove_disabled_commands(self) -> None:
self.commands = [
command
for command in self.commands
if not any(
name in self.app_config.disabled_commands for name in command.names
)
]
def find_obscured_commands(self) -> list[Command]:
seen_names = set()
obscured_commands = []
for command in reversed(self.commands):
# If all of the command's names have been seen, it's obscured
if seen_names.issuperset(command.names):
obscured_commands.append(command)
else:
seen_names.update(command.names)
return list(reversed(obscured_commands))