Binary Tree Max Path Sum
// https://leetcode.com/problems/binary-tree-maximum-path-sum/
/*
A path in a binary tree is a sequence of nodes where each pair of adjacent nodes in the sequence has an edge connecting them.
A node can only appear in the sequence at most once.
Note that the path does not need to pass through the root.
The path sum of a path is the sum of the node's values in the path.
Given the root of a binary tree, return the maximum path sum of any non-empty path.
Ex1:
Input: root = [1,2,3]
Output: 6
Explanation: The optimal path is 2 -> 1 -> 3 with a path sum of 2 + 1 + 3 = 6.
Ex2:
Input: root = [-10,9,20,null,null,15,7]
Output: 42
Explanation: The optimal path is 15 -> 20 -> 7 with a path sum of 15 + 20 + 7 = 42.
*/
#include <vector>
#include <iostream>
#include <cassert>
#include <algorithm>
using namespace std;
class TreeNode {
public:
TreeNode(int v) {
val = v;
left = nullptr;
right = nullptr;
}
int val;
TreeNode* left;
TreeNode* right;
};
int helper(TreeNode*curr, int& res) {
if (!curr) {
return 0;
}
int l = helper(curr->left, res);
int r = helper(curr->right, res);
res = max(res, curr->val);
res = max(res, curr->val + l);
res = max(res, curr->val + r);
res = max(res, curr->val + l + r);
return max(curr->val, max(curr->val+l, curr->val+ r));
}
// Time: O(n), Space: O(n)
int maxPathSum(TreeNode* root) {
if (!root) return 0;
int res = INT_MIN;
helper(root, res);
return res;
}
int helper2(TreeNode* curr, int& res, vector<int>& path, vector<int>& maxPath) {
if (!curr) return 0;
vector<int> leftPath, rightPath;
int l = max(0, helper2(curr->left, res, leftPath, maxPath));
int r = max(0, helper2(curr->right, res, rightPath, maxPath));
int currMax = curr->val + l + r;
if (currMax > res) {
res = currMax;
maxPath.clear();
maxPath.insert(maxPath.end(), leftPath.rbegin(), leftPath.rend());
maxPath.push_back(curr->val);
maxPath.insert(maxPath.end(), rightPath.begin(), rightPath.end());
}
// Determine the maximum sum path that can continue through the parent node
path = (l > r) ? leftPath : rightPath;
path.push_back(curr->val);
return curr->val + max(l, r);
}
// Time: O(n), Space: O(n)
vector<int> maxPathSum2(TreeNode* root, int& maxSum) {
vector<int> path, maxPath;
maxSum = INT_MIN;
if (root) {
helper2(root, maxSum, path, maxPath);
}
return maxPath;
}
int main() {
// TreeNode* root = new TreeNode(1);
// root->left = new TreeNode(2);
// root->right = new TreeNode(3);
// // cout << maxPathSum(root) << endl;
// assert(maxPathSum(root) == 6);
// root = new TreeNode(-10);
// root->left = new TreeNode(9);
// root->right = new TreeNode(20);
// root->right->left = new TreeNode(15);
// root->right->right = new TreeNode(7);
// assert(maxPathSum(root) == 42);
TreeNode* root = new TreeNode(-10);
root->left = new TreeNode(9);
root->right = new TreeNode(20);
root->right->left = new TreeNode(15);
root->right->right = new TreeNode(7);
int maxSum = 0;
vector<int> path = maxPathSum2(root, maxSum);
cout << "Max Path Sum: " << maxSum << endl;
cout << "Path: ";
for (int val : path) {
cout << val << " ";
}
cout << endl;
return 0;
}
Last updated