Segment Trees Tutorials:
4. Code Reference: https://github.com/cacophonix/SPOJ/blob/master/GSS1.cpp
C++ Code - GSS1:
#include <stdio.h> #include <math.h> class Node { public: int lMax; int rMax; int anyMax; int sum; Node(int l, int a, int r, int s) { lMax = l; rMax = r; anyMax = a; sum = s; } }; Node **segmentTree; int *array; int max(int a, int b) { if (a > b)return a; else return b; } void segmentTreeConstructor(int arrayLo, int arrayHi, int stIndex) { if (arrayLo == arrayHi) { segmentTree[stIndex] = new Node(array[arrayLo], array[arrayLo], array[arrayLo], array[arrayLo]); return; } int mid = (arrayLo + arrayHi) / 2; segmentTreeConstructor(arrayLo, mid, stIndex * 2 + 1); segmentTreeConstructor(mid + 1, arrayHi, stIndex * 2 + 2); Node *current = segmentTree[stIndex] = new Node(0, 0, 0, 0); Node *leftChild = segmentTree[stIndex * 2 + 1]; Node *rightChild = segmentTree[stIndex * 2 + 2]; current->lMax = max(leftChild->lMax, leftChild->sum + rightChild->lMax); current->rMax = max(leftChild->rMax + rightChild->sum, rightChild->rMax); current->anyMax = max(leftChild->anyMax, max(leftChild->rMax + rightChild->lMax, rightChild->anyMax)); current->sum = leftChild->sum + rightChild->sum; } Node& find(int stIndex, int qLow, int qHi, int currentLo, int currentHi) { if (qLow == currentLo && qHi == currentHi)return *(segmentTree[stIndex]); int mid = (currentLo + currentHi) / 2; if (qHi <= mid)return find(2 * stIndex + 1, qLow, qHi, currentLo, mid); else if (qLow > mid)return find(2 * stIndex + 2, qLow, qHi, mid + 1, currentHi); else { Node *currentResult = new Node(0, 0, 0, 0); Node leftResult = find(2 * stIndex + 1, qLow, mid, currentLo, mid); Node rightResult = find(2 * stIndex + 2, mid + 1, qHi, mid + 1, currentHi); currentResult->lMax = max(leftResult.lMax, leftResult.sum + rightResult.lMax); currentResult->rMax = max(leftResult.rMax + rightResult.sum, rightResult.rMax); currentResult->anyMax = max(leftResult.anyMax, max(leftResult.rMax + rightResult.lMax, rightResult.anyMax)); currentResult->sum = leftResult.sum + rightResult.sum; return *currentResult; } } void constructSegmentTree(int* array, int size) { int height = (int) ceil(log((double) size) / log((double) 2)); segmentTree = new Node*[(int) ceil(pow((double) 2, height + 1))]; segmentTreeConstructor(0, size - 1, 0); } int main() { int nElements; scanf("%d", &nElements); array = new int[nElements]; for (int i = 0; i < nElements; ++i)scanf("%d", &array[i]); constructSegmentTree(array, nElements); int nQueries; scanf("%d", &nQueries); while (nQueries--) { int l, h; scanf("%d %d", &l, &h); Node result = find(0, l - 1, h - 1, 0, nElements - 1); printf("%d\n", result.anyMax); } }
C++ GSS3 Code:
#include <stdio.h> #include <math.h> class Node { public: int lMax; int rMax; int anyMax; int sum; Node(int l, int a, int r, int s) { lMax = l; rMax = r; anyMax = a; sum = s; } void toString() { printf("[%d %d %d %d]\n", lMax, rMax, anyMax, sum); } }; Node **segmentTree; int *array; int max(int a, int b) { if (a > b)return a; else return b; } void segmentTreeConstructor(int arrayLo, int arrayHi, int stIndex) { if (arrayLo == arrayHi) { segmentTree[stIndex] = new Node(array[arrayLo], array[arrayLo], array[arrayLo], array[arrayLo]); return; } int mid = (arrayLo + arrayHi) / 2; segmentTreeConstructor(arrayLo, mid, stIndex * 2 + 1); segmentTreeConstructor(mid + 1, arrayHi, stIndex * 2 + 2); Node *current = segmentTree[stIndex] = new Node(0, 0, 0, 0); Node *leftChild = segmentTree[stIndex * 2 + 1]; Node *rightChild = segmentTree[stIndex * 2 + 2]; current->lMax = max(leftChild->lMax, leftChild->sum + rightChild->lMax); current->rMax = max(leftChild->rMax + rightChild->sum, rightChild->rMax); current->anyMax = max(leftChild->anyMax, max(leftChild->rMax + rightChild->lMax, rightChild->anyMax)); current->sum = leftChild->sum + rightChild->sum; } Node& find(int stIndex, int qLow, int qHi, int currentLo, int currentHi) { if (qLow == currentLo && qHi == currentHi)return *(segmentTree[stIndex]); int mid = (currentLo + currentHi) / 2; if (qHi <= mid)return find(2 * stIndex + 1, qLow, qHi, currentLo, mid); else if (qLow > mid)return find(2 * stIndex + 2, qLow, qHi, mid + 1, currentHi); else { Node *currentResult = new Node(0, 0, 0, 0); Node leftResult = find(2 * stIndex + 1, qLow, mid, currentLo, mid); Node rightResult = find(2 * stIndex + 2, mid + 1, qHi, mid + 1, currentHi); currentResult->lMax = max(leftResult.lMax, leftResult.sum + rightResult.lMax); currentResult->rMax = max(leftResult.rMax + rightResult.sum, rightResult.rMax); currentResult->anyMax = max(leftResult.anyMax, max(leftResult.rMax + rightResult.lMax, rightResult.anyMax)); currentResult->sum = leftResult.sum + rightResult.sum; return *currentResult; } } void modify(int stIndex, int target, int currentLo, int currentHi, int newVal) { if (target == currentLo && currentLo == currentHi) { Node &target = *(segmentTree[stIndex]); target.sum = newVal; target.lMax = newVal; target.rMax = newVal; target.anyMax = newVal; return; } int mid = (currentLo + currentHi) / 2; if (target <= mid) modify(2 * stIndex + 1, target, currentLo, mid, newVal); else if (target > mid) modify(2 * stIndex + 2, target, mid + 1, currentHi, newVal); Node *currentResult = segmentTree[stIndex]; Node leftResult = *(segmentTree[2 * stIndex + 1]); Node rightResult = *(segmentTree[2 * stIndex + 2]); currentResult->lMax = max(leftResult.lMax, leftResult.sum + rightResult.lMax); currentResult->rMax = max(leftResult.rMax + rightResult.sum, rightResult.rMax); currentResult->anyMax = max(leftResult.anyMax, max(leftResult.rMax + rightResult.lMax, rightResult.anyMax)); currentResult->sum = leftResult.sum + rightResult.sum; } void constructSegmentTree(int* array, int size) { int height = (int) ceil(log((double) size) / log((double) 2)); segmentTree = new Node*[(int) ceil(pow((double) 2, height + 1))]; n = (int) ceil(pow((double) 2, height + 1)); segmentTreeConstructor(0, size - 1, 0); } int main() { int nElements; scanf("%d", &nElements); array = new int[nElements]; for (int i = 0; i < nElements; ++i)scanf("%d", &array[i]); constructSegmentTree(array, nElements); int nQueries; scanf("%d", &nQueries); while (nQueries--) { int type, l, h; scanf("%d %d %d", &type, &l, &h); if (type == 1) { Node result = find(0, l - 1, h - 1, 0, nElements - 1); printf("%d\n", result.anyMax); } else { modify(0, l - 1, 0, nElements - 1, h); } } }
No comments:
Post a Comment