139 lines
3.4 KiB
Python
139 lines
3.4 KiB
Python
import sys
|
|
import gridutil.grid as gu
|
|
import gridutil.coord as cu
|
|
from collections.abc import Callable, Iterator
|
|
from collections import namedtuple
|
|
from queue import PriorityQueue
|
|
|
|
|
|
def parse(instr: str) -> gu.Grid:
|
|
g = gu.parse(instr)
|
|
return {k: int(g[k]) for k in g}
|
|
|
|
|
|
State = namedtuple("State", ["pos", "steps_taken", "direction"])
|
|
|
|
|
|
def djikstra(
|
|
start: State,
|
|
end: Callable[[State], bool],
|
|
neighbours: Callable[[State], Iterator[State, None, None]],
|
|
cost: Callable[[State, State], int],
|
|
) -> list[State]:
|
|
dq = PriorityQueue()
|
|
dq.put((0, start))
|
|
d = {start: 0}
|
|
p = {}
|
|
|
|
endState = None
|
|
while endState is None:
|
|
w = dq.get()[1]
|
|
|
|
for u in neighbours(w):
|
|
c = d[w] + cost(w, u)
|
|
if u not in d or c < d[u]:
|
|
d[u] = c
|
|
dq.put((c, u))
|
|
p[u] = w
|
|
|
|
if end(u):
|
|
endState = u
|
|
break
|
|
|
|
res = [endState]
|
|
cursor = p[endState]
|
|
while cursor is not None:
|
|
res.append(cursor)
|
|
cursor = p.get(cursor)
|
|
return list(reversed(res))
|
|
|
|
|
|
def one(instr: str):
|
|
grid = parse(instr)
|
|
end = (gu.get_max_x(grid), gu.get_max_y(grid))
|
|
|
|
def neighbours(node: State) -> Iterator[State, None, None]:
|
|
for direction in cu.Direction:
|
|
if (
|
|
direction == node.direction and node.steps_taken == 3
|
|
) or node.direction == direction.opposite():
|
|
continue
|
|
|
|
nc = cu.add(node.pos, direction.delta())
|
|
if nc not in grid:
|
|
continue
|
|
|
|
yield State(
|
|
nc,
|
|
(node.steps_taken + 1) if direction == node.direction else 1,
|
|
direction,
|
|
)
|
|
|
|
path = djikstra(
|
|
State((0, 0), 0, cu.Direction.Right),
|
|
lambda x: x.pos == end,
|
|
neighbours,
|
|
lambda _, x: grid[x.pos],
|
|
)[1:]
|
|
|
|
acc = 0
|
|
for node in path:
|
|
acc += grid[node.pos]
|
|
return acc
|
|
|
|
|
|
def two(instr: str):
|
|
grid = parse(instr)
|
|
end = (gu.get_max_x(grid), gu.get_max_y(grid))
|
|
|
|
for x in range(end[0] - 3, end[0]):
|
|
for y in range(end[1] - 3, end[1]):
|
|
del grid[(x, y)]
|
|
|
|
def neighbours(node: State) -> Iterator[State, None, None]:
|
|
for direction in cu.Direction:
|
|
if 0 < node.steps_taken < 4 and direction != node.direction:
|
|
continue
|
|
|
|
if (
|
|
direction == node.direction and node.steps_taken == 10
|
|
) or node.direction == direction.opposite():
|
|
continue
|
|
|
|
nc = cu.add(node.pos, direction.delta())
|
|
if nc not in grid:
|
|
continue
|
|
|
|
yield State(
|
|
nc,
|
|
(node.steps_taken + 1) if direction == node.direction else 1,
|
|
direction,
|
|
)
|
|
|
|
path = djikstra(
|
|
State((0, 0), 0, cu.Direction.Right),
|
|
lambda x: x.pos == end,
|
|
neighbours,
|
|
lambda _, x: grid[x.pos],
|
|
)[1:]
|
|
|
|
acc = 0
|
|
for node in path:
|
|
acc += grid[node.pos]
|
|
return acc
|
|
|
|
|
|
def _debug(*args, **kwargs):
|
|
kwargs["file"] = sys.stderr
|
|
print(*args, **kwargs)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) < 2 or sys.argv[1] not in ["1", "2"]:
|
|
print("Missing day argument", file=sys.stderr)
|
|
sys.exit(1)
|
|
inp = sys.stdin.read().strip()
|
|
if sys.argv[1] == "1":
|
|
print(one(inp))
|
|
else:
|
|
print(two(inp))
|