adventOfCode/challenges/2022/16-proboscideaVolcanium/py/__init__.py
AKP 6f282e5761
Code formatting
Signed-off-by: AKP <tom@tdpain.net>
2022-12-20 17:51:35 +00:00

182 lines
5.4 KiB
Python

from __future__ import annotations
from pprint import pprint
from typing import *
from aocpy import BaseChallenge
from dataclasses import dataclass
import re
from math import inf
import itertools
import copy
@dataclass
class Node:
name: str
value: int
edges: List[Node]
def __init__(
self, name: str, value: int = 0, edges: Optional[List[Tuple[Node, int]]] = None
):
self.name = name
self.value = value
self.edges = [] if edges is None else edges
def shortest_path(nodes: Dict[str, Node], begin: str, end: str) -> List[str]:
priorities: Dict[str, Tuple[Union[int, float], Optional[str]]] = {
node: (inf, None) for node in nodes
}
visited: Dict[str, None] = {begin: None}
cursor = begin
while True:
if cursor == end:
break
n = nodes[cursor]
visited[cursor] = None
length_to_current = priorities[cursor][0]
if length_to_current == inf:
length_to_current = 0
# every neighbour that's not been visited already
for neighbour in n.edges:
if neighbour.name in visited:
continue
if priorities[neighbour.name][0] > length_to_current + 1:
priorities[neighbour.name] = (length_to_current + 1, cursor)
# work out next item
min_priority = (inf, None)
for node_name in priorities:
if node_name not in visited and priorities[node_name][0] < min_priority[0]:
min_priority = (priorities[node_name][0], node_name)
cursor = min_priority[1]
route: List[str] = []
while priorities[cursor][1] is not None:
route.insert(0, cursor)
cursor = priorities[cursor][1]
return route[:-1]
parse_re = re.compile(
r"Valve ([A-Z]+) has flow rate=(\d+); tunnels? leads? to valves? ((?:[A-Z]+,? ?)+)"
)
def parse(
instr: str,
) -> Tuple[Dict[str, Node], List[str], Dict[Tuple[str, str], List[str]]]:
sp = [parse_re.match(line).groups() for line in instr.strip().splitlines()]
nodes: Dict[str, Node] = {}
unjammed_nodes: List[str] = []
for (valve_name, flow_rate_str, _) in sp:
flow_rate = int(flow_rate_str)
nodes[valve_name] = Node(valve_name, flow_rate)
if flow_rate != 0:
unjammed_nodes.append(valve_name)
for (valve_name, _, further_nodes_str) in sp:
n = nodes[valve_name]
for connected_node_name in further_nodes_str.split(", "):
n.edges.append(nodes[connected_node_name])
# work out a matrix of the shortest paths between two nodes
shortest_paths: Dict[Tuple[str, str], List[str]] = {}
for start_node in nodes:
if nodes[start_node].value == 0 and start_node != "AA":
continue
for end_node in nodes:
if end_node == start_node or nodes[end_node].value == 0:
continue
path = shortest_path(nodes, start_node, end_node)
pl = len(path) + 1
shortest_paths[(start_node, end_node)] = pl
shortest_paths[(end_node, start_node)] = pl
return nodes, unjammed_nodes, shortest_paths
def permutations(
current_node: str,
nodes_remaining: List[str],
shortest_paths: Dict[Tuple[str, str], List[str]],
path: List[str],
cost_remaining: int,
) -> Generator[List[str]]:
for next_node in nodes_remaining:
cost = shortest_paths[(current_node, next_node)]
if cost < cost_remaining:
nr = copy.copy(nodes_remaining)
nr.remove(next_node)
yield from permutations(
next_node, nr, shortest_paths, path + [next_node], cost_remaining - cost
)
yield path
def calc_vented(
nodes: Dict[str, Node],
shortest_paths: Dict[Tuple[str, str], List[str]],
visit_order: List[str],
time_remaining: int,
) -> int:
current = "AA"
pressure = 0
for node_name in visit_order:
path_length = shortest_paths[(current, node_name)]
time_remaining -= path_length + 1
pressure += nodes[node_name].value * time_remaining
current = node_name
return pressure
class Challenge(BaseChallenge):
@staticmethod
def one(instr: str) -> int:
nodes, unjammed_nodes, shortest_paths = parse(instr)
max_pressure = 0
for visit_order in permutations("AA", unjammed_nodes, shortest_paths, [], 30):
pressure = calc_vented(nodes, shortest_paths, visit_order, 30)
if pressure > max_pressure:
max_pressure = pressure
return max_pressure
@staticmethod
def two(instr: str) -> int:
nodes, unjammed_nodes, shortest_paths = parse(instr)
pressures = [
(calc_vented(nodes, shortest_paths, visit_order, 26), visit_order)
for visit_order in permutations(
"AA", unjammed_nodes, shortest_paths, [], 26
)
]
max_pressure = 0
for i, (pressure_a, order_a) in enumerate(
sorted(pressures, reverse=True, key=lambda x: x[0])
):
if pressure_a * 2 < max_pressure:
break
for (pressure_b, order_b) in pressures[i + 1 :]:
if len(set(order_a).intersection(order_b)) == 0:
pressure = pressure_a + pressure_b
if pressure > max_pressure:
max_pressure = pressure
return max_pressure