糯米

TI DaVinci, gstreamer, ffmpeg
随笔 - 167, 文章 - 0, 评论 - 47, 引用 - 0
数据加载中……

POJ 1987 Distance Statistics 牛题 树的分治

这题很牛逼,是楼教主的《男人七题》的其中一道。
求:一棵树内最短距离小于K的点对数量
后来看了解题报告,原来树也是可以分治的。

分:
选取一条边,将一棵树分成两半,这两半的节点数要尽量相等。
首先,统计个每个节点的下面有多少个节点
然后,就知道每个边切断之后的情况了。选择最优的即可。

治:
分成两半之后,统计经过该条边的点对线段中,长度小于K的数目。

Amber 大牛论文的精辟描述如下:
 

Divide and Conquer.

Each iteration, we should choose an edge (u, v) and divide the tree into two parts disjoined by the edge. Due to avoid from degenerating, that partition edge should be chosen to divide two parts as equally as possible. Then we should merge two parts and count the valid pairs between them. It can be implemented by two sorted list that denotes the distances between u and the posterities of u and the distances between u and the posterities of v respectively. And like merge sort, use two scan line l, r in two list and maintain the property d(u, l) + d(v, r) <= k.


可见这位大牛的英文水平实在牛逼,英文说得比中文说得还清楚,赞一个。

按照这个思路,很费劲地写出了代码。还好,在1987上面还是勉强上榜啦!250ms那个就是我啦,哈哈。
但是在楼教主的题目1741 上面还是 TLE了。

后来找了一份能过1741的代码,在http://hi.baidu.com/shingray/blog/item/221362b079afc55d082302f0.html
一个大牛的博客上~
发现它的思路不是选择一条边来把树分成两份。
而是选择一个点来把树分成数份,然后计算经过该点的线段数目。
这样速度就快了,大牛的代码在1741上面只跑了170多ms。
将这份代码放到1987上面,也能跑到260ms。
所以这种方法还是很牛逼的!


我的垃圾代码(POJ 1987):
#include <stdio.h>
#include 
<stdlib.h>

#define MAX_VETXS 65536*2
#define MAX_EDGES (MAX_VETXS - 1)

#if 0
#define dbp printf
#else
#define dbp()
#endif

struct edge_node {
    
int w, i;
    
struct edge_node *next, *prev;
}
;

struct edge_node edges[MAX_EDGES], map[MAX_VETXS];
int edges_cnt;
int N, K, ans;

int cmp(const void *a, const void *b)
{
    
return *(int *)a - *(int *)b;
}


inline 
int max(int a, int b)
{
    
return a > b ? a : b;
}


#define list_foreach(_head, _t)    \
    
for (_t = (_head)->next; _t != _head; _t = (_t)->next)

inline 
void list_init(struct edge_node *t)
{
    t
->next = t->prev = t;
}


inline 
void list_add(struct edge_node *head, struct edge_node *t)
{
    head
->prev->next = t;
    t
->prev = head->prev;
    head
->prev = t;
    t
->next = head;
}


inline 
void list_del(struct edge_node *t)
{
    t
->prev->next = t->next;
    t
->next->prev = t->prev;
}


inline 
void list_rev(struct edge_node *t)
{
    t
->next->prev = t;
    t
->prev->next = t;
}


inline 
void edge_add(int a, int b, int w)
{
    
struct edge_node *= &edges[edges_cnt++];

    t
->= b;
    t
->= w;
    list_add(
&map[a], t);
}


struct part_info {
    
int u, v, e, cnt_v;
}
;

inline 
void divide(int i, int *arr, int *len, int cnt, struct part_info *pi)
{
    
static struct {
        
int i, e, depth, cnt, stat, root;
    }
 stk[MAX_VETXS], *sp, *top;
    
static int vis[MAX_VETXS], tm, best, val;
    
int *orig = arr;
    
struct edge_node *e;
    
    best 
= cnt;
    tm
++;
    top 
= stk + 1;
    top
->= i;
    top
->depth = top->cnt = top->stat = top->root = 0;
    vis[i] 
= tm;
    
while (top > stk) {
        sp 
= top;
        
if (sp->stat) {
            stk[sp
->root].cnt += sp->cnt;
            
if (arr && sp->depth <= K)
                
*arr++ = sp->depth;
            val 
= max(sp->cnt, cnt - sp->cnt);
            
if (val < best) {
                best 
= val;
                pi
->= stk[sp->root].i;
                pi
->= sp->i;
                pi
->= sp->e;
                pi
->cnt_v = sp->cnt;
            }

            top
--;
            
continue;
        }

        sp
->stat++;
        list_foreach(
&map[sp->i], e) {
            
if (vis[e->i] == tm)
                
continue;
            vis[e
->i] = tm;
            top
++;
            top
->= e->i;
            top
->= e - edges;
            top
->depth = sp->depth + e->w;
            top
->cnt = 1;
            top
->stat = 0;
            top
->root = sp - stk;
        }

    }


    
if (len)
        
*len = arr - orig;
}


void conquer(struct part_info *pi, int cnt)
{
    
struct part_info pl, pr;
    
static int arr_l[MAX_VETXS], arr_r[MAX_VETXS], len_l, len_r, l, r;

    
if (cnt <= 1)
        
return ;

    list_del(
&edges[pi->e]);
    list_del(
&edges[pi->^ 1]);
    
    divide(pi
->u, arr_l, &len_l, cnt - pi->cnt_v, &pl);
    divide(pi
->v, arr_r, &len_r, pi->cnt_v, &pr);
    
    qsort(arr_l, len_l, 
sizeof(arr_l[0]), cmp);
    qsort(arr_r, len_r, 
sizeof(arr_r[0]), cmp);

    r 
= len_r - 1;
    
for (l = 0; l < len_l; l++{
        
while (r >= 0 && arr_l[l] + arr_r[r] + edges[pi->e].w > K)
            r
--;
        ans 
+= r + 1;
    }

    
    conquer(
&pl, cnt - pi->cnt_v);
    conquer(
&pr, pi->cnt_v);
    
    list_rev(
&edges[pi->e]);
    list_rev(
&edges[pi->^ 1]);
}


inline 
void solve_v2()
{
    
struct part_info pi;

    divide(
1, NULL, NULL, N, &pi);
    conquer(
&pi, N);
}


int main()
{
    
int i, a, b, w, m;
    
char str[16];

    freopen(
"e:\\test\\in.txt""r", stdin);

    scanf(
"%d%d"&N, &m);
    edges_cnt 
= 0;
    
for (i = 1; i <= N; i++)
        list_init(
&map[i]);
    
for (i = 0; i < m; i++{
        scanf(
"%d%d%d%s"&a, &b, &w, str);
        edge_add(a, b, w);
        edge_add(b, a, w);
    }

    scanf(
"%d"&K);
    ans 
= 0;
    solve_v2();
    printf(
"%d\n", ans);

    
return 0;
}



大牛的代码(POJ 1987):
#include <algorithm>
#include 
<cstdio>
#include 
<cstring>
#include 
<limits>
#include 
<queue>
#include 
<vector>
using namespace std;

const int MAX_N = 65536*2;
bool flag[MAX_N];
int k, n, ret, v[MAX_N];
queue
<pair<intint> > q;

struct edge{int v, w; edge *next; } *e[MAX_N], data[MAX_N*2-2], *it;
void insert(int u, int v, int w)
{
   
*it = (edge){v, w, e[u]}; e[u] = it++;
   
*it = (edge){u, w, e[v]}; e[v] = it++;
}


int count(int *first, int *last)
{
   
int ret = 0;
   sort(first, last
--);
   
while (first < last)
       
if (*first+*last <= k) ret += last-first++;
       
else --last;
   
return ret;
}


int best_size, center;
int centerOfGravity(int root, int pred)
{
   
int max_sub = 0, size = 1;
   
for (edge *it = e[root]; it; it = it->next)
       
if (it->!= pred && flag[it->v])
       
{
           
int t = centerOfGravity(it->v, root);
           size 
+= t;
           
if (t > max_sub) max_sub = t;
       }

   
if (q.front().second-q.front().first-max_sub > max_sub)
       max_sub 
= q.front().second-q.front().first-max_sub;
   
if (max_sub < best_size)
       best_size 
= max_sub, center = root;
   
return size;
}


int dists[MAX_N], len;
void find(int root, int pred, int dist)
{
   v[len] 
= root;
   dists[len
++= dist;
   
int last = len;
   
for (edge *it = e[root]; it; it = it->next)
       
if (it->!= pred && flag[it->v])
       
{
           find(it
->v, root, dist+it->w);
           
if (pred == -1)
           
{
               q.push(make_pair(last, len));
               ret 
-= count(dists+last, dists+len);
               last 
= len;
           }

       }

}


int main()
{
    
int m;
    
char str[16];
   scanf(
"%d%d"&n, &m);
   
{
       it 
= data;
       memset(e, 
0sizeof(e[0])*n);
       
for (int i = n; --i; )
       
{
           
int u, v, w;
           scanf(
"%d%d%d%s"&u, &v, &w, str);
           
--u; --v;
           insert(u, v, w);
       }

       scanf(
"%d"&k);

       ret 
= 0;
       
for (int i = 0; i < n; ++i)
           v[i] 
= i;
       
for (q.push(make_pair(0, n)); !q.empty(); q.pop())
       
{
           
if (q.front().first == q.front().second-1continue;
           
for (int i = q.front().first; i < q.front().second; ++i)
               flag[v[i]] 
= true;

           best_size 
= numeric_limits<int>::max();
           centerOfGravity(v[q.front().first], 
-1);

           len 
= q.front().first;
           find(center, 
-10);
           ret 
+= count(dists+q.front().first, dists+q.front().second);

           
for (int i = q.front().first; i < q.front().second; ++i)
               flag[v[i]] 
= false;
       }

       printf(
"%d\n", ret);
   }

}

posted on 2010-04-25 22:30 糯米 阅读(743) 评论(0)  编辑 收藏 引用 所属分类: POJ


只有注册用户登录后才能发表评论。
网站导航: 博客园   IT新闻   BlogJava   知识库   博问   管理