问题描述
Say you have an array for which the ith element is the price of a given stock on day i.
Design an algorithm to find the maximum profit. You may complete at most k transactions.
Example
Given prices = [4,4,6,1,1,4,2,5], and k = 2, return 6.
简而言之,就是从prices数组中,取出最多2*k个数,相邻的每对(比如第1个数和第2个数,第3个数和第4个数),其总和要最大。
Note
You may not engage in multiple transactions at the same time (i.e., you must sell the stock before you buy again).
题意分析
这是一个DP (Dynamic Programming)问题。
解法1
t(n, k):表示前n天,进行最多k次交易,所能获得的最大收益。并且最后一次交易发生在第n天
那么有:
t(n, k) = max { t(m, k-1) + prices[n] - prices_min(m+1, n) }, 0 <= m <= n-1
t(n, k) = max { t(n, k), t(n, k-1) } // 只最多进行 k-1 次交易
其中 prices_min(m+1, n) 表示第m+1天到第n天之间的最低股票价格
这样的话,可以对m从n-1变换到0,进行扫描一次,并且更新prices_min,这样的话,要求得到t(n, k) 的复杂度为 o(n),因此总的时间复杂度为 O(k * n^2)
解法2
解法一的复杂度还是有点大,当k趋近于n的时候,复杂度可以达到 O(n^3),不太可以接受。因此我们需要一个更好的dp方法。
g(n, k)表示前n天,进行k次交易,所获得最大收益
l(n, k) 表示前n天,进行k次交易,且最后一天进行了交易(卖出股票),所获得最大收益
g(n, k) = max{ g(n-1, k), l(n, k) } // 前n天最大收益,要不最后一次交易在前n-1天内,要不发生在第n天
l(n, k) = max{ g(m, k-1) + prices[n] - prices[m+1] }, 0 <= m <= n-1
l(n, k) = max{ l(n, k), l(n, k-1) }
在计算 l(n, k) 的时候,需要遍历 0 ~ n-1,如果暴力的话,复杂度还是 O(n),最后的复杂度就是O(k * n^2)
那么这里有个优化算法:
设max_diff 为 max { g(m, k-1) - prices[m+1] }, 0 <= m <= n-1。在计算 l(0, k) ~ l(n, k) 的时候,不断更新max_diff,那么这样的话,其实就把计算 l(n,k) 的平摊复杂度降了下来,复杂度为 O(1)。那么总的时间复杂度为 O(kn)。
具体代码:
int maxProfit(int k, vector<int> &prices) {
// write your code here
int n = prices.size();
if (k > n/2) k = n/2;
if (n <= 1) return 0;
int *l = new int[n+1];
int *g = new int[n+1];
for (int i = 1; i <= k; i++) {
if (i == 1) {
int current_min = prices[0], max_profit = 0;
for (int j = 0; j <= n; j++) {
if (j == 0) { l[j] = g[j] = 0; continue; }
current_min = min(current_min, prices[j-1]);
max_profit = max(max_profit, prices[j-1]- current_min);
l[j] = prices[j-1] - current_min;
g[j] = max_profit;
}
// print(l, n+1); print(g, n+1); printf("---\n");
} else {
int max_diff = g[0] - prices[0];
l[0] = g[0] = 0;
for (int j = 1; j <= n; j++) {
max_diff = max(max_diff, g[j-1] - prices[j-1]);
// printf("j:%d, max_diff:%d\n ", j, max_diff);
l[j] = max(l[j], max_diff + prices[j-1]);
}
// cout << endl;
// update g
for (int j = 1; j <= n; j++) {
g[j] = max(g[j-1], l[j]);
}
// print(l, n+1); print(g, n+1); printf("---\n");
}
}
return g[n];
}