2023년 8월 18일 알고리즘 문제풀이
문제 1 백준 15652
1차 시도
나의 생각
처음엔 조합(combination)을 이용해야하나 생각했다. 하지만 문제가 수를 구하는게 아니라 모든 결과를 출력해야 하기 때문에 조합은 의미가 없을 것이라 생각했다. 어떻게 해야할까 고민하다가 모든 경우를 탐색해야하니 dfs로 풀면 좋겠다는 생각을 했다. 실마리를 찾으니 푸는건 쉬웠다.
결과
정답
코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import sys
sys.setrecursionlimit(10**6)
n, m = map(int, sys.stdin.readline().split())
arr = [i for i in range(n+1)]
def dfs(x, tmp):
if len(tmp) == m:
print(*tmp)
return
for i in range(x,n+1):
tmp.append(i)
dfs(i,tmp)
tmp.pop()
for i in range(1, n+1):
tmp = []
tmp.append(i)
dfs(i, tmp)
문제 2 백준 2357
1차 시도
나의 생각
각 범위만큼의 원소들만 최소 힙에 넣어 최댓값과 최솟값을 구해냈다. 하지만 시간초과가 떴다. 딱히 원인을 분석하지 못하겠다.
결과
오답
코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import sys
from heapq import heappush, heappop
n, m = map(int, sys.stdin.readline().split())
nums = []
for _ in range(n):
nums.append(int(sys.stdin.readline()))
arr = []
for _ in range(m):
a, b = map(int, sys.stdin.readline().split())
arr.append([a-1, b-1])
for t in arr:
start, end = t
nums_min = []
nums_max = []
for i in range(start, end+1):
number = nums[i]
heappush(nums_min, number)
heappush(nums_max, -1*number)
print(heappop(nums_min), -1*heappop(nums_max), sep=' ')
2차 시도
나의 생각
세그먼트 트리를 이용하는 문제였다. 난 세그먼트 트리를 처음 들어봤다…
세그먼트 트리란 여러가지 연산을 효율적으로 지원하는 이진 트리 기반의 자료구조이다. 배열의 특정 구간에 대한 정보를 빠르게 알아내기 위해 사용되는데, 특정 구간의 합, 최소 값, 최대 값 등에 많이 쓰인다. 특징은 다음과 같다.
- 이진 트리 구조 : 배열의 각 원소는 트리의 리프 노드에, 원소들의 구간 정보는 트리의 내부 노드에 저장
- 높이 : 높이는 배열의 길이가 n일 때, O(logn)
- 구간 병합 : 두 자식 노드의 정보를 병합하여 부모 노드의 정보 구성. 합을 구할 땐 두 자식 노드의 합이 부모노드에 저장
- 연산 : 특정 구간에 대한 정보를 O(logn)시간 내에 얻을 수 있음
- 업데이트 연산 : 특정 원소를 변경할 때 그 원소를 포함하는 모든 구간의 정보를 O(logn) 시간 내에 업데이트
- 생성 비용 : 초기 트리 생성 비용은 O(nlogn)
정의는 위와 같다. 조금 더 알아듣기 쉽게 풀어서 쓰겠다. 우선 이진 트리 구조이므로 자식이 최대 2개인 트리이다. 많아봤자 왼쪽 자식, 오른쪽 자식이고 한명만 있거나 없을 수 도 있다. 각 정보를 노드의 내부에 저장한다고 하는데 이는 문제에서 요구하는 것을 저장하면 된다는 의미이다. 이번 문제에서는 구간의 최댓값과 최솟값이므로 그 둘을 저장하면 된다. 정확히는 저장 시켜야한다.
왼쪽 자식이 최솟값을 3, 최댓값이 10이고 오른쪽 자식이 최솟값이 1, 최댓값이 9라고 가정하자. 그럴 때 부모 노드는 두 자식의 최솟값들 중 최솟값과 최댓값들 중 최댓값을 각각 최솟값과 최댓값으로 가지면 된다. 따라서 부모 노드는 최솟값이 1, 최댓값이 10이다.
그렇다면 이제 최솟값과 최댓값에 관한 세그먼트 트리를 만들어보겠다.
build_tree
는 입력값들의 배열인 nums
로 tree
라는 세그먼트 트리를 만드는 것이다. start
와 end
를 양 쪽 끝으로 parameter를 받은 뒤 가운데를 mid
로 한다. 가운데를 기준으로 양쪽으로 쪼개며 재귀를 통해 트리를 구축한다. 또한 노드의 값은 원소가 2개인 배열인데, 0번 index는 최솟값, 1번 index는 최댓값을 나타낸다.
query_tree
는 트리에서 구간 a부터 b까지 를 탐색하는 것이다. 즉 우리가 원하는 구간에서 최솟값과 최댓값을 생성된 세그먼트 트리를 통해 찾아내는 함수이다. start
와 end
는 이 트리의 처음과 끝이므로 우리가 찾고자 하는 a,b의 범위가 밖이면 찾을 수 없으므로 무한대를 최댓값, 0을 최솟값으로 반환해버린다. 반대로 처음과 끝에 포함된다면 현재 노드가 최솟값과 최댓값의 정보를 내부에 가지고 있으므로 현재 노드를 반환한다. 아직 완벽히 포함되지 않는다면 mid
를 기준으로 둘로 나누어 재귀를 통해 탐색한다.
start
와 end
가 같다면 범위에 들어오는 숫자가 한 개란 뜻이다. 다시 말해 리프노드임을 의미한다. 또한 부모 노드는 자식 노드의 정보로 구성되어 있으므로 현재 노드의 최솟값 최댓값은 두 자식노드의 최솟값들 중 하나, 최댓값들 중 하나로 설정한다.
문제에서 a번쨰, b번쨰라고 하였으므로 index로 변환하기 위해 입력값을 배열에 담을때 1을 빼주었다. 또한 트리의 노드 갯수를 정확히 파악할 수 없으므로 대략적으로 2배와 4배 사이이기 때문에 tree
의 노드 갯수를 4*n 으로 설정하였다.
중간의 index가 2idx+1, 2idx+2 인 이유는 세그먼트 트리를 배열로 표현할 때 현재 노드의 index가 ‘idx’라면 왼쪽 자식은 2idx+1 , 오른쪽 자식은 2idx+2 로 표현하는 규칙 덕분이다. 세그먼트 트리를 배열로 표현할 때 사용되는 인덱싱 규칙이다.
결과
이해
코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import sys
sys.setrecursionlimit(10**9)
n, m = map(int, sys.stdin.readline().split())
nums = []
for _ in range(n):
nums.append(int(sys.stdin.readline()))
arr = []
for _ in range(m):
a, b = map(int, sys.stdin.readline().split())
arr.append((a-1, b-1))
tree = [[0, 0] for _ in range(4*n)]
def build_tree(idx, start, end):
if start == end:
tree[idx] = (nums[start], nums[start])
else:
mid = (start + end) // 2
left = build_tree(2*idx+1, start, mid)
right = build_tree(2*idx+2, mid+1, end)
tree[idx] = min(left[0], right[0]), max(left[1], right[1])
return tree[idx]
def query_tree(idx, start, end, a, b):
if end < a or b < start:
return (1e9, 0)
if a <= start and end <= b:
return tree[idx]
mid = (start + end) // 2
left = query_tree(2*idx+1, start, mid, a, b)
right = query_tree(2*idx+2, mid+1, end, a, b)
return min(left[0], right[0]), max(left[1], right[1])
build_tree(0, 0, n-1)
for a, b in arr:
min_val, max_val = query_tree(0, 0, n-1, a, b)
print(min_val, max_val)