NC14683 储物点距离
NC14683 储物点距离
题目
题目描述
一个数轴,每一个储物点会有一些东西,同时它们之间存在距离。
每次给个区间 \([l,r]\) ,查询把这个区间内所有储物点的东西运到另外一个储物点的代价是多少?
比如储物点 \(i\) 有 \(x\) 个东西,要运到储物点 \(j\) ,代价为 \(x \cdot dist( i , j )\)
\(dist\) 就是储物点间的距离。
输入描述
第一行两个数表示 \(n\) , \(m\)第二行 \(n-1\) 个数,第i个数表示第 \(i\) 个储物点与第 \(i+1\) 个储物点的距离 \(a_i\) 第三行 \(n\) 个数,表示每个储物点的东西个数 \(b_i\) 之后 \(m\) 行每行三个数 \(x\) , \(l\) , \(r\)。
表示查询要把区间 \([l,r]\) 储物点的物品全部运到储物点 \(x\) 的花费
每次查询独立
输出描述
对于每个询问输出一个数表示答案
答案对 \(1000000007\) 取模
示例1
输入
5 5 2 3 4 5 1 2 3 4 5 1 1 5 3 1 5 2 3 3 3 3 3 1 5 5
输出
125 72 9 0 70
备注
对于\(100\%\) 的数据\(n , m \leq 200000\) , \(0 \leq a_i,b_i \leq 2000000000\) 。
题解
思路
考虑三种情况 \(x \geq r\) , \(x \leq l\) , \(l < x < r\) 的花费 \(P_{l,r}\) 。
\(dist\) 的性质:\(dist(i,j) = dist(1,j)-dist(1,i)\)
-
\(x \geq r\)
\(P_{l,r} = \sum_{i=l}^{r} b_i \cdot dist(x,i) = \sum_{i=l}^r b_i \cdot (dist(1,x) - dist(1,i)) = dist(1,x) \sum_{i=l}^r b_i - \sum_{i=l}^r b_i \cdot dist(1,i)\)
-
\(x \leq l\)
\(P_{l,r} = \sum_{i=l}^{r} b_i \cdot dist(x,i) = \sum_{i=l}^r b_i \cdot (dist(1,i) - dist(1,x)) = \sum_{i=l}^r b_i \cdot dist(1,i) - dist(1,x) \sum_{i=l}^r b_i\)
-
\(l \leq x \leq r\)
\(\begin{aligned} P_{l,r} &= \sum_{i=l}^{r} b_i \cdot dist(x,i)\\ &= \sum_{i=l}^x b_i \cdot (dist(1,x) - dist(1,i)) + \sum_{i=x+1}^r b_i \cdot (dist(1,i) - dist(1,x))\\ &= dist(1,x) \sum_{i=l}^x b_i - \sum_{i=l}^x b_i \cdot dist(1,i) + \sum_{i=x+1}^r b_i \cdot dist(1,i) - dist(1,x) \sum_{i=x+1}^r b_i \end{aligned}\)
我们发现其中反复用到 \(dist(1,i)\) , \(\sum_{i=a}^b b_i\),\(\sum_{i=a}^b b_i \cdot dist(1,i)\)。
\(dist(1,i)\) 可以用 \(a_i\) 前缀和维护;\(\sum_{i=a}^b b_i\) 可以用 \(b_i\) 前缀和维护;\(\sum_{i=a}^b b_i \cdot dist(1,i)\) 可以在 \(a_i\) 前缀和的基础上在计算 \(b_i\) 前缀和的过程维护。
代码
#include <bits/stdc++.h> using namespace std; const int mod = 1000000007; long long a[200007],b[200007],ab[200007]; int main(){ std::ios::sync_with_stdio(0),cin.tie(0),cout.tie(0); int n,m; cin>>n>>m; for(int i = 2;i<=n;i++){ cin>>a[i]; a[i] = (a[i] + a[i-1])%mod; } for(int i = 1;i<=n;i++){ cin>>b[i]; ab[i] = (ab[i-1]+(a[i]%mod*b[i]%mod)%mod)%mod; b[i] = (b[i] + b[i-1])%mod; } for(int i = 0;i<m;i++){ int x,l,r; cin>>x>>l>>r; if(x<=l){ cout<<( (ab[r]-ab[l-1]+mod)%mod - (b[r]-b[l-1]+mod)%mod*a[x]%mod%mod + mod )%mod <<'\n'; } else if(x>=r){ cout<<( (b[r]-b[l-1]+mod)%mod*a[x]%mod%mod - (ab[r]-ab[l-1]+mod)%mod + mod )%mod <<'\n'; } else{ cout<<( (ab[r]-ab[x-1]+mod)%mod - (b[r]-b[x-1]+mod)%mod*a[x]%mod%mod + (b[x]-b[l-1]+mod)%mod*a[x]%mod%mod - (ab[x]-ab[l-1]+mod)%mod + 2*mod )%mod <<'\n'; } } return 0; }