259 lines
8.7 KiB
Python
259 lines
8.7 KiB
Python
import pytest
|
|
from prisma.models import User
|
|
|
|
from autogpt_server.blocks.basic import FindInDictionaryBlock, StoreValueBlock
|
|
from autogpt_server.blocks.maths import CalculatorBlock, Operation
|
|
from autogpt_server.data import execution, graph
|
|
from autogpt_server.executor import ExecutionManager
|
|
from autogpt_server.server import AgentServer
|
|
from autogpt_server.usecases.sample import create_test_graph, create_test_user
|
|
from autogpt_server.util.test import SpinTestServer, wait_execution
|
|
|
|
|
|
async def execute_graph(
|
|
agent_server: AgentServer,
|
|
test_manager: ExecutionManager,
|
|
test_graph: graph.Graph,
|
|
test_user: User,
|
|
input_data: dict,
|
|
num_execs: int = 4,
|
|
) -> str:
|
|
# --- Test adding new executions --- #
|
|
response = await agent_server.execute_graph(test_graph.id, input_data, test_user.id)
|
|
graph_exec_id = response["id"]
|
|
|
|
# Execution queue should be empty
|
|
assert await wait_execution(
|
|
test_manager, test_user.id, test_graph.id, graph_exec_id, num_execs
|
|
)
|
|
return graph_exec_id
|
|
|
|
|
|
async def assert_sample_graph_executions(
|
|
agent_server: AgentServer,
|
|
test_graph: graph.Graph,
|
|
test_user: User,
|
|
graph_exec_id: str,
|
|
):
|
|
executions = await agent_server.get_graph_run_node_execution_results(
|
|
test_graph.id, graph_exec_id, test_user.id
|
|
)
|
|
|
|
output_list = [{"result": ["Hello"]}, {"result": ["World"]}]
|
|
input_list = [
|
|
{
|
|
"name": "input_1",
|
|
"description": "First input value",
|
|
"placeholder_values": [],
|
|
"limit_to_placeholder_values": False,
|
|
"value": "Hello",
|
|
},
|
|
{
|
|
"name": "input_2",
|
|
"description": "Second input value",
|
|
"placeholder_values": [],
|
|
"limit_to_placeholder_values": False,
|
|
"value": "World",
|
|
},
|
|
]
|
|
|
|
# Executing StoreValueBlock
|
|
exec = executions[0]
|
|
assert exec.status == execution.ExecutionStatus.COMPLETED
|
|
assert exec.graph_exec_id == graph_exec_id
|
|
assert exec.output_data in output_list
|
|
assert exec.input_data in input_list
|
|
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
|
|
|
|
# Executing StoreValueBlock
|
|
exec = executions[1]
|
|
assert exec.status == execution.ExecutionStatus.COMPLETED
|
|
assert exec.graph_exec_id == graph_exec_id
|
|
assert exec.output_data in output_list
|
|
assert exec.input_data in input_list
|
|
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
|
|
|
|
# Executing FillTextTemplateBlock
|
|
exec = executions[2]
|
|
assert exec.status == execution.ExecutionStatus.COMPLETED
|
|
assert exec.graph_exec_id == graph_exec_id
|
|
assert exec.output_data == {"output": ["Hello, World!!!"]}
|
|
assert exec.input_data == {
|
|
"format": "{a}, {b}{c}",
|
|
"values": {"a": "Hello", "b": "World", "c": "!!!"},
|
|
"values_#_a": "Hello",
|
|
"values_#_b": "World",
|
|
"values_#_c": "!!!",
|
|
}
|
|
assert exec.node_id == test_graph.nodes[2].id
|
|
|
|
# Executing PrintToConsoleBlock
|
|
exec = executions[3]
|
|
assert exec.status == execution.ExecutionStatus.COMPLETED
|
|
assert exec.graph_exec_id == graph_exec_id
|
|
assert exec.output_data == {"status": ["printed"]}
|
|
assert exec.input_data == {"text": "Hello, World!!!"}
|
|
assert exec.node_id == test_graph.nodes[3].id
|
|
|
|
|
|
@pytest.mark.asyncio(scope="session")
|
|
async def test_agent_execution(server: SpinTestServer):
|
|
test_graph = create_test_graph()
|
|
test_user = await create_test_user()
|
|
await graph.create_graph(test_graph, user_id=test_user.id)
|
|
data = {"input_1": "Hello", "input_2": "World"}
|
|
graph_exec_id = await execute_graph(
|
|
server.agent_server,
|
|
server.exec_manager,
|
|
test_graph,
|
|
test_user,
|
|
data,
|
|
4,
|
|
)
|
|
await assert_sample_graph_executions(
|
|
server.agent_server, test_graph, test_user, graph_exec_id
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio(scope="session")
|
|
async def test_input_pin_always_waited(server: SpinTestServer):
|
|
"""
|
|
This test is asserting that the input pin should always be waited for the execution,
|
|
even when default value on that pin is defined, the value has to be ignored.
|
|
|
|
Test scenario:
|
|
StoreValueBlock1
|
|
\\ input
|
|
>------- FindInDictionaryBlock | input_default: key: "", input: {}
|
|
// key
|
|
StoreValueBlock2
|
|
"""
|
|
nodes = [
|
|
graph.Node(
|
|
block_id=StoreValueBlock().id,
|
|
input_default={"input": {"key1": "value1", "key2": "value2"}},
|
|
),
|
|
graph.Node(
|
|
block_id=StoreValueBlock().id,
|
|
input_default={"input": "key2"},
|
|
),
|
|
graph.Node(
|
|
block_id=FindInDictionaryBlock().id,
|
|
input_default={"key": "", "input": {}},
|
|
),
|
|
]
|
|
links = [
|
|
graph.Link(
|
|
source_id=nodes[0].id,
|
|
sink_id=nodes[2].id,
|
|
source_name="output",
|
|
sink_name="input",
|
|
),
|
|
graph.Link(
|
|
source_id=nodes[1].id,
|
|
sink_id=nodes[2].id,
|
|
source_name="output",
|
|
sink_name="key",
|
|
),
|
|
]
|
|
test_graph = graph.Graph(
|
|
name="TestGraph",
|
|
description="Test graph",
|
|
nodes=nodes,
|
|
links=links,
|
|
)
|
|
test_user = await create_test_user()
|
|
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
|
|
graph_exec_id = await execute_graph(
|
|
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
|
|
)
|
|
|
|
executions = await server.agent_server.get_graph_run_node_execution_results(
|
|
test_graph.id, graph_exec_id, test_user.id
|
|
)
|
|
assert len(executions) == 3
|
|
# FindInDictionaryBlock should wait for the input pin to be provided,
|
|
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
|
|
assert executions[2].status == execution.ExecutionStatus.COMPLETED
|
|
assert executions[2].output_data == {"output": ["value2"]}
|
|
|
|
|
|
@pytest.mark.asyncio(scope="session")
|
|
async def test_static_input_link_on_graph(server: SpinTestServer):
|
|
"""
|
|
This test is asserting the behaviour of static input link, e.g: reusable input link.
|
|
|
|
Test scenario:
|
|
*StoreValueBlock1*===a=========\\
|
|
*StoreValueBlock2*===a=====\\ ||
|
|
*StoreValueBlock3*===a===*MathBlock*====b / static====*StoreValueBlock5*
|
|
*StoreValueBlock4*=========================================//
|
|
|
|
In this test, there will be three input waiting in the MathBlock input pin `a`.
|
|
And later, another output is produced on input pin `b`, which is a static link,
|
|
this input will complete the input of those three incomplete executions.
|
|
"""
|
|
nodes = [
|
|
graph.Node(block_id=StoreValueBlock().id, input_default={"input": 4}), # a
|
|
graph.Node(block_id=StoreValueBlock().id, input_default={"input": 4}), # a
|
|
graph.Node(block_id=StoreValueBlock().id, input_default={"input": 4}), # a
|
|
graph.Node(block_id=StoreValueBlock().id, input_default={"input": 5}), # b
|
|
graph.Node(block_id=StoreValueBlock().id),
|
|
graph.Node(
|
|
block_id=CalculatorBlock().id,
|
|
input_default={"operation": Operation.ADD.value},
|
|
),
|
|
]
|
|
links = [
|
|
graph.Link(
|
|
source_id=nodes[0].id,
|
|
sink_id=nodes[5].id,
|
|
source_name="output",
|
|
sink_name="a",
|
|
),
|
|
graph.Link(
|
|
source_id=nodes[1].id,
|
|
sink_id=nodes[5].id,
|
|
source_name="output",
|
|
sink_name="a",
|
|
),
|
|
graph.Link(
|
|
source_id=nodes[2].id,
|
|
sink_id=nodes[5].id,
|
|
source_name="output",
|
|
sink_name="a",
|
|
),
|
|
graph.Link(
|
|
source_id=nodes[3].id,
|
|
sink_id=nodes[4].id,
|
|
source_name="output",
|
|
sink_name="input",
|
|
),
|
|
graph.Link(
|
|
source_id=nodes[4].id,
|
|
sink_id=nodes[5].id,
|
|
source_name="output",
|
|
sink_name="b",
|
|
is_static=True, # This is the static link to test.
|
|
),
|
|
]
|
|
test_graph = graph.Graph(
|
|
name="TestGraph",
|
|
description="Test graph",
|
|
nodes=nodes,
|
|
links=links,
|
|
)
|
|
test_user = await create_test_user()
|
|
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
|
|
graph_exec_id = await execute_graph(
|
|
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
|
|
)
|
|
executions = await server.agent_server.get_graph_run_node_execution_results(
|
|
test_graph.id, graph_exec_id, test_user.id
|
|
)
|
|
assert len(executions) == 8
|
|
# The last 3 executions will be a+b=4+5=9
|
|
for exec_data in executions[-3:]:
|
|
assert exec_data.status == execution.ExecutionStatus.COMPLETED
|
|
assert exec_data.output_data == {"result": [9]}
|