AutoGPT/classic/frontend/lib/viewmodels/task_queue_viewmodel.dart

277 lines
10 KiB
Dart

import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_run.dart';
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_step_request_body.dart';
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_task_request_body.dart';
import 'package:auto_gpt_flutter_client/models/benchmark/benchmark_task_status.dart';
import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_edge.dart';
import 'package:auto_gpt_flutter_client/models/skill_tree/skill_tree_node.dart';
import 'package:auto_gpt_flutter_client/models/step.dart';
import 'package:auto_gpt_flutter_client/models/task.dart';
import 'package:auto_gpt_flutter_client/models/test_option.dart';
import 'package:auto_gpt_flutter_client/models/test_suite.dart';
import 'package:auto_gpt_flutter_client/services/benchmark_service.dart';
import 'package:auto_gpt_flutter_client/services/leaderboard_service.dart';
import 'package:auto_gpt_flutter_client/services/shared_preferences_service.dart';
import 'package:auto_gpt_flutter_client/viewmodels/chat_viewmodel.dart';
import 'package:auto_gpt_flutter_client/viewmodels/task_viewmodel.dart';
import 'package:collection/collection.dart';
import 'package:flutter/foundation.dart';
import 'package:uuid/uuid.dart';
import 'package:auto_gpt_flutter_client/utils/stack.dart';
class TaskQueueViewModel extends ChangeNotifier {
final BenchmarkService benchmarkService;
final LeaderboardService leaderboardService;
final SharedPreferencesService prefsService;
bool isBenchmarkRunning = false;
Map<SkillTreeNode, BenchmarkTaskStatus> benchmarkStatusMap = {};
List<BenchmarkRun> currentBenchmarkRuns = [];
List<SkillTreeNode>? _selectedNodeHierarchy;
TestOption _selectedOption = TestOption.runSingleTest;
TestOption get selectedOption => _selectedOption;
List<SkillTreeNode>? get selectedNodeHierarchy => _selectedNodeHierarchy;
TaskQueueViewModel(
this.benchmarkService, this.leaderboardService, this.prefsService);
void updateSelectedNodeHierarchyBasedOnOption(
TestOption selectedOption,
SkillTreeNode? selectedNode,
List<SkillTreeNode> nodes,
List<SkillTreeEdge> edges) {
_selectedOption = selectedOption;
switch (selectedOption) {
case TestOption.runSingleTest:
_selectedNodeHierarchy = selectedNode != null ? [selectedNode] : [];
break;
case TestOption.runTestSuiteIncludingSelectedNodeAndAncestors:
if (selectedNode != null) {
populateSelectedNodeHierarchy(selectedNode.id, nodes, edges);
}
break;
case TestOption.runAllTestsInCategory:
if (selectedNode != null) {
_getAllNodesInDepthFirstOrderEnsuringParents(nodes, edges);
}
break;
}
notifyListeners();
}
void _getAllNodesInDepthFirstOrderEnsuringParents(
List<SkillTreeNode> skillTreeNodes, List<SkillTreeEdge> skillTreeEdges) {
var nodes = <SkillTreeNode>[];
var stack = Stack<SkillTreeNode>();
var visited = <String>{};
// Identify the root node by its label
var root = skillTreeNodes.firstWhere((node) => node.label == "WriteFile");
stack.push(root);
visited.add(root.id);
while (stack.isNotEmpty) {
var node = stack.peek(); // Peek the top node, but do not remove it yet
var parents =
_getParentsOfNodeUsingEdges(node.id, skillTreeNodes, skillTreeEdges);
// Check if all parents are visited
if (parents.every((parent) => visited.contains(parent.id))) {
nodes.add(node);
stack.pop(); // Remove the node only when all its parents are visited
// Get the children of the current node using edges
var children = _getChildrenOfNodeUsingEdges(
node.id, skillTreeNodes, skillTreeEdges)
.where((child) => !visited.contains(child.id));
children.forEach((child) {
visited.add(child.id);
stack.push(child);
});
} else {
stack
.pop(); // Remove the node if not all parents are visited, it will be re-added when its parents are visited
}
}
_selectedNodeHierarchy = nodes;
}
List<SkillTreeNode> _getParentsOfNodeUsingEdges(
String nodeId, List<SkillTreeNode> nodes, List<SkillTreeEdge> edges) {
var parents = <SkillTreeNode>[];
for (var edge in edges) {
if (edge.to == nodeId) {
parents.add(nodes.firstWhere((node) => node.id == edge.from));
}
}
return parents;
}
List<SkillTreeNode> _getChildrenOfNodeUsingEdges(
String nodeId, List<SkillTreeNode> nodes, List<SkillTreeEdge> edges) {
var children = <SkillTreeNode>[];
for (var edge in edges) {
if (edge.from == nodeId) {
children.add(nodes.firstWhere((node) => node.id == edge.to));
}
}
return children;
}
// TODO: Do we want to continue testing other branches of tree if one branch side fails benchmarking?
void populateSelectedNodeHierarchy(String startNodeId,
List<SkillTreeNode> nodes, List<SkillTreeEdge> edges) {
_selectedNodeHierarchy = <SkillTreeNode>[];
final addedNodes = <String>{};
recursivePopulateHierarchy(startNodeId, addedNodes, nodes, edges);
notifyListeners();
}
void recursivePopulateHierarchy(String nodeId, Set<String> addedNodes,
List<SkillTreeNode> nodes, List<SkillTreeEdge> edges) {
// Find the current node in the skill tree nodes list.
final currentNode = nodes.firstWhereOrNull((node) => node.id == nodeId);
// If the node is found and it hasn't been added yet, proceed with the population.
if (currentNode != null && addedNodes.add(currentNode.id)) {
// Find all parent edges for the current node.
final parentEdges = edges.where((edge) => edge.to == currentNode.id);
// For each parent edge found, recurse to the parent node.
for (final parentEdge in parentEdges) {
// Recurse to the parent node identified by the 'from' field of the edge.
recursivePopulateHierarchy(parentEdge.from, addedNodes, nodes, edges);
}
// After processing all parent nodes, add the current node to the list.
_selectedNodeHierarchy!.add(currentNode);
}
}
Future<void> runBenchmark(
ChatViewModel chatViewModel, TaskViewModel taskViewModel) async {
// Clear the benchmarkStatusList
benchmarkStatusMap.clear();
// Reset the current benchmark runs list to be empty at the start of a new benchmark
currentBenchmarkRuns = [];
// Create a new TestSuite object with the current timestamp
final testSuite =
TestSuite(timestamp: DateTime.now().toIso8601String(), tests: []);
// Set the benchmark running flag to true
isBenchmarkRunning = true;
// Notify listeners
notifyListeners();
// Populate benchmarkStatusList with node hierarchy
for (var node in _selectedNodeHierarchy!) {
benchmarkStatusMap[node] = BenchmarkTaskStatus.notStarted;
}
try {
// Loop through the nodes in the hierarchy
for (var node in _selectedNodeHierarchy!) {
benchmarkStatusMap[node] = BenchmarkTaskStatus.inProgress;
notifyListeners();
// Create a BenchmarkTaskRequestBody
final benchmarkTaskRequestBody = BenchmarkTaskRequestBody(
input: node.data.task, evalId: node.data.evalId);
// Create a new benchmark task
final createdTask = await benchmarkService
.createBenchmarkTask(benchmarkTaskRequestBody);
// Create a new Task object
final task =
Task(id: createdTask['task_id'], title: createdTask['input']);
// Update the current task ID in ChatViewModel
chatViewModel.setCurrentTaskId(task.id);
// Execute the first step and initialize the Step object
Map<String, dynamic> stepResponse =
await benchmarkService.executeBenchmarkStep(
task.id, BenchmarkStepRequestBody(input: node.data.task));
Step step = Step.fromMap(stepResponse);
chatViewModel.fetchChatsForTask();
// Check if it's the last step
while (!step.isLast) {
// Execute next step and update the Step object
stepResponse = await benchmarkService.executeBenchmarkStep(
task.id, BenchmarkStepRequestBody(input: null));
step = Step.fromMap(stepResponse);
// Fetch chats for the task
chatViewModel.fetchChatsForTask();
}
// Trigger the evaluation
final evaluationResponse =
await benchmarkService.triggerEvaluation(task.id);
// Decode the evaluationResponse into a BenchmarkRun object
BenchmarkRun benchmarkRun = BenchmarkRun.fromJson(evaluationResponse);
// Add the benchmark run object to the list of current benchmark runs
currentBenchmarkRuns.add(benchmarkRun);
// Update the benchmarkStatusList based on the evaluation response
bool successStatus = benchmarkRun.metrics.success;
benchmarkStatusMap[node] = successStatus
? BenchmarkTaskStatus.success
: BenchmarkTaskStatus.failure;
await Future.delayed(Duration(seconds: 1));
notifyListeners();
testSuite.tests.add(task);
// If successStatus is false, break out of the loop
if (!successStatus) {
print(
"Benchmark for node ${node.id} failed. Stopping all benchmarks.");
break;
}
}
// Add the TestSuite to the TaskViewModel
taskViewModel.addTestSuite(testSuite);
} catch (e) {
print("Error while running benchmark: $e");
}
// Reset the benchmark running flag
isBenchmarkRunning = false;
notifyListeners();
}
Future<void> submitToLeaderboard(
String teamName, String repoUrl, String agentGitCommitSha) async {
// Create a UUID.v4 for our unique run ID
String uuid = const Uuid().v4();
for (var run in currentBenchmarkRuns) {
run.repositoryInfo.teamName = teamName;
run.repositoryInfo.repoUrl = repoUrl;
run.repositoryInfo.agentGitCommitSha = agentGitCommitSha;
run.runDetails.runId = uuid;
await leaderboardService.submitReport(run);
print('Completed submission to leaderboard!');
}
// Clear the currentBenchmarkRuns list after submitting to the leaderboard
currentBenchmarkRuns.clear();
}
}