182 lines
5.4 KiB
Python
182 lines
5.4 KiB
Python
from __future__ import annotations
|
|
from pprint import pprint
|
|
from typing import *
|
|
from aocpy import BaseChallenge
|
|
from dataclasses import dataclass
|
|
import re
|
|
from math import inf
|
|
import itertools
|
|
import copy
|
|
|
|
|
|
@dataclass
|
|
class Node:
|
|
name: str
|
|
value: int
|
|
edges: List[Node]
|
|
|
|
def __init__(
|
|
self, name: str, value: int = 0, edges: Optional[List[Tuple[Node, int]]] = None
|
|
):
|
|
self.name = name
|
|
self.value = value
|
|
self.edges = [] if edges is None else edges
|
|
|
|
|
|
def shortest_path(nodes: Dict[str, Node], begin: str, end: str) -> List[str]:
|
|
priorities: Dict[str, Tuple[Union[int, float], Optional[str]]] = {
|
|
node: (inf, None) for node in nodes
|
|
}
|
|
visited: Dict[str, None] = {begin: None}
|
|
|
|
cursor = begin
|
|
while True:
|
|
if cursor == end:
|
|
break
|
|
|
|
n = nodes[cursor]
|
|
visited[cursor] = None
|
|
|
|
length_to_current = priorities[cursor][0]
|
|
if length_to_current == inf:
|
|
length_to_current = 0
|
|
|
|
# every neighbour that's not been visited already
|
|
for neighbour in n.edges:
|
|
if neighbour.name in visited:
|
|
continue
|
|
if priorities[neighbour.name][0] > length_to_current + 1:
|
|
priorities[neighbour.name] = (length_to_current + 1, cursor)
|
|
|
|
# work out next item
|
|
min_priority = (inf, None)
|
|
for node_name in priorities:
|
|
if node_name not in visited and priorities[node_name][0] < min_priority[0]:
|
|
min_priority = (priorities[node_name][0], node_name)
|
|
|
|
cursor = min_priority[1]
|
|
|
|
route: List[str] = []
|
|
while priorities[cursor][1] is not None:
|
|
route.insert(0, cursor)
|
|
cursor = priorities[cursor][1]
|
|
|
|
return route[:-1]
|
|
|
|
|
|
parse_re = re.compile(
|
|
r"Valve ([A-Z]+) has flow rate=(\d+); tunnels? leads? to valves? ((?:[A-Z]+,? ?)+)"
|
|
)
|
|
|
|
|
|
def parse(
|
|
instr: str,
|
|
) -> Tuple[Dict[str, Node], List[str], Dict[Tuple[str, str], List[str]]]:
|
|
sp = [parse_re.match(line).groups() for line in instr.strip().splitlines()]
|
|
|
|
nodes: Dict[str, Node] = {}
|
|
unjammed_nodes: List[str] = []
|
|
for (valve_name, flow_rate_str, _) in sp:
|
|
flow_rate = int(flow_rate_str)
|
|
nodes[valve_name] = Node(valve_name, flow_rate)
|
|
if flow_rate != 0:
|
|
unjammed_nodes.append(valve_name)
|
|
|
|
for (valve_name, _, further_nodes_str) in sp:
|
|
n = nodes[valve_name]
|
|
for connected_node_name in further_nodes_str.split(", "):
|
|
n.edges.append(nodes[connected_node_name])
|
|
|
|
# work out a matrix of the shortest paths between two nodes
|
|
shortest_paths: Dict[Tuple[str, str], List[str]] = {}
|
|
for start_node in nodes:
|
|
|
|
if nodes[start_node].value == 0 and start_node != "AA":
|
|
continue
|
|
|
|
for end_node in nodes:
|
|
if end_node == start_node or nodes[end_node].value == 0:
|
|
continue
|
|
|
|
path = shortest_path(nodes, start_node, end_node)
|
|
pl = len(path) + 1
|
|
shortest_paths[(start_node, end_node)] = pl
|
|
shortest_paths[(end_node, start_node)] = pl
|
|
|
|
return nodes, unjammed_nodes, shortest_paths
|
|
|
|
|
|
def permutations(
|
|
current_node: str,
|
|
nodes_remaining: List[str],
|
|
shortest_paths: Dict[Tuple[str, str], List[str]],
|
|
path: List[str],
|
|
cost_remaining: int,
|
|
) -> Generator[List[str]]:
|
|
for next_node in nodes_remaining:
|
|
cost = shortest_paths[(current_node, next_node)]
|
|
if cost < cost_remaining:
|
|
nr = copy.copy(nodes_remaining)
|
|
nr.remove(next_node)
|
|
yield from permutations(
|
|
next_node, nr, shortest_paths, path + [next_node], cost_remaining - cost
|
|
)
|
|
yield path
|
|
|
|
|
|
def calc_vented(
|
|
nodes: Dict[str, Node],
|
|
shortest_paths: Dict[Tuple[str, str], List[str]],
|
|
visit_order: List[str],
|
|
time_remaining: int,
|
|
) -> int:
|
|
current = "AA"
|
|
pressure = 0
|
|
|
|
for node_name in visit_order:
|
|
path_length = shortest_paths[(current, node_name)]
|
|
time_remaining -= path_length + 1
|
|
pressure += nodes[node_name].value * time_remaining
|
|
current = node_name
|
|
|
|
return pressure
|
|
|
|
|
|
class Challenge(BaseChallenge):
|
|
@staticmethod
|
|
def one(instr: str) -> int:
|
|
nodes, unjammed_nodes, shortest_paths = parse(instr)
|
|
|
|
max_pressure = 0
|
|
for visit_order in permutations("AA", unjammed_nodes, shortest_paths, [], 30):
|
|
pressure = calc_vented(nodes, shortest_paths, visit_order, 30)
|
|
|
|
if pressure > max_pressure:
|
|
max_pressure = pressure
|
|
|
|
return max_pressure
|
|
|
|
@staticmethod
|
|
def two(instr: str) -> int:
|
|
nodes, unjammed_nodes, shortest_paths = parse(instr)
|
|
|
|
pressures = [
|
|
(calc_vented(nodes, shortest_paths, visit_order, 26), visit_order)
|
|
for visit_order in permutations(
|
|
"AA", unjammed_nodes, shortest_paths, [], 26
|
|
)
|
|
]
|
|
|
|
max_pressure = 0
|
|
for i, (pressure_a, order_a) in enumerate(
|
|
sorted(pressures, reverse=True, key=lambda x: x[0])
|
|
):
|
|
if pressure_a * 2 < max_pressure:
|
|
break
|
|
for (pressure_b, order_b) in pressures[i + 1 :]:
|
|
if len(set(order_a).intersection(order_b)) == 0:
|
|
pressure = pressure_a + pressure_b
|
|
if pressure > max_pressure:
|
|
max_pressure = pressure
|
|
|
|
return max_pressure
|