斜率優化
見 https://drive.google.com/file/d/1w4Lnxy5OuNN1rJ8nz9nBqakPGhS40g6B/view
介紹
斜率最佳化是一類利用決策單調性來最佳化DP轉移的最佳化方式。因為其原理中一部分形似斜率,故名為斜率最佳化。能用斜率優化的題目的轉移式一般來說如下:
\[\large dp(i)=\max\limits_{0\le j < i} \{ a(j) \times f(i) + b(j) \}\]
轉移 \(dp(i)\) 時想像成二維平面上有一堆直線 \(y = a(j) \times x + b(j)\) 。要找到這些直線和 \(x = f(i)\) 的所有交點中,\(y\) 座標最大的數值。
觀察這些直線可以發現,這些直線所形成的下凸包,會是轉移答案的位置。有些線段(虛線)不在下凸包的上,可以從轉移名單上淘汰。
維護這個凸包,就可以直接查詢在 \(x = f(i)\) ,最大的值是多少。所以我們要解決的問題就是
如何快速查詢凸包
加入新直線後如何維護這個凸包
以下,我們針對斜率與查詢的單調性分別討論幾種情況的處理方式。
斜率與查詢單調
加入新直線
\(L_1\) 表示當前斜率次大的直線,\(L_2\) 表示當前斜率次大的直線,\(L_3\) 表示當前要加入的直線,有 \(L_1.a \le L_2.a \le L_3.a\) 。
可以觀察到若 \(L_2,L_3\) 的交點(紅色)在 \(L_1,L_2\) 的交點(藍色)的左側,\(L_2\) 將會被刪掉。實作上使用一個 deque 按照斜率小到大儲存在凸包上的直線,加入新直線時查看 deque 尾端直線是否會被淘汰。
如果上面的圖片還是無法理解,這裡是網路上的動圖 。將式子列出來後,我們就可以寫出 check\((L_1,L_2,L_3)\) 的代碼 :
code
bool check ( Line l1 , Line l2 , Line l3 ) {
return ( l2 . b - l1 . b ) * ( l2 . a - l3 . a ) >= ( l3 . b - l2 . b ) * ( l1 . a - l2 . a );
}
查詢 x = f(i)
我們從 deque 的 front 每次看最前面的兩條線,若發現代入斜率大(藍色)的會比代入斜率小(紅色)的還大代表要往右,否則左邊的就是答案
因為詢問的位置 \(x\) 只會越來越大,因此被淘汰的直線必不會是後面的詢問的答案,所以我們可以 pop_front 直到屬於找到當前代入 \(x=f(i)\) 最大的那條線
總結
維護一個 deque<pair<int, int>>
代表直線的 (a, b)。先找 x = f(i) 的答案 dp(i),所以我們一直去判斷最前面的值是否為最大值 (和第二個比較),若不是則持續 pop_front。然後即可求出當前新直線的 a, b。然後我們就要加入這條新直線,用 check\((L_1,L_2,L_3)\) 一直去判斷尾端直線是否要被 pop_back。架構等價於單調隊列,轉移總複雜度均攤為 \(O(n)\) 。實作見下面 CSES Monster Game I 的 code。
當斜率一樣時,會發生什麼事?
若斜率相同時,b 較大的應該要留下來,b 較小的需被淘汰。因此我下面的寫法當新的線加入後,會先特判斜率是否與 deque 的尾端相同,是的話就只留下 b 大的。但在 CSES - Monster Game I 這題很多人是沒有特判的,而是直接使用以下代碼來看是否淘汰尾端:
bool check ( Line l1 , Line l2 , Line l3 ) {
return ( l3 . b - l2 . b ) * ( l1 . a - l2 . a ) <= ( l2 . b - l1 . b ) * ( l2 . a - l3 . a );
}
這個式子是將我們的不等式移向得到的,但注意到任兩條直線斜率相同時,我們未移向的不等式的分母會是 0,那移向後是否還具有正確性?
我們就分三種情況討論 \(L_1, L_2\) 斜率相同,\(L_2, L_3\) 斜率相同,\(L_1, L_2, L_3\) 斜率相同。前兩種情況我們會發現,移向過後的式子其中一側變為 0,所以關鍵是在另一側,也就是看兩條線的節距是正是負。以 \(L_2, L_3\) 斜率相同來說,右側會變為 0,所以只剩下 (l3.b - l2.b) * (l1.a - l2.a) <= 0,又 (l1.a - l2.a) 必為負數,所以不等式的成立一切取決於 (l3.b - l2.b) 的正負性。而三條線斜率都相同情況我們的式子一定都是 0 <= 0,所以會成立,這時若 l1.b, l2.b, l3.b 非遞增的話,維護就會出現問題。但在 CSES - Monster Game I 內,b = dp(j),恰好 dp(j) 又是單調遞增的(這是因為 s 遞增),因此才會 AC,若今天 b 沒有限制的話就會出問題。
總結來說,把斜率相同的情況特判掉是比較安全的處理方式。
題目
CSES - Monster Game I
給 \(n\) 個怪獸,你必須打敗第 \(n\) 隻怪獸才能贏。打敗第 \(i\) 隻怪獸會花 \(s_i\times f_j\) 的時間,其中 \(j\) 為你上次打敗的怪獸的編號,如果沒有上一隻,則 \(f_j=x\) 。最少花多少時間可以贏
\(n\le 2\times 10^5,1\le x\le 10^6,1 \le s_1 \le \dots \le s_n \le 10^6,x \ge f_1 \ge \dots \ge f_n \ge 1\)
思路
\[dp(i)=\min \limits_{0\le j < i} \{f_j\times s_i + dp(j) \}\]
我們可以將式子改成 :
\[dp(i)=\max \limits_{0\le j < i} \{(-f_j)\times s_i + dp(j) \}\]
這樣就變成上面標準斜率優化的轉移式了,\(a=-f_j,b=dp(j)\) ,最後的答案記得是 \(-dp[n]\)
code
#include <bits/stdc++.h>
#define int long long
#define pii pair<int, int>
#define pb push_back
using namespace std ;
int n ;
int dp [ 200005 ], f [ 200005 ], s [ 200005 ];
struct Line {
int a , b ;
int get_value ( int x ) {
return a * x + b ;
}
};
bool check ( Line l1 , Line l2 , Line l3 ) {
return ( l2 . b - l1 . b ) * ( l2 . a - l3 . a ) >= ( l3 . b - l2 . b ) * ( l1 . a - l2 . a );
}
void solve () {
deque < Line > dq ;
dq . pb ({ f [ 0 ], dp [ 0 ]});
for ( int i = 1 ; i <= n ; i ++ ) {
// 刪掉過期的直線
while ( dq . size () >= 2 && dq [ 0 ]. get_value ( s [ i ]) < dq [ 1 ]. get_value ( s [ i ])) {
dq . pop_front ();
}
dp [ i ] = dq [ 0 ]. get_value ( s [ i ]);
Line l = { f [ i ], dp [ i ]};
// 加入新的直線, 並且檢查會不會把舊的直線給蓋住
while ( dq . size () >= 2 && check ( dq [ dq . size () - 2 ], dq . back (), l )) {
dq . pop_back ();
}
dq . push_back ( l );
}
}
signed main () {
ios :: sync_with_stdio ( 0 );
cin . tie ( 0 );
cin >> n >> f [ 0 ];
for ( int i = 1 ; i <= n ; i ++ )
cin >> s [ i ];
for ( int i = 1 ; i <= n ; i ++ )
cin >> f [ i ];
for ( int i = 0 ; i <= n ; i ++ )
f [ i ] = - f [ i ];
solve ();
cout << - dp [ n ] << '\n' ;
}
TOI 2022 二模 pD. rectangle
給你 \(N\) 個矩形,第 \(i\) 個個矩形有參數高 \(H_i\) 、寬 \(W_i\) 、亮度 \(D_i\) 、成本參數 \(C_i ,F_i\) 。
其中 \(H\) 是單調遞增的,覆蓋 \([L ,R]\) 的矩形需要成本
\[C_R \frac{\sum_{i = L}^{R}D_iW_i}{\sum_{i = L}^{R}W_i}\sum_{i = L}^{R}W_i + F_R\]
問最小覆蓋成本。
\(1 \leq N \leq 2 \times 10^5, 1 \leq W_i, H_i \leq 10^6, -16 \leq D_i \leq 16 , -10^6 \leq C_i , F_i \leq 10^6\)
吐鈔機 2
有 \(N\) 台機器,其中第 \(i\) 台會在時間 \(D_i\) 拍賣,價格為 \(P_i\) ,此機器每天會生產 \(G_i\) 元,且將這台機器轉賣可得 \(R_i\) 元,一個時間只能擁有一台機器,你一開始有 \(C_i\) 元,求在第 \(D\) 天後你最多可以賺進多少錢
\(N\le 10^5,D\le 10^9\)
變化 : 斜率會過期
若 \(L_2,L_3\) 的交點在 \(L_1,L_2\) 的交點的左側,且 \(L_1\) 過期右界在 \(L_2,L_3\) 的右側,\(L_2\) 將會被刪掉
變化 : 缺少查詢單調
二分搜 \(x=f(i)\)
code
struct Line { // ax+b;
int a ;
int b ;
int get_value ( int x ) const {
return a * x + b ;
}
};
long double intersection_x ( Line f , Line g ) {
// a1 x + b1 = a2 x + b2
// x = (b2-b1) / (a1-a2)
return 1.0 * ( g . b - f . b ) / ( f . a - g . a );
}
struct LineContainer {
vector < Line > lines ;
bool ok ( Line f , Line g , Line h ) {
// 判斷 f, g 的交點是否在 f,h 的交點左邊
return intersection_x ( f , g ) < intersection_x ( g , h );
}
void insert ( Line l ) {
// 假設 insert 的斜率遞增
int m = lines . size ();
if ( m >= 1 && lines [ m - 1 ]. a == l . a ) {
if ( lines [ m - 1 ]. b >= l . b ) return ;
lines . pop_back ();
m -- ;
}
while ( m >= 2 && ! ok ( lines [ m - 2 ], lines [ m - 1 ], l )) {
lines . pop_back ();
m -- ;
}
lines . push_back ( l );
}
int get_max ( int x ) {
// 找到第一條直線 lines[i]
// lines[i] 和 lines[i+1] 的交點大於等於 x
int m = lines . size ();
if ( m == 1 ) {
return lines [ 0 ]. get_value ( x );
}
int l = 0 , r = m - 2 ;
while ( l != r ) {
int mid = ( l + r ) / 2 ;
int p = intersection_x ( lines [ mid ], lines [ mid + 1 ]);
if ( p >= x ) {
r = mid ;
} else {
l = mid + 1 ;
}
}
return max ( lines [ r ]. get_value ( x ), lines [ m - 1 ]. get_value ( x ));
}
};
不具單調性
動態凸包
加入新直線
用一個 set 維護當前在凸包上的直線,按照斜率由小到大儲存,當要新增一條新的直線時,先直接放入 set 內,和位於該直線前後的直線用 check 判斷需不需要被 pop 掉。每條直線最多進去和出來 set 一次,每次花費 \(O(\log n)\) 的時間維護,轉移總複雜度為 \(O(n\log n)\) 。
查詢 x = f(i)
直接二分搜 \(x=f(i)\)
code
struct Line { // ax+b;
int a ;
int b ;
int get_value ( int x ) const {
return a * x + b ;
}
};
long double intersection_x ( Line f , Line g ) {
// a1 x + b1 = a2 x + b2
// x = (b2-b1) / (a1-a2)
return 1.0 * ( g . b - f . b ) / ( f . a - g . a );
}
struct LineContainer {
vector < Line > lines ;
bool ok ( Line f , Line g , Line h ) {
// 判斷 f, g 的交點是否在 f,h 的交點左邊
return intersection_x ( f , g ) < intersection_x ( g , h );
}
void insert ( Line l ) {
// 假設 insert 的斜率遞增
int m = lines . size ();
if ( m >= 1 && lines [ m - 1 ]. a == l . a ) {
if ( lines [ m - 1 ]. b >= l . b ) return ;
lines . pop_back ();
m -- ;
}
while ( m >= 2 && ! ok ( lines [ m - 2 ], lines [ m - 1 ], l )) {
lines . pop_back ();
m -- ;
}
lines . push_back ( l );
}
int get_max ( int x ) {
// 找到第一條直線 lines[i]
// lines[i] 和 lines[i+1] 的交點大於等於 x
int m = lines . size ();
if ( m == 1 ) {
return lines [ 0 ]. get_value ( x );
}
int l = 0 , r = m - 2 ;
while ( l != r ) {
int mid = ( l + r ) / 2 ;
int p = intersection_x ( lines [ mid ], lines [ mid + 1 ]);
if ( p >= x ) {
r = mid ;
} else {
l = mid + 1 ;
}
}
return max ( lines [ r ]. get_value ( x ), lines [ m - 1 ]. get_value ( x ));
}
};
題目
CSES - Monster Game II
給 \(n\) 個怪獸,你必須打敗第 \(n\) 隻怪獸才能贏。打敗第 \(i\) 隻怪獸會花 \(s_i\times f_j\) 的時間,其中 \(j\) 為你上次打敗的怪獸的編號,如果沒有上一隻,則 \(f_j=x\) 。最少花多少時間可以贏
\(n\le 2\times 10^5,1\le x\le 10^6,1\le s_i,f_i\le 10^6\)
code
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std ;
#define int long long
struct Line {
int a , b ;
int operator ()( int x ) const {
return a * x + b ;
}
};
struct LineContainer {
static constexpr int LIMIT = 1e6 ;
static constexpr int SIZE = LIMIT * 4 ;
static const int INF = 1e18 ;
vector < int > lo = vector < int > ( SIZE );
vector < int > hi = vector < int > ( SIZE );
vector < Line > seg = vector < Line > ( SIZE , { 0 , INF });
void build ( int i = 1 , int l = 1 , int r = LIMIT ) {
lo [ i ] = l ;
hi [ i ] = r ;
if ( l == r ) return ;
int mid = ( l + r ) / 2 ;
build ( 2 * i , l , mid );
build ( 2 * i + 1 , mid + 1 , r );
}
void insert ( Line L , int i = 1 ) {
int l = lo [ i ], r = hi [ i ];
if ( l == r ) {
if ( L ( l ) < seg [ i ]( l )) seg [ i ] = L ;
return ;
}
int mid = ( l + r ) / 2 ;
if ( seg [ i ]. a < L . a ) swap ( seg [ i ], L );
if ( seg [ i ]( mid ) > L ( mid )) {
swap ( seg [ i ], L );
insert ( L , 2 * i );
} else {
insert ( L , 2 * i + 1 );
}
}
int query ( int x , int i = 1 ) {
int l = lo [ i ], r = hi [ i ];
if ( l == r ) return seg [ i ]( x );
int mid = ( l + r ) / 2 ;
if ( x <= mid ) {
return min ( seg [ i ]( x ), query ( x , 2 * i ));
} else {
return min ( seg [ i ]( x ), query ( x , 2 * i + 1 ));
}
}
};
int solve ( int n , int x , const vector < int > & s , const vector < int > & f ) {
LineContainer ds ;
ds . build ();
ds . insert ({ x , 0 });
for ( int i = 0 ; i < n - 1 ; i ++ ) {
int v = ds . query ( s [ i ]);
ds . insert ({ f [ i ], v });
}
return ds . query ( s [ n - 1 ]);
}
signed main () {
cin . tie ( 0 );
cin . sync_with_stdio ( 0 );
int n , X ;
cin >> n >> X ;
vector < int > s ( n ), f ( n );
for ( int i = 0 ; i < n ; i ++ ) cin >> s [ i ];
for ( int i = 0 ; i < n ; i ++ ) cin >> f [ i ];
int ans = solve ( n , X , s , f );
cout << ans << '\n' ;
}
洛谷 P4097. 模板】李超线段树 / [HEOI2013] Segment
有 \(n\) 個石頭,第 \(i\) 個石頭的高度是 \(h_i\) ,目標從第 1 個石頭跳到第 n 個石頭。若現在位於第 \(j\) 個石頭,可以跳到任何一個編號大於 \(j\) 的石頭 \(i\) ,但需要花費 \((h_i-h_j)^2 + c\) 能量。請問最少需要花費少能量才能到達石頭 \(n\) ?