
Weights Division (hard version)

Easy and hard versions are actually different problems, so we advise you to read both statements carefully.

You are given a weighted rooted tree, vertex 1 is the root of this tree. Also, each edge has its own cost.

A tree is a connected graph without cycles. A rooted tree has a special vertex called the root. A parent of a vertex v is the last different from v vertex on the path from the root to the vertex v. Children of vertex v are all vertices for which v is the parent. A vertex is a leaf if it has no children. The weighted tree is such a tree that each edge of this tree has some weight.

The weight of the path is the sum of edges weights on this path. The weight of the path from the vertex to itself is 0.

You can make a sequence of zero or more moves. On each move, you select an edge and divide its weight by 2 rounding down. More formally, during one move, you choose some edge i and divide its weight by 2 rounding down (wi:=⌊wi2⌋).

Each edge i has an associated cost ci which is either 1 or 2 coins. Each move with edge i costs ci coins.

Your task is to find the minimum total cost to make the sum of weights of paths from the root to each leaf at most S. In other words, if w(i,j) is the weight of the path from the vertex i to the vertex j, then you have to make ∑v∈leavesw(root,v)≤S, where leaves is the list of all leaves.

You have to answer t independent test cases.


The first line of the input contains one integer t (1≤t≤2⋅104) — the number of test cases. Then t test cases follow.

The first line of the test case contains two integers n and S (2≤n≤105;1≤S≤1016) — the number of vertices in the tree and the maximum possible sum of weights you have to obtain. The next n−1 lines describe edges of the tree. The edge i is described as four integers vi, ui, wi and ci (1≤vi,ui≤n;1≤wi≤106;1≤ci≤2), where vi and ui are vertices the edge i connects, wi is the weight of this edge and ci is the cost of this edge.

It is guaranteed that the sum of n does not exceed 105 (∑n≤105).


For each test case, print the answer: the minimum total cost required to make the sum of weights paths from the root to each leaf at most S.


4 18
2 1 9 2
3 2 4 1
4 1 1 2
3 20
2 1 8 1
3 1 7 2
5 50
1 3 100 1
1 5 10 2
2 3 123 2
5 4 55 1
2 100
1 2 409 2








之后就是更暴力但有效的做法,用两个优先队列分别模拟一遍操作最“赚”的边直到满足≤s\le s≤s或者所有边都被操作成000,将每次减掉的值做一个前缀和。因为我们的总操作必定是由操作一堆代价为111的边和代价为222的边组合起来的,而我们求出的前缀和在对应的代价里都是最优解,所以最优的操作就是将两个前缀和的一部分(这一部分可以为空)组合起来,我们用两个指针不断取minminmin就能找到最终答案。



#define L long long
using namespace std;
const int M=1e5+5;
const int N=2e6+5;
struct sd{int to,wei,cst;};
struct edg{int tim,wei;L val;bool operator <(const edg &x)const{return x.val>val;}
int t,n,ans=1e7,t1,t2,p1,p2;
L tot,s;
L pre1[N],pre2[N];
bool vis[M];
void re()
{for(int i=0;i<=t1;++i)pre1[i]=0;for(int i=0;i<=t2;++i)pre2[i]=0;ans=1e7,tot=t1=t2=0;for(int i=1;i<=n;++i)vis[i]=0;for(int i=1;i<=n;++i)mmp[i].erase(mmp[i].begin(),mmp[i].end());for(;!dui1.empty();dui1.pop());for(;!dui2.empty();dui2.pop());
void in()
{int a,b,c,d;scanf("%d%I64d",&n,&s);for(int i=1;i<n;++i){scanf("%d%d%d%d",&a,&b,&c,&d);mmp[a].push_back((sd){b,c,d});mmp[b].push_back((sd){a,c,d});}
int dfs(int v)
{L tmp,re=0;int to,wei;edg add;vis[v]=1;if(mmp[v].size()==1&&v!=1)return 1;for(int i=mmp[v].size()-1;i>=0;--i){if(vis[mmp[v][i].to])continue;to=mmp[v][i].to,wei=mmp[v][i].wei;tmp=dfs(to);add.tim=tmp,add.wei=wei,add.val=tmp*(wei-wei/2);mmp[v][i].cst==1?dui1.push(add):dui2.push(add);tot+=tmp*wei;re+=tmp;}return re;
void ac()
{dfs(1);if(tot<=s){printf("0\n");return;}edg tmp;for(;!dui1.empty();){tmp=dui1.top();++t1,pre1[t1]=tmp.val+pre1[t1-1];if(tot-pre1[t1]<=s)break;tmp.wei>>=1,tmp.val=tmp.tim*1ll*(tmp.wei-tmp.wei/2);dui1.pop();if(tmp.val)dui1.push(tmp);}for(;!dui2.empty();){tmp=dui2.top();++t2,pre2[t2]=tmp.val+pre2[t2-1];if(tot-pre2[t2]<=s)break;tmp.wei>>=1,tmp.val=tmp.tim*1ll*(tmp.wei-tmp.wei/2);dui2.pop();if(tmp.val)dui2.push(tmp);}int p=0;for(int i=t1;i>=0;--i){for(;tot-pre1[i]-pre2[p]>s&&p<=t2;++p);if(tot-pre1[i]-pre2[p]<=s)ans=min(ans,i+p*2);}printf("%d\n",ans);
int main()
{scanf("%d",&t);for(int i=1;i<=t;++i)in(),ac(),re();
#define L long long
using namespace std;
const int M=1e5+5;
struct sd{int to,wei,cst;};
struct edg{int tim,wei;L val;bool operator <(const edg &x)const{return x.val>val;}
int t,n,ans;
L tot,s;
bool vis[M];
void re()
{ans=tot=0;for(int i=1;i<=n;++i)vis[i]=0;for(int i=1;i<=n;++i)mmp[i].erase(mmp[i].begin(),mmp[i].end());for(;!dui1.empty();dui1.pop());for(;!dui2.empty();dui2.pop());
void in()
{int a,b,c,d;scanf("%d%I64d",&n,&s);for(int i=1;i<n;++i){scanf("%d%d%d%d",&a,&b,&c,&d);mmp[a].push_back((sd){b,c,d});mmp[b].push_back((sd){a,c,d});}
int dfs(int v)
{L tmp,re=0;int to,wei;edg add;vis[v]=1;if(mmp[v].size()==1&&v!=1)return 1;for(int i=mmp[v].size()-1;i>=0;--i){if(vis[mmp[v][i].to])continue;to=mmp[v][i].to,wei=mmp[v][i].wei;tmp=dfs(to);add.tim=tmp,add.wei=wei,add.val=tmp*(wei-wei/2);mmp[v][i].cst==1?dui1.push(add):dui2.push(add);tot+=tmp*wei;re+=tmp;}return re;
void div(edg &x,int a)
{//  printf("edg:%d %d %I64d %d\n",x.tim,x.wei,x.val,a);tot-=x.val,ans+=a,x.wei>>=1,x.val=x.tim*1ll*(x.wei-x.wei/2);
void ac()
{dfs(1);L val,last=s+1;int flag;//return;for(edg tmp11,tmp12,tmp2;tot>s;){flag=1;//   printf("tot:%I64d\n",tot);if(dui1.empty()){tmp2=dui2.top();//    tot-=tmp2.val;div(tmp2,2),dui2.pop(),dui2.push(tmp2);continue;}tmp11=dui1.top();if(tot-tmp11.val<=s){ans+=1;break;}if(dui1.size()<2&&!dui2.empty())flag=3;else if(!dui2.empty()){val=tmp11.tim*1ll*(tmp11.wei/2-tmp11.wei/2/2);dui1.pop();tmp12=dui1.top();tmp2=dui2.top();if(val>tmp12.val)flag=2;if(tmp2.val>tmp11.val+max(val,tmp12.val))flag=3;}else{dui1.pop(),div(tmp11,1),dui1.push(tmp11);continue;}//  dec=max(tmp2.val,max(val+tmp11.val,tmp12.val+tmp11.val));//  tot-=dec;//    printf("fuck\n");//   printf("????flag:%d\n",flag);if(flag==3){// printf("flag:%d\n",flag);dui1.push(tmp11),div(tmp2,2),dui2.pop(),dui2.push(tmp2);}else if(flag==2){//   printf("flag:%d\n",flag);div(tmp11,1),last=tmp11.val,div(tmp11,1),dui1.push(tmp11);}else{//  printf("flag:%d\n",flag);div(tmp11,1),div(tmp12,1),last=tmp12.val,dui1.pop(),dui1.push(tmp11),dui1.push(tmp12);}//   ans+=2;}if(tot+last<=s)--ans;printf("%d\n",ans);
int main()
{scanf("%d",&t);for(int i=1;i<=t;++i)in(),ac(),re();

