mdz/pytorch/drl4vrp/1_scripts/0_infer.py

93 lines
3.3 KiB
Python

# 该脚本用来加载模型并执行一张图像的前向推理
import sys
sys.path.append(R"../0_drl4vrp")
import os
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
import env
from model import DRL4TSP
from tasks import tsp
from tasks.tsp import TSPDataset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')
def validate(data_loader, actor, reward_fn, render_fn=None, save_dir='.',
num_plot=5):
"""Used to monitor progress on a validation set & optionally plot solution."""
actor.eval()
if not os.path.exists(save_dir):
os.makedirs(save_dir)
rewards = []
for batch_idx, batch in enumerate(data_loader):
print('*'*40,batch_idx,'*'*40)
static, dynamic, x0 = batch
static = static.to(device)
dynamic = dynamic.to(device)
x0 = x0.to(device) if len(x0) > 0 else None
with torch.no_grad():
tour_indices, _ = actor.forward(static, dynamic, x0)
print('idx =',tour_indices)
reward = reward_fn(static, tour_indices).mean().item()
rewards.append(reward)
# Uncomment the following code will save the visualization results
if render_fn is not None and batch_idx < num_plot:
name_path = 'batch_'+str(batch_idx)+'.gif'
path = os.path.join(save_dir, name_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
render_fn(static, tour_indices, path)
return np.mean(rewards)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Combinatorial Optimization')
parser.add_argument('--seed', default=12345, type=int)
parser.add_argument('--checkpoint', default="../weights/tsp20")
parser.add_argument('--task', default='tsp')
parser.add_argument('--nodes', dest='num_nodes', default=20, type=int)
parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--hidden', dest='hidden_size', default=128, type=int)
parser.add_argument('--dropout', default=0.1, type=float)
parser.add_argument('--layers', dest='num_layers', default=1, type=int)
parser.add_argument('--test-size', default=100, type=int)
args = parser.parse_args()
print(args)
# Goals from paper:
# TSP20, 3.97
# TSP50, 6.08
# TSP100, 8.44
STATIC_SIZE = 2 # (x, y)
DYNAMIC_SIZE = 1 # dummy for compatibility
update_fn = None
actor = DRL4TSP(STATIC_SIZE,
DYNAMIC_SIZE,
args.hidden_size,
update_fn,
tsp.update_mask,
args.num_layers,
args.dropout).to(device)
if args.checkpoint:
path = os.path.join(args.checkpoint, 'actor.pt')
actor.load_state_dict(torch.load(path, device))
test_data = TSPDataset(args.num_nodes, args.test_size, args.seed + 2)
test_dir = 'test_infer'
test_loader = DataLoader(test_data, args.batch_size, False, num_workers=0)
out = validate(test_loader, actor, env.reward, env.render, test_dir, num_plot=5)
print('Results save at: ', test_dir)
print('Average tour length: ', out)