HDU 5296 Annoying problem
这是用倍增法按照题解公式写的代码,除了题解的那种操作细节以及公式以外其他都是最简单的那种LCA…
// whn6325689
// Mr.Phoebe
// http://blog.csdn.net/u013007900
#include <algorithm>
#include <iostream>
#include <iomanip>
#include <cstring>
#include <climits>
#include <complex>
#include <fstream>
#include <cassert>
#include <cstdio>
#include <bitset>
#include <vector>
#include <deque>
#include <queue>
#include <stack>
#include <ctime>
#include <set>
#include <map>
#include <cmath>
#include <functional>
#include <numeric>
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;
#define eps 1e-9
#define PI acos(-1.0)
#define INF 0x3f3f3f3f
#define LLINF 1LL<<62
#define speed std::ios::sync_with_stdio(false);
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef complex<ld> point;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef vector<int> vi;
#define CLR(x,y) memset(x,y,sizeof(x))
#define CPY(x,y) memcpy(x,y,sizeof(x))
#define clr(a,x,size) memset(a,x,sizeof(a[0])*(size))
#define cpy(a,x,size) memcpy(a,x,sizeof(a[0])*(size))
#define mp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define lowbit(x) (x&(-x))
#define MID(x,y) (x+((y-x)>>1))
#define ls (idx<<1)
#define rs (idx<<1|1)
#define lson ls,l,mid
#define rson rs,mid+1,r
#define root 1,1,n
template<class T>
inline bool read(T &n)
{
T x = 0, tmp = 1;
char c = getchar();
while((c < '0' || c > '9') && c != '-' && c != EOF) c = getchar();
if(c == EOF) return false;
if(c == '-') c = getchar(), tmp = -1;
while(c >= '0' && c <= '9') x *= 10, x += (c - '0'),c = getchar();
n = x*tmp;
return true;
}
template <class T>
inline void write(T n)
{
if(n < 0)
{
putchar('-');
n = -n;
}
int len = 0,data[20];
while(n)
{
data[len++] = n%10;
n /= 10;
}
if(!len) data[len++] = 0;
while(len--) putchar(data[len]+48);
}
//-----------------------------------
const int MAXN=100010;
int dfn[MAXN],redfn[MAXN];
int n,q,cnt;
ll ans;
struct Edge
{
int to,next,c;
}e[MAXN<<1];
int head[MAXN],tot;
int dis[MAXN],p[MAXN][22],dep[MAXN];
std::set<int> s;
std::set<int>::iterator it1;
std::set<int, std::greater<int> > ss;
std::set<int, std::greater<int> >::iterator it2;
void init()
{
tot=0;dis[1]=0;dep[1]=0;p[1][0]=1;cnt=0;
CLR(head,-1);s.clear();ss.clear();ans=0;
}
void addedge(int u,int v,int c)
{
e[tot].to=v;
e[tot].c=c;
e[tot].next=head[u];
head[u]=tot++;
}
void dfs(int u,int fa=-1)
{
dfn[u]=++cnt;
redfn[cnt]=u;
int v;
for(int i=1;i<20;i++)
p[u][i]=p[p[u][i-1]][i-1];
for(int i=head[u];~i;i=e[i].next)
{
v=e[i].to;
if(fa==v) continue;
dep[v]=dep[u]+1;
dis[v]=dis[u]+e[i].c;
p[v][0]=u;
dfs(v,u);
}
}
int LCA(int u,int v)
{
if(dep[u]>dep[v]) swap(u,v);
int hu=dep[u],hv=dep[v];
int tu=u,tv=v;
for(int det=hv-hu,i=0;det;det>>=1,i++)
if(det&1)
tv=p[tv][i];
if(tu==tv) return tu;
for(int i=19;i>=0;i--)
{
if(p[tu][i]==p[tv][i])
continue;
tu=p[tu][i];
tv=p[tv][i];
}
return p[tu][0];
}
int main()
{
freopen("data.txt","r",stdin);
int T,cas=1;
read(T);
while(T--)
{
init();
read(n),read(q);
int op,u,v,c;
for(int i=1;i<n;i++)
{
read(u),read(v),read(c);
addedge(u,v,c);addedge(v,u,c);
}
dfs(1);
printf("Case #%d:\n",cas++);
while(q--)
{
read(op),read(c);
if(op==1)
{
if(s.empty())
{
printf("0\n");
s.insert(dfn[c]);ss.insert(dfn[c]);
continue;
}
if(s.find(dfn[c])!=s.end())
{
printf("%I64d\n",ans);
continue;
}
it1=s.upper_bound(dfn[c]);
it2=ss.upper_bound(dfn[c]);
if(it1==s.end() || it2==ss.end())
u=*s.begin(),v=*ss.begin();
else
u=*it1,v=*it2;
u=redfn[u],v=redfn[v];
ans += dis[c] - dis[LCA(u,c)] -dis[LCA(v,c)] + dis[LCA(u,v)];
s.insert(dfn[c]);
ss.insert(dfn[c]);
}
else
{
if(s.find(dfn[c])==s.end())
{
printf("%I64d\n",ans);
continue;
}
s.erase(dfn[c]);ss.erase(dfn[c]);
if(s.empty())
{
printf("0\n");
continue;
}
it1=s.upper_bound(dfn[c]);
it2=ss.upper_bound(dfn[c]);
if(it1==s.end() || it2==ss.end())
u=*s.begin(),v=*ss.begin();
else
u=*it1,v=*it2;
u=redfn[u],v=redfn[v];
ans -= dis[c] - dis[LCA(u,c)] -dis[LCA(v,c)] + dis[LCA(u,v)];
}
printf("%I64d\n",ans);
}
}
return 0;
}
用的也是倍增算LCA,但是有用到类似树链剖分那样在树上建一个线段树
#include "iostream"
#include "cstring"
#include "cstdio"
#include "vector"
#define PB push_back
#define MP make_pair
#define F first
#define S second
using namespace std;
const int N = 100010;
vector<pair<int, int> > e[N];
int weight[N];
int t[N << 2];
int f[N][21];
int pos[N], lpos[N], rpos[N], level[N];
int total;
void build(int l, int r, int idx)
{
t[idx] = 0;
if(l == r) return;
int m = (l + r) / 2;
build(l, m, idx << 1);
build(m + 1, r, idx << 1 | 1);
}
void dfs(int x, int pa, int dep)
{
level[x] = dep;
int now = f[x][0] = pa;
for(int i = 0; now != -1 && f[now][i] != -1; ++ i){
f[x][i + 1] = f[now][i];
now = f[now][i];
}
lpos[x] = ++total;
pos[total] = x;
for(int i = 0; i < e[x].size(); ++ i){
int u = e[x][i].F;
int w = e[x][i].S;
if(u == pa) continue;
weight[u] = weight[x] + w;
dfs(u, x, dep + 1);
}
rpos[x] = total;
}
void update(int l, int r, int id, int v, int idx)
{
if(l == r){
t[idx] += v;
return;
}
int m = (l + r) / 2;
if(m < id){
update(m + 1, r, id, v, idx << 1 | 1);
}else{
update(l, m, id, v, idx << 1);
}
t[idx] = t[idx << 1] + t[idx << 1 | 1];
}
int query(int l, int r, int L, int R, int idx)
{
if(l >= L && r <= R) return t[idx];
int m = (l + r) / 2;
if(m < L) return query(m + 1, r, L, R, idx << 1 | 1);
else if(m >= R) return query(l, m, L, R, idx << 1);
else return query(l, m, L, m, idx << 1) + query(m + 1, r, m + 1, R, idx << 1 | 1);
}
int jump(int x, int step)
{
int k = 0;
while(step){
if(step & 1) x = f[x][k];
step >>= 1;
k ++;
}
return x;
}
int n;
int findlca(int x, int num)
{
int l = 0, r = level[x];
while(l < r){
int m = (l + r) / 2;
int u = jump(x, m);
if(query(1, n, lpos[u], rpos[u], 1) >= num){
r = m;
}else{
l = m + 1;
}
}
return jump(x, l);
}
int find_element(int x)
{
int l = lpos[x], r = rpos[x];
while(l < r){
int m = (l + r) / 2;
int u = query(1, n, l, m, 1);
if(u >= 1){
r = m;
}else{
l = m + 1;
}
}
return pos[l];
}
int getnum(int x)
{
return query(1, n, lpos[x], rpos[x], 1);
}
int vis[N];
int main(void)
{
int num_tests, g = 0;
scanf("%d", &num_tests);
while(num_tests --){
total = 0;
printf("Case #%d:\n", ++ g);
int q, x, y, w;
scanf("%d %d", &n, &q);
for(int i = 1; i <= n; ++ i){
e[i].clear();
}
for(int i = 1; i < n; ++ i){
scanf("%d %d %d",&x, &y, &w);
e[x].PB(MP(y, w));
e[y].PB(MP(x, w));
}
build(1, n, 1);
memset(f, -1, sizeof(f));
memset(vis, 0, sizeof(vis));
weight[1] = 0;
dfs(1, -1, 0);
int num = 0;
int ans = 0;
for(int i = 1; i <= q; ++ i){
scanf("%d %d", &x, &y);
if(x == 1){
if(vis[y]){
cout << ans << '\n';
continue;
}
vis[y] = 1;
num ++;
if(num <= 1){
ans = 0;
}else{
int u = findlca(y, 1);
int v = find_element(u);
if(getnum(u) == num - 1){
if(getnum(u) == getnum(v)){
ans += weight[y] - weight[u];
ans += weight[v] - weight[u];
}else{
int t = findlca(v, num - 1);
if(t == u){
ans += weight[y] - weight[u];
}else{
ans += weight[y] - weight[u];
ans += weight[t] - weight[u];
}
}
}else{
ans += weight[y] - weight[u];
}
}
update(1, n, lpos[y], 1, 1);
}else{
if(!vis[y]){
cout << ans << '\n';
continue;
}
vis[y] = 0;
num --;
update(1, n, lpos[y], -1, 1);
if(num <= 1){
ans = 0;
}else{
int u = findlca(y, 1);
int v = find_element(u);
if(getnum(u) == num){
int target = findlca(v, num);
ans -= weight[y] - weight[u];
ans -= weight[target] - weight[u];
}else{
ans -= weight[y] - weight[u];
}
}
}
cout << ans << '\n';
}
}
return 0;
}
还没有评论,来说两句吧...