728x90
https://www.acmicpc.net/problem/13511
13511번: 트리와 쿼리 2
N개의 정점으로 이루어진 트리(무방향 사이클이 없는 연결 그래프)가 있다. 정점은 1번부터 N번까지 번호가 매겨져 있고, 간선은 1번부터 N-1번까지 번호가 매겨져 있다. 아래의 두 쿼리를 수행하
www.acmicpc.net
우선 최소 공통 조상을 구한 뒤, 쿼리의 종류에 따라 출력해주면 되는 문제입니다. LCA 를 구한 뒤 쿼리의 종류와 구하는 인덱스에 따라 분기를 나누어 출력해주면 풀 수 있습니다.
from collections import deque
from math import log2
import sys
input = sys.stdin.readline
def solution(N, M, edge_list, query_list):
# 최소 공통 조상 구해서 다시 거슬러 올라가자 / 과정 모두 저장하면 메모리 초과
# 임시 트리 기록 / tree[i] = [[i 노드의 자식, 비용] ... ]
tree = [[] for _ in range(N+1)]
for node_1, node_2, cost in edge_list:
tree[node_1].append([node_2, cost])
tree[node_2].append([node_1, cost])
# 트리 정렬 및 깊이 탐색
tree, depth_list = sort_tree_and_set_depth(tree, N)
# 희소배열 생성
sparse_table = set_sparse_table(tree, N)
# 쿼리 순회
for query in query_list:
# 쿼리 상태
state = query[0]
# 노드
node_1 = query[1]
node_2 = query[2]
# 깊이 맞추기
eq_node_1, eq_node_2, cost = equailze_depth(node_1, node_2, depth_list, sparse_table)
# eq_node_1 이 eq_node_2 면 lca 는 eq_node_1 == eq_node_2
if eq_node_1 == eq_node_2:
lca = eq_node_1
else:
# 순회
for n in range(int(log2(N)), -1, -1):
# 큰 수부터 올라가보며 같은 노드면 패스
if sparse_table[n][eq_node_1][0] != sparse_table[n][eq_node_2][0]:
# 코스트 추가
cost += sparse_table[n][eq_node_1][1]+sparse_table[n][eq_node_2][1]
# 노드 이동
eq_node_1 = sparse_table[n][eq_node_1][0]
eq_node_2 = sparse_table[n][eq_node_2][0]
# 순회가 끝나면 하나 더 거슬러 올라가기
# 코스트 추가
cost += sparse_table[0][eq_node_1][1]+sparse_table[0][eq_node_2][1]
# lca
lca = sparse_table[0][eq_node_1][0]
# state 가 1 이면
if state == 1:
print(cost)
continue
# 아니면
else:
# node와 lca 의 거리
dist_node_1 = depth_list[node_1] - depth_list[lca] + 1
dist_node_2 = depth_list[node_2] - depth_list[lca] + 1
# 순회
k = query[3]
# dist_node_1 보다 k 가 크면
if k > dist_node_1:
# node_2 에서 거슬러 올라감
# node_2 에서부터 k 의 위치
dist_k = dist_node_1 + dist_node_2 - 1 - (k-1)
# 2진법
bin_dist_k = bin(dist_k-1)[2:]
for idx, n in enumerate(reversed(bin_dist_k)):
if n == '1':
node_2 = sparse_table[idx][node_2][0]
print(node_2)
continue
# dist_node_1 보다 k 가 작으면
elif k < dist_node_1:
# node_1 에서 거슬러 올라감
# 2진법
bin_dist_k = bin(k-1)[2:]
for idx, n in enumerate(reversed(bin_dist_k)):
if n == '1':
node_1 = sparse_table[idx][node_1][0]
print(node_1)
continue
# 같으면
else:
print(lca)
continue
# 트리 정렬 및 깊이 탐색 함수
def sort_tree_and_set_depth(tree, N):
# 큐 / 1을 루트로 임의로 지정
queue = deque([[1, 0]])
# 깊이 리스트
depth_list = [0] + [-1 for _ in range(N)]
# 정렬된 트리 / tree[i] = [i 의 부모, 비용]
sorted_tree = [[] for _ in range(N+1)]
sorted_tree[1] = [1, 0]
# 트리 탐색
while queue:
# 현재 노드, 깊이
now_node, depth = queue.popleft()
# 현재 노드 깊이 기록
depth_list[now_node] = depth
# 현재 노드와 연결되어 있고 깊이가 지정되지 않은 노드 탐색
for node, cost in tree[now_node]:
if depth_list[node] < 0:
# 큐에 추가
queue.append([node, depth+1])
# 트리에 추가
sorted_tree[node] = [now_node, cost]
return sorted_tree, depth_list
# 희소 배열 생성 함수
def set_sparse_table(tree, N):
# 희소 배열 최대 행 개수
max_row = int(log2(N))+1
# 희소 배열 / sparse_table[i][j] = [j 노드의 2^i 번째 조상, 조상 노드까지의 비용]
sparse_table = [[[]] + [[tree[i][0], tree[i][1]] for i in range(1, N+1)] for _ in range(max_row)]
# 희소 배열 최신화
for i in range(1, max_row):
for j in range(2, N+1):
sparse_table[i][j] = [
# 조상
sparse_table[i-1][sparse_table[i-1][j][0]][0],
# 비용
sparse_table[i-1][j][1]+sparse_table[i-1][sparse_table[i-1][j][0]][1],
]
return sparse_table
# 깊이 맞추기 함수
def equailze_depth(node_1, node_2, depth_list, sparse_table):
# 노드들의 깊이
depth_1 = depth_list[node_1]
depth_2 = depth_list[node_2]
# 이동 비용
move_cost = 0
# 깊이 상태, 1, 2 중에 뭐가 깊은지
depth_state = 0
# 깊이 차이
sub_depth = 0
# depth_1 이 더 크면
if depth_1 > depth_2:
# 깊이 상태
depth_state = 1
# 깊이 차이
sub_depth = depth_1 - depth_2
# 깊이 차이 2진수
bin_sub_depth = bin(sub_depth)[2:]
# 순회하며 노드 이동
for idx, n in enumerate(reversed(bin_sub_depth)):
# n 이 1 이면
if n == '1':
# 이동 비용 합
move_cost += sparse_table[idx][node_1][1]
# 노드 이동
node_1 = sparse_table[idx][node_1][0]
# depth_2 가 더 크면
elif depth_1 < depth_2:
# 깊이 상태
depth_state = 2
# 깊이 차이
sub_depth = depth_2 - depth_1
# 깊이 차이 2진수
bin_sub_depth = bin(sub_depth)[2:]
# 순회하며 노드 이동
for idx, n in enumerate(reversed(bin_sub_depth)):
# n 이 1 이면
if n == '1':
# 이동 비용 합
move_cost += sparse_table[idx][node_2][1]
# 노드 이동
node_2 = sparse_table[idx][node_2][0]
return node_1, node_2, move_cost
N = int(input())
edge_list = [list(map(int, input().split())) for _ in range(N-1)]
M = int(input())
query_list = [list(map(int,input().strip().split())) for _ in range(M)]
solution(N, M, edge_list, query_list)
728x90
'Coding Test > BaekJoon_Python' 카테고리의 다른 글
백준 4196 <도미노> Python (1) | 2024.01.03 |
---|---|
백준 2150 <Strongly Connected Component> Python (1) | 2024.01.03 |
백준 3176 <도로 네트워크> Python (1) | 2023.12.30 |
백준 11438 <LCA 2> Python (0) | 2023.12.29 |
백준 17435 <합성함수와 쿼리> Python (1) | 2023.12.28 |