跳转至

WQS 二分

引入

本文介绍利用 WQS 二分优化动态规划问题的方法。在不同的文章中,它也常称作带权二分、凸优化 DP、凸完全单调性 DP、Lagrange 乘子法等,在国外也称作 Aliens Trick。它最早由王钦石在《浅析一类二分方法》一文中总结。

WQS 二分通常用于解决这样一类优化问题:它们带有数量限制,直接求解代价较高;但一旦去除这一限制,问题本身就变得容易得多。

比如,假设要解决的问题是,要从 个物品中选取 个,并最优化某个较复杂的目标函数。如果设从前 个物品中选取 个,目标函数的最优值为 ,那么原问题的答案就是 。这类问题中,状态转移方程通常是二维的。直接实现该状态转移方程,时间复杂度是 的,难以接受。

进一步假设,没有数量限制的最优化问题容易解决。但是,选取到的最优数量未必满足原问题的数量限制。假设选取的物品过多。那么,就可以考虑在选取物品时,为每个选取到的物品都附加一个固定大小的惩罚 (即「带权二分」中的「权」),仍然解没有数量限制的最优化问题。根据 的取值不同,选取到的最优数量也会有所不同;而且,随着 的变化,选取到的最优数量也是单调变化的。所以,可以通过二分,找到 使得选取到的最优数量恰为 。假设此时目标函数的最优值为 ,那么,只要消除额外附加的惩罚造成的价值损失,就能得到原问题的答案 。假设单次求解附加惩罚的问题的复杂度是 的,那么,算法的整体复杂度也就降低到了 ,其中, 是二分 需要的次数。

这就是 WQS 二分的基本想法。但是,这一想法能够行得通,前提是 关于 是凸的。否则,可能不存在使得最优数量恰为 的附加惩罚 。这也是这种 DP 优化方法常常称为「凸优化 DP」或「凸完全单调性 DP」的原因。

传统方法

设非空集合 为(有限的)决策空间, 为目标函数,且另有函数 用于施加限制。需要求解的问题,可以看作是计算如下最优化问题的价值函数 在某处的取值:

比如,对于前文提到的限制数量的问题, 可以理解为所有物品集合的子集族, 是单个子集, 是单个子集的价值函数, 是子集 中的元素个数。当然, 并非只能是数量限制,后文提供了更为广泛的限制条件的例子。

约定

为了行文方便,本文仅讨论最小化目标函数的问题。最大化目标函数的问题与之相仿,只是需要将本文中的(下)凸函数相应地替换成凹函数(或称上凸函数)。或者,可以通过添加负号,将最大化目标函数的问题,转化为最小化它的相反数的问题。

几何直观

因为算法竞赛中遇到的大多数问题都是组合优化问题,决策空间 通常没有良好的结构,所以,可以转而考察集合

传统方法能够解决的主要是 的情形,即只有一个限制的情形。下图提供了此时点集 的一种可能的图示。

图中的红点和蓝点是 中所有可能的选择投影在平面 上得到的集合 。于是,原问题所要求的就是横坐标为 的那些点中,纵坐标的最小值 。当 变动时,所有这样的点 就构成了图中的红点的集合。

为了求得点 的纵坐标,可以考虑用斜率为 的直线去切集合 。如图所示,当直线的斜率选取得恰当时,经过点 的那条直线,是所有经过集合 中的点且斜率为 的直线中,截距 最小的。将这一最小值记为

那么,因为 同样位于该直线上,就可以得到原问题的解

假设对于所有合理范围的 ,上述函数 都是容易求解的。这在算法竞赛中常常是成立的,因为它去掉了原问题中的限制条件。那么,现在面临的最为重要的两个问题,就是

  1. 是否存在这样的直线斜率 ,使得它的截距最小值恰好取得在点 处,以及
  2. 如果存在,如何找到这样的斜率

第一个问题相对容易解决。因为当直线斜率 发生变化时,所有这些直线切出的集合(即它们对应的上半平面的交)必然是一个凸集。因此,这些直线能够经过某个点,当且仅当这个点在该凸集的下凸壳上。这等价于说,函数 凸函数

第二个问题则更为精细。因为所求点的横坐标已经知道是 ,所以,一个自然的思路是,计算 时,顺便求出限制函数 在当前最优解 处的取值。比如,在前文提到的例子中,求解带惩罚的问题时,可以记录带惩罚的目标函数取得最优解时,选取的物品数量。然后,将 与所期望的 进行比较,并相应调整下次计算时的 的取值。这就是最为传统的 WQS 二分的方法。

总结一下,传统 WQS 二分的基本流程如下:

  1. 初始时,选取一个 的合理的区间;
  2. 在当前的区间中选择一个
  3. 求解带惩罚的问题 ,并记录它的最优解 的取值
  4. 如果 ,就得到原问题的最优价值 ,直接结束算法;
  5. 否则,根据 的大小关系,调整 的区间,并回到步骤 2。

这一基本流程已经足以解决一些问题,但并不完善。接下来,本文将讨论对这一基本流程的改进。

共线情形的处理

在应用基本流程时,首先遇到的问题就是共线情形无法正确处理。

如果点集 的下凸壳上有三个及以上的红点共线,那么在上述基本流程中,可能无法正确地判断 的大小关系。比如,设共线的三个红点的横坐标分别为 ,且它们共线的直线的斜率为 。那么,要正确求解 ,就必须保证算法终止时,最后计算的问题是 ,因为 是唯一一个最小化截距时能够经过点 的直线的斜率。但是,因为在求解 的过程中,记录的 可能是 中的任意一个。如果记录到的 不等于 ,那么算法将错误地继续运行,并向着背离 的方向调整 的区间,最终将得到错误的结果。

为了解决共线的情形,一种处理方法是在记录最优解 对应的 时,总是使之尽可能大(或尽可能小)。同时,将二分中的终止条件从寻找恰好满足 改为寻找满足 (或 )的最小(或最大)的 。在上一段的例子中,这相当于计算问题 时,输出的 。这就保证了算法终止时,最后计算的问题是 。实现这一方法时,需要注意最后输出的不是 而是 ,因为记录的 未必等于实际的限制

另一种处理方法是实数二分。如果问题涉及的数字都是整数,显然,WQS 二分中的斜率也是整数。在二分中引入实数,是为了保证错误地排除正确选项 时,可以通过小数部分调整回来,最终逼近正确答案 。例如,在上面的例子中,如果计算问题 时,记录的 ,小于所希望的 ,那么,算法就会转而考虑区间 ,其中, 所在区间的右端点。对于整数的情形,这一区间实际应该写作 ,这就排除了在后续算法中接近正确答案 的可能。但是,实数二分时,考虑的区间仍然是 ,而且,对于该区间中的 ,求解 时记录的 总是不小于 ,从而严格大于 的。因此,随着算法继续进行,会不断地舍去右半区间,从而,最终得到的 的范围可以保证在 附近。当然,因为已经知道所求的斜率是一个整数,实数二分终止时的精度不必太高,只要能保证二分的区间中只包含一个整数即可,这一整数就是要寻找的

正确地处理共线情形后,WQS 二分足以解决绝大多数算法竞赛会遇到的 WQS 二分的问题。但是,这一方法仍然存在一些不足之处:它无法处理 难以记录的情形,也无法处理高维 WQS 二分中多个点共面的情形。本文将进一步考察最优化问题 的性质,并提出更为一般的处理方法。

对偶方法

本节介绍一种 WQS 二分的实现方法,它只要求对于所有 ,可以高效地计算

的取值,且原问题的最优价值 是关于 的凸函数1。用一句话概括,本节将证明原问题的价值函数 就等于它的对偶问题的最优价值

而对偶问题的目标函数是关于 的凹函数,从而是单峰函数,可以通过 三分法黄金分割法 高效地求解,复杂度仍然是 的。这就完全地解决了传统 WQS 二分方法中记录 的值可能会出现的问题,同时,允许将 WQS 二分的思想应用至高维的情形。

另外,本节还说明, 的范围可以通过 求得,而无需在求解 时额外记录。例如,对于 且问题只涉及整数的情形,可以证明 的取值范围恰为

这实际上也对于不得不采取前文所述二分流程的题目,提供了又一种解决共线问题的方法。

接下来,本节将用凸分析的理论证明这些结论成立。至于这些方法的具体应用,可以参考 例题 一节。

Lagrange 对偶

考虑用 Lagrange 乘子法 解决该问题。引入 Lagrange 乘子 ,那么,Lagrangian 可以写作

因为只要 有一个分量非零,就可以让相应的 的分量趋于(正或负)无穷,所以有

这说明,原问题可以写作

交换两次最值操作,就得到它的 对偶问题

马上要说明的是,在 是关于 的凸函数的条件下,强对偶(strong duality)成立,即

凸共轭

为了说明强对偶成立,需要引入凸共轭的概念。

凸共轭

对于函数 ,它的 凸共轭(convex conjugate),或称 Legendre–Fenchel 变换(Legendre–Fenchel transformation),是指函数

从变量 的角度看, 是一系列线性函数的上确界,所以,必然是 上的凸函数。

超平面的「斜率向量」和「截距」

本文所讨论的向量空间 中的超平面的方程都具有形式

也就是说,本文不会涉及平行于 轴的超平面。为表述方便,本文并不严谨地将 称为该超平面的「斜率向量」, 称为该超平面的「截距」。将这一超平面的方程写成更标准的形式,就是

它的一个法向量是 。因此,所谓的斜率向量其实是将超平面的法向量归一化使得它的最后一个分量等于 时,所得到的法向量的前 个分量。

几何直观上,函数 的凸共轭描述的是,对于所有斜率向量为 且与函数 的上境图

相交的超平面,截距 的最小值就是 。换句话说,函数 总在超平面 上方,且与该平面切于点 ;当然,可能存在其余的切点。这样的超平面,称为 处的 支撑超平面(supporting hyperplane)。函数 的一个支撑超平面的截距由它的斜率向量唯一确定,凸共轭就提供了这个从斜率向量到截距的映射。

在集合 上最小化 就等价于在集合 上最小化

因此,有

这说明 是关于 的凹函数。进而,有

也就是说,对偶问题的价值函数 是原问题的价值函数 的双重凸共轭,也称为 双共轭(biconjugate)。

所以,问题转化为:什么样的函数 满足它的双共轭就等于它自身?这一问题的答案由如下定理给出:

定理(Fenchel–Moreau)

对于函数 ,它的双共轭等于它自身,即 ,当且仅当以下三个条件之一满足:

  1. 是正常凸函数且 下半连续
  2. ,或
证明

一个函数是正常的(proper),当且仅当它从不取得 的值,且不永远取得 的值。

对于非正常函数的情形,可以验证 互为共轭。除此之外,只要 在任何一点处取到 ,必然有 。所以,满足 的非正常函数只有这两种情形。下面的讨论仅限于正常函数。对于正常函数,下半连续且凸的条件等价于它的上境图是闭凸集。

这一条件的必要性是容易的。因为 的凸共轭,作为一系列线性函数的上确界,它的上境图必然是一系列闭凸集的交集,所以必然是闭凸集。这就说明,满足 的正常函数必然是下半连续且凸的。

反过来,这些条件也是充分的。和其他强对偶定理的证明一样,证明可以分为两步。

第一步,说明弱对偶成立,即 。由凸共轭的定义可知,对于所有 ,都有

这就说明,对于所有 ,同样有

对不等式右侧中的 取上确界,就有

第二步,利用 超平面分离定理 说明 。假设不然,存在 使得 成立。因为 的上境图 是闭凸集,而且单点集 是紧凸集,所以,根据超平面分离定理,存在 使得对于所有 和所有 都有

成立。因为 可以选得任意大,所以必然有 。这又可以分为两种情形。

首先,讨论 的情形。此时,将不等式的各部分都同除以 ,并设 ,就得到

对所有 ,令 ,就都成立

故而,对不等号右侧的 取上确界,有

进而,有

这一矛盾说明 的情形并不成立。

最后,讨论 的情形。事实上,将要说明的是,可以通过微扰,将它转化为 的情形。任取 ,根据凸共轭的定义可知,对于任何 都有

因此,对于任意 ,都有

同时,因为 ,所以,对于充分小的 ,又有

因此,如果取 ,那么,就有

这就又回到了前一种情形,仍然会导致矛盾。

这一矛盾说明,并不存在满足 的点 。故而,总有

结合这两步证明的结果,就得到 成立。

因此,强对偶成立,当且仅当 是关于 的凸函数2

次梯度

上一节说明了,带惩罚的问题的价值函数 是原问题的价值函数 的凸共轭的相反数。因为凸共轭的定义实际上是一个含参数的最优化问题,所以它也成立类似 包络定理 的结论。但是,因为凸函数并非处处可微的,所以需要首先将导数的定义推广到凸函数的情形。这就引出了次梯度的概念。

次梯度

对于凸函数 ,如果向量 满足对于任何 ,都有

那么,就称 处的一个 次梯度(subgradient)。函数 处的全体次梯度的集合称为它在该处的 次微分(subdifferential),记作

几何直观上,凸函数 处的次微分,就是它在该处的所有支撑超平面的斜率向量的集合。对于一维的情形,次微分

其中, 分别是函数 处的左右导数。进一步地,对于整数集上的凸函数 延拓而来的 ,它在整数点 处的左右导数就是左右两侧的一阶差分:

显然,凸函数 在点 处可微,当且仅当它在该处的次微分 是单点集。

因为凸共轭提供了从支撑超平面的斜率向量到它的截距的映射,所以,利用凸共轭,可以判断一个斜率向量 是否是凸函数 在给定点 处的一个次梯度。

定理(凸共轭与次梯度)

对于正常凸函数 和任意 ,都有

进而,如果 还是下半连续的,那么这两个条件都等价于

证明

按照次梯度的定义,,当且仅当

这等价于

这又等价于

但是,依据凸共轭的定义,总是有

因此,前一式中的大于等于号实际上等价于等号,也就等价于下式

这就完成了第一部分的证明。

对于 是下半连续的正常凸函数的情形,依 Fenchel–Moreau 定理,有 。因此,这两个条件等价于

再次应用第一部分的结论,它们也就等价于

这一结论说明,如果 ,那么凸共轭 处的次微分 ,恰好就是斜率向量为 的支撑超平面与上境图 的交点的 分量的集合。

推论

对于下半连续的正常凸函数 和任意 ,都有

证明

下面,证明第二个等式。第一个等式的证明与之类似。

按照凸共轭的定义,有

所以,

当且仅当 ,而这一等式成立,又当且仅当 。这就证明了两个集合是相等的。

应用到本文的场景中,这一结论说明,求解问题

时,限制函数 在最优决策集合上的取值恰为 。对于 且问题只涉及整数的情形,这一集合就是区间

对于连续的整数 ,这些区间首尾相接,所以,如果用于二分,只需要计算一侧的端点即可。

凸性证明

应用 WQS 二分的前提条件是价值函数的凸性。在算法竞赛中,可以通过打表、感性理解等方式猜测凸性成立。但是,严格地证明凸性成立,往往并不容易。本节结合如下经典题目,介绍算法竞赛中常见的证明凸性的方法。

种树问题

个树坑,要种 棵树。树不能栽种于相邻的两个坑。给定长度为 的序列 ,表示在每个坑种树的收益,收益可正可负。求种完这 棵树最大可能的收益和。

简言之,就是在长度为 的链上,求解大小为 的最大权独立集的问题。

这些方法粗略地可以分为四类:

  • 归约为凸优化问题(包括 线性规划 等)的价值函数对参数的凸性,这包括建立 费用流 模型等方法;
  • 利用状态转移方程也可以归纳地证明凸性成立,过程中可能会用到一些 保持凸性的变换
  • 对于区间分拆类型的问题,可以验证每段区间的成本函数满足 四边形不等式
  • 最后,对于特殊的问题,也可以通过交换论证直接说明凸性成立。

这些证明方法本身往往都同该问题的某种解法联系在一起。

归约为含参凸优化

考虑如下形式的含参凸优化问题:

其中,目标函数 对于每个 都是关于 的凸函数,而可行域 上的集合值函数,且对于每个 ,集合 都是凸集。这些条件保证了对于任意参数 ,这都是一个凸优化问题。

定理

假设上述含参凸优化问题满足如下条件:

  1. 目标函数 是关于 的凸函数;
  2. 可行域映射 的图像 是凸集。

如果对于任意 ,都有 ,那么,价值函数 是关于 的正常凸函数。

证明

对于任意 ,需要证明

如果 ,那么不等式的右侧就是 ,不等式必然成立。否则, 都是有限值。对于任意 ,都存在 使得 成立。利用映射 的图像的凸性可知

也就是说, 是参数为 的最优化问题的一个可行解。利用最优化条件和目标函数的凸性可知

因为 的选取是任意的,令 ,就有

因此,价值函数 的凸性成立。

算法竞赛中,最为常见的凸优化问题就是线性规划问题。

推论

。考虑如下含参线性规划问题:

那么,价值函数 是关于 的凸函数。

无论是不等式约束,还是等式约束,线性规划的价值函数都是约束条件参数的凸函数。

很多图论问题都可以写成线性规划问题的形式:

  • 网络流问题:最大流、最小割、最小费用流;
  • 无负环的最短路问题;
  • 二分图的最大(权)匹配、最小点覆盖等问题;
  • 一般图的最大(权)匹配问题;
  • 最小生成树问题3

因此,这些问题的价值函数都是这些问题的参数的凸(凹)函数。

整数约束

利用图论模型为实际问题建模时,通常有隐含的整数限制,例如一条边只能选或不选、流量只能是整数等。因此,它们只能转化为整数线性规划(integer linear programming, ILP)问题而非线性规划(LP)问题。因为 ILP 问题并非凸优化问题,所以它的价值函数未必是该问题的参数的凸函数。将一个 ILP 问题中的整数约束松弛掉后就得到一个 LP 问题,但后者未必存在满足整数约束的最优解。因此,松弛整数约束后得到的 LP 的最优价值有可能严格优于相应的 ILP 问题,两者未必等价。

上文列举的那些图论问题,都可以写成一个 LP 问题而不需要施加整数约束;但对于其他的一些问题,例如一般图的最大独立集问题等,整数约束则是必要的。另外,即使一个图论问题可以写成 LP 的形式,在该问题引入额外的线性约束条件后,仍然可能会破坏相应的 ILP 问题和 LP 问题的等价性,从而这个带约束的图论问题不再能够写成线性规划的形式。

例如,在费用流的语境下,有如下常见结论:

推论

最小费用流模型 中,最小费用 是流量 的凸函数。

证明

设有向图 ,边 的容量为 ,单位流量的费用为 ,,源点和汇点分别为 。记决策变量为 ,其中, 为边 的流量。那么,最小费用流可以写成如下线性规划问题:

因此,最小费用 是参数 的凸函数。

算法竞赛中很多问题都可以归约为网络流等图论问题,从而都可以通过类似的方式建立价值函数的凸性。

利用这一方法,可以得到种树问题的第一个凸性证明:

凸性证明一

种树问题的最大收益实际上可以通过如下最大费用最大流模型得出:

  • 从源点 出发,向结点 ,连接一条容量为 、费用为 的边;
  • 从结点 出发,向每个奇数结点 连接一条容量为 、费用为 的边;
  • 从每个偶数结点 出发,向汇点 连接一条容量为 、费用为 的边;
  • 对于每个 ,从 中的奇数结点出发,向偶数结点连接一条容量为 、费用为 的边。

最终答案就是求得的最大费用。将这一图论模型转换为相应的线性规划问题(具体见上文推论的证明),那么,总流量 将出现在表示边 的流量限制的不等式中。由推论可知,最大费用 是流量 的凹函数。

利用该费用流模型,可以通过模拟费用流或 反悔贪心 的方法在 的复杂度内解决该问题。

利用状态转移方程

尽管状态转移方程无法提供有效的计算方式,但是,它常常可以用于证明状态函数 对于参数 具有凸性。具体地说,就是将 这一函数视为 处的状态,就可以将关于 的状态转移方程看作是关于 的递推关系,从而可以归纳地证明每个 都是凸函数。这类证明凸性的方法在 Slope Trick 优化 DP 的场景中更为常见,该页面也讨论了常见的保持凸性的变换。

这一方法同样可以用于证明种树问题的凸性:

凸性证明二

为前 个坑种 棵树时的最大收益。考察如下状态转移方程:

将这一状态转移方程看作是函数 的递推关系式。因为最值符号内涉及两个不同的函数,这并不能表达为卷积上确界的形式。但是,仍然可以通过归纳的方法证明函数 是凹函数。

实际上,需要归纳地证明如下两点:

  • 关于 递减;
  • 关于 递增。

归纳起点是平凡的。假设它们对于所有 及之前的自然数都成立,现在证明它对 也成立。直接验证即可。

首先,由归纳假设,有

关于 递减。所以,有

关于 递减,且

关于 递增。这就完成了归纳。

进而,有

关于 递减。这就说明 是关于 的凹函数,从而,价值函数 是关于 的凹函数。

这个证明的一个副产品是,对于任意 ,都存在 使得

这说明,可以通过平衡树直接维护序列 ,复杂度是 的。但是,好处是可以处理任意种树间隔的一般情形,且一次性地获得了所有 的值。

四边形不等式

算法竞赛中,另一类常见的成立凸性的问题是 区间分拆问题。该页面证明了,如果单个区间的成本函数满足四边形不等式,那么限制区间个数的区间分拆问题的最小成本是区间个数的凸函数。该页面同样提供了一些判断某个函数 是否满足四边形不等式的方法。最为直接的方法就是计算它的二阶混合差分:

函数 满足四边形不等式,当且仅当 非正。直观上,满足四边形不等式的函数通常意味着,区间向两侧扩大——即左端点向左移动和右端点向右移动——具有某种协同效应。

种树问题同样可以看作是一个区间分拆问题,可以通过验证四边形不等式进行证明。

凸性证明三

在种树的收益序列前添加一个 ,它可以是任何值。然后,种树问题就等价于将序列 分成 段,且每段的收益函数为

的区间分拆问题。也就是说,每一段的收益是除去第一棵树外,其余树的收益的最大值——这就保证了间隔种树。

因为这是最大化问题,需要验证「交叉大于包含」,即对于任意 ,都成立

代入收益函数的表达式,并设

则要证明的不等式可以写作

注意到不等号左侧的两项 中较大的那个就等于 ,而它们中较小的那个总是不小于 ,因此该不等式成立。

将种树问题转化为区间分拆问题后,只需要用 ST 表等方式预处理区间最值,可以单次 地计算单个区间的成本,就可以套用区间分拆问题的算法在 的时间复杂度内解决该问题。该方法同样可以处理任意种树间隔的问题。

交换论证

组合优化问题中,证明价值函数的凸性常常会用到交换论证(exchange argument)。具体地说,就是从参数为 的问题的最优解出发,通过交换部分元素,构造出参数为 且价值不超过 的可行解,从而利用 的最优性来证明凸性的论证方法。相较于凸优化的情形,组合优化问题中并不存在自然地构造两个解的「中间形态」的方法,因此,交换论证的应用通常具有一定的技巧性。

「边际成本递增」并不一定导致凸性

组合优化问题中,目标函数常常具有一些「边际成本递增」的性质,但是这并不必然导致凸性。一个典型的例子是 [IOI 2005] Riv 河流,该问题的链上版本是满足四边形不等式因而具有凸性的,但是树上版本存在凸性不成立的例子。

用于刻画「边际成本递增」的一个常见性质是函数的超模性(supermodularity)。对于有限集合 的子集族 上的函数 ,如果它满足以下两条等价性质之一:

  1. (交叉小于包含)对于任何子集 ,都有 成立;
  2. (边际成本递增)对于任何子集 以及 ,都有 成立;

就称函数 超模的(supermodular)。但是,超模函数作为目标函数的最优化问题中,价值函数

未必 的凸函数。究其原因,就是从子集大小分别为 的最优解,一般来说是无法构造出子集大小为 且满足前述价值函数大小关系的可行解的。

交换论证提供了种树问题凸性的又一种证明方式。

凸性证明四

利用交换论证。设种 棵树和种 棵树的最优方案分别由 给出,其中,取值为 表示该坑位种了一棵树,取值为 表示该坑位没有种树。定义序列 满足

这个序列标记了两个种树方案的差异。这一序列中取值为 的位置表示该坑位要么在两个方案中都种了一棵树,要么在两个方案中都没有种树;而取值为 的位置则分别表示只在方案 或只在方案 中,该坑位种了一棵树。因为任何方案中都不能在相邻的坑位种树,所以有如下观察:

  • 连续的非零子段中, 的取值必然在 之间相互交错的;
  • 极大的连续非零子段左右两侧的 ,必然表示在两个方案中都没有种树。

因此,如果某个极大的连续非零子段中, 的和恰好为 ,也就是说,在该段坑位中,方案 比方案 多种了一棵树,那么就可以在该段内交换两个方案的种树位置。这样,就得到了两个各种 棵树的可行方案。因为没有改变两个方案中总的种树的位置和数量,只是将它们重新分配,所以总的收益不变,仍然是 。但是,这两个种 棵树的方案未必是最优的,因此,它们各自的收益不会超过 。这就证明了

也就是说, 是关于 的凹函数。

现在,只剩下一个问题,就是和恰好为 的极大连续非零子段是否存在。因为是若干个交错的 相加,一个连续非零子段的和只能是 。又因为所有这些极大连续非零子段的和等于 ,所以,一定存在至少两个和恰好为 的极大连续非零子段。这就完成了证明。

例题

本节介绍几个不同场景下应用 WQS 二分方法的例题。

模板题目

Luogu P1484 种树

个树坑,至多 棵树。树不能栽种于相邻的两个坑。给定长度为 的序列 ,表示在每个坑种树的收益,收益可正可负。求种完这 棵树最大可能的收益和。

解答

与前文讨论的种树问题稍有不同,本题要求至多种 棵树,而非恰好种 棵树。仍然用 表示前文讨论的种树问题的价值函数,本题的答案实际上是 。因为 是凹函数,也就是一个单峰函数,本题的答案相当于只保留 上升至峰顶的部分,然后函数会一直留在峰顶;这相当于仅仅保留切线斜率非负的部分。因此,本题与前文讨论的题目的唯一差别,就是 WQS 二分时,初始的斜率范围为 而非

应用 WQS 二分的方法移除数量限制后,问题转化为计算链上最大权独立集,只是将原本的收益 替换为了 。这是经典的动态规划题目。可以设 为第 个坑位选择种树()或不种树()时,前 个树坑的子问题的最大收益。由此,可以写出状态转移方程为

初始条件为 ,最终答案为 。单次计算复杂度为 ,整体时间复杂度为 ,其中,

参考实现如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <algorithm>
#include <cstring>
#include <iostream>
#include <tuple>
#include <vector>

int main() {
  int n, m;
  std::cin >> n >> m;
  std::vector<int> a(n + 1);
  for (int i = 1; i <= n; ++i) std::cin >> a[i];
  // Calculate h(k) = max_x f(x) - k * g(x).
  // Meanwhile, obtain the maximum value g(x) of the optimizer x.
  auto calc = [&](int k) -> std::pair<long long, int> {
    long long dp[2] = {0, -0x3f3f3f3f3f3f3f3f};
    int opt[2] = {0, 0};
    for (int i = 1; i <= n; ++i) {
      long long tmp_dp[2];
      int tmp_opt[2];
      if (dp[0] > dp[1]) {
        tmp_dp[0] = dp[0];
        tmp_opt[0] = opt[0];
      } else if (dp[1] > dp[0]) {
        tmp_dp[0] = dp[1];
        tmp_opt[0] = opt[1];
      } else {
        tmp_dp[0] = dp[0];
        tmp_opt[0] = std::max(opt[0], opt[1]);
      }
      tmp_dp[1] = dp[0] + a[i] - k;
      tmp_opt[1] = opt[0] + 1;
      std::memcpy(dp, tmp_dp, sizeof(dp));
      std::memcpy(opt, tmp_opt, sizeof(opt));
    }
    long long val;
    int opt_m;
    if (dp[0] > dp[1]) {
      val = dp[0];
      opt_m = opt[0];
    } else if (dp[1] > dp[0]) {
      val = dp[1];
      opt_m = opt[1];
    } else {
      val = dp[0];
      opt_m = std::max(opt[0], opt[1]);
    }
    return {val, opt_m};
  };
  // WQS binary search.
  long long val, tar_val;
  int opt_m, tar_k;
  std::tie(val, opt_m) = calc(0);
  if (opt_m <= m) {
    // Have already reached the peak.
    tar_k = 0;
    tar_val = val;
  } else {
    // Find the maximum k such that g(x) >= m.
    int ll = 0, rr = 1000000;
    while (ll <= rr) {
      int mm = (ll + rr) / 2;
      std::tie(val, opt_m) = calc(mm);
      if (opt_m >= m) {
        tar_k = mm;
        tar_val = val;
        ll = mm + 1;
      } else {
        rr = mm - 1;
      }
    }
  }
  long long res = tar_val + (long long)tar_k * m;
  std::cout << res << std::endl;
  return 0;
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <algorithm>
#include <iostream>
#include <tuple>
#include <type_traits>
#include <vector>

// Golden section search on integer domain (unimodal function)
template <typename T, typename F>
typename std::enable_if<
    std::is_integral<T>::value,
    std::pair<T, decltype(std::declval<F>()(std::declval<T>()))>>::type
golden_section_search(T ll, T rr, F func) {
  constexpr long double phi = 0.618033988749894848204586834L;
  T ml = ll + static_cast<T>((rr - ll) * (1 - phi));
  T mr = ll + static_cast<T>((rr - ll) * phi);
  auto fl = func(ml), fr = func(mr);
  while (ml < mr) {
    if (fl > fr) {
      rr = mr;
      mr = ml;
      fr = fl;
      ml = ll + static_cast<T>((rr - ll) * (1 - phi));
      fl = func(ml);
    } else {
      ll = ml;
      ml = mr;
      fl = fr;
      mr = ll + static_cast<T>((rr - ll) * phi);
      fr = func(mr);
    }
  }
  T best_x = ll;
  auto best_val = func(ll);
  for (T i = ll + 1; i <= rr; ++i) {
    auto val = func(i);
    if (val > best_val) {
      best_val = val;
      best_x = i;
    }
  }
  return {best_x, best_val};
}

int main() {
  int n, m;
  std::cin >> n >> m;
  std::vector<int> a(n + 1);
  for (int i = 1; i <= n; ++i) std::cin >> a[i];
  // Calculate h(k) = max_x f(x) + k * g(x).
  auto calc = [&](int k) -> long long {
    long long dp[2] = {0, -0x3f3f3f3f3f3f3f3f};
    for (int i = 1; i <= n; ++i) {
      std::tie(dp[0], dp[1]) =
          std::make_pair(std::max(dp[0], dp[1]), dp[0] + a[i] + k);
    }
    return std::max(dp[0], dp[1]);
  };
  // Solve the dual problem to find v(m).
  // Implemented as a minimization problem by adding negative signs.
  // Only consider tangent lines of negative slopes to ignore the part
  //     of the curve after the peak.
  auto res = -golden_section_search(-1000000, 0, [&](int k) -> long long {
                return -calc(k) + (long long)k * m;
              }).second;
  std::cout << res << std::endl;
  return 0;
}
Luogu P2619 [国家集训队] Tree I

给定一张带权无向连通图,每条边是黑色或白色。求恰有 条白边的生成树的最小权。

解答

首先,通过交换论证可以证明 是凸函数。不妨假设所有边的边权各不相同:那些存在两边边权相同的情形,可以通过微扰使之变为边权各不相同的情形;然后只要令微扰的幅度趋近于零,就可以证明函数的凸性在极限情形——也就是存在两边边权相同的情形——仍然成立。证明的关键在于如下引理:4

引理

是无向连通图 的两个生成树。对于任意 ,都存在至少一条边 ,使得 都是图 的生成树。

证明

,且 是树 中连接 的唯一一条路径。因为 是图 中唯一的环路,所以删掉 中的任何一条边 都可以使得 是一棵生成树。与此同时,图 是有两个连通分量的森林,它们的顶点集分别记作 ,所以,只要选择边 使得 连通了 ,就能保证 是一棵生成树。这样的边 总是存在的,因为 分别属于 ,而 连接了 。而且,,因为图 中, 并不是连通的。这就完成了证明。

是白边数量分别为 的最小生成树。设 的一条白边,对它应用上述引理可知,存在边 ,使得 都是生成树。因为只是交换了一对边,所以,树 和树 的边权和仍然是 。进而,分两种情形讨论:

  • 如果 是一条黑边,那么, 中的白边数量都是 。它们各自的边权和都不会小于 。这就证明了 ,故而 关于 是凸的。
  • 如果 是一条白边,那么, 的白边数量分别是 ,所以它们的边权和分别不小于 。但是,上面已经说明,它们的边权和加在一起就等于 。这说明, 的边权和就等于 。将 比较可知, 的边权必然相等。这与假设矛盾,所以该情形并不成立。

这就证明了 的凸函数。

建立了函数 的凸性后,就可以用 WQS 二分解决该问题。移除数量限制并将每条白边的权重都减去 ,并求解最小生成树问题。为此,可以应用 Kruskal 算法。利用并查集维护连通性,算法的复杂度就是 ,其中, 分别为边数和顶点数, 为反 Ackerman 函数。复杂度的主要部分 是给边排序的复杂度,在本题中可以进一步优化。虽然在 WQS 二分的过程中需要多次计算最小生成树,但是每次只有白边的边权会整体加减一个数。所以,可以在预处理时给白边、黑边分别排序,然后每次计算最小生成树时,只需要将调整完权重后的白边和黑边归并到一起就可以了。这样,整体复杂度就降低到了 ,其中, 为边权的取值范围的长度。

参考实现如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <algorithm>
#include <array>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>

class DisjointSet {
  std::vector<int> fa, sz;

  int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }

 public:
  DisjointSet(int n) : fa(n), sz(n, 1) { std::iota(fa.begin(), fa.end(), 0); }

  bool unite(int x, int y) {
    x = find(x);
    y = find(y);
    if (x == y) return false;
    if (sz[x] < sz[y]) std::swap(x, y);
    fa[y] = x;
    sz[x] += sz[y];
    return true;
  }
};

int main() {
  int V, E, m;
  std::cin >> V >> E >> m;
  std::array<std::vector<std::array<int, 3>>, 2> edges;
  edges[0].reserve(E);  // white edges.
  edges[1].reserve(E);  // black edges.
  for (int i = 0; i < E; ++i) {
    int u, v, c, w;
    std::cin >> u >> v >> w >> c;
    edges[c].push_back({u, v, w});
  }
  // Sort edges.
  std::sort(edges[0].begin(), edges[0].end(),
            [&](const auto& lhs, const auto& rhs) -> bool {
              return lhs[2] < rhs[2];
            });
  std::sort(edges[1].begin(), edges[1].end(),
            [&](const auto& lhs, const auto& rhs) -> bool {
              return lhs[2] < rhs[2];
            });
  // Calculate h(k) = min_x f(x) - k * g(x) by Kruskal algorithm.
  // Use white edges first, whenever possible.
  auto calc = [&](int k) -> std::pair<int, int> {
    int res = 0, cnt = 0;
    DisjointSet djs(V);
    int i[2] = {};
    while (i[0] < edges[0].size() && i[1] < edges[1].size()) {
      int c = edges[0][i[0]][2] - k > edges[1][i[1]][2];
      if (djs.unite(edges[c][i[c]][0], edges[c][i[c]][1])) {
        res += edges[c][i[c]][2] - (c ? 0 : k);
        cnt += c ? 0 : 1;
      }
      ++i[c];
    }
    while (i[0] < edges[0].size()) {
      if (djs.unite(edges[0][i[0]][0], edges[0][i[0]][1])) {
        res += edges[0][i[0]][2] - k;
        ++cnt;
      }
      ++i[0];
    }
    while (i[1] < edges[1].size()) {
      if (djs.unite(edges[1][i[1]][0], edges[1][i[1]][1])) {
        res += edges[1][i[1]][2];
      }
      ++i[1];
    }
    return {res, cnt};
  };
  // WQS binary search.
  // Find the minimum k such that g(x) >= m.
  int val, opt_m, tar_val, tar_k;
  int ll = -100, rr = 100;
  while (ll <= rr) {
    int mm = ll + (rr - ll) / 2;
    std::tie(val, opt_m) = calc(mm);
    if (opt_m >= m) {
      tar_val = val;
      tar_k = mm;
      rr = mm - 1;
    } else {
      ll = mm + 1;
    }
  }
  int res = tar_val + tar_k * m;
  std::cout << res << std::endl;
  return 0;
}
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include <algorithm>
#include <array>
#include <iostream>
#include <numeric>
#include <tuple>
#include <type_traits>
#include <vector>

// Golden section search on integer domain (unimodal function)
template <typename T, typename F>
typename std::enable_if<
    std::is_integral<T>::value,
    std::pair<T, decltype(std::declval<F>()(std::declval<T>()))>>::type
golden_section_search(T ll, T rr, F func) {
  constexpr long double phi = 0.618033988749894848204586834L;
  T ml = ll + static_cast<T>((rr - ll) * (1 - phi));
  T mr = ll + static_cast<T>((rr - ll) * phi);
  auto fl = func(ml), fr = func(mr);
  while (ml < mr) {
    if (fl > fr) {
      rr = mr;
      mr = ml;
      fr = fl;
      ml = ll + static_cast<T>((rr - ll) * (1 - phi));
      fl = func(ml);
    } else {
      ll = ml;
      ml = mr;
      fl = fr;
      mr = ll + static_cast<T>((rr - ll) * phi);
      fr = func(mr);
    }
  }
  T best_x = ll;
  auto best_val = func(ll);
  for (T i = ll + 1; i <= rr; ++i) {
    auto val = func(i);
    if (val > best_val) {
      best_val = val;
      best_x = i;
    }
  }
  return {best_x, best_val};
}

class DisjointSet {
  std::vector<int> fa, sz;

  int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }

 public:
  DisjointSet(int n) : fa(n), sz(n, 1) { std::iota(fa.begin(), fa.end(), 0); }

  bool unite(int x, int y) {
    x = find(x);
    y = find(y);
    if (x == y) return false;
    if (sz[x] < sz[y]) std::swap(x, y);
    fa[y] = x;
    sz[x] += sz[y];
    return true;
  }
};

int main() {
  int V, E, m;
  std::cin >> V >> E >> m;
  std::array<std::vector<std::array<int, 3>>, 2> edges;
  edges[0].reserve(E);  // white edges.
  edges[1].reserve(E);  // black edges.
  for (int i = 0; i < E; ++i) {
    int u, v, c, w;
    std::cin >> u >> v >> w >> c;
    edges[c].push_back({u, v, w});
  }
  // Sort edges.
  std::sort(edges[0].begin(), edges[0].end(),
            [&](const auto& lhs, const auto& rhs) -> bool {
              return lhs[2] < rhs[2];
            });
  std::sort(edges[1].begin(), edges[1].end(),
            [&](const auto& lhs, const auto& rhs) -> bool {
              return lhs[2] < rhs[2];
            });
  // Calculate h(k) = min_x f(x) - k * g(x) by Kruskal algorithm.
  auto calc = [&](int k) -> int {
    int res = 0;
    DisjointSet djs(V);
    int i[2] = {};
    while (i[0] < edges[0].size() && i[1] < edges[1].size()) {
      int c = edges[0][i[0]][2] - k > edges[1][i[1]][2];
      if (djs.unite(edges[c][i[c]][0], edges[c][i[c]][1])) {
        res += edges[c][i[c]][2] - (c ? 0 : k);
      }
      ++i[c];
    }
    while (i[0] < edges[0].size()) {
      if (djs.unite(edges[0][i[0]][0], edges[0][i[0]][1])) {
        res += edges[0][i[0]][2] - k;
      }
      ++i[0];
    }
    while (i[1] < edges[1].size()) {
      if (djs.unite(edges[1][i[1]][0], edges[1][i[1]][1])) {
        res += edges[1][i[1]][2];
      }
      ++i[1];
    }
    return res;
  };
  // Solve the dual problem to find v(m).
  auto res = golden_section_search(-100, 100, [&](int k) -> int {
               return calc(k) + k * m;
             }).second;
  std::cout << res << std::endl;
  return 0;
}

区间分拆问题

Luogu P6246 [IOI 2000] 邮局 加强版 加强版

给定长度为 且递增的正整数序列 表示一条高速公路旁的 个村庄的位置,需要修建 个邮局。邮局位置的选择,需要最小化所有村庄与其最近邮局的距离之和。求这个最小值。

解答

这是典型的 区间分拆问题。二分队列的实现细节请参考该页面。

每个邮局服务离它最近的村庄,那么,这些村庄必然是高速公路旁连续的若干个村庄。所以,修建 个邮局,就相当于将所有村庄划分为连续的 段,并为每一段村庄修建一个成本最低的邮局。众所周知,邮局应当修建在村庄位置的中位数的位置。由此,可以写出区间 的成本函数为

它满足四边形不等式,因为它的二阶混合差分非正:

这说明,可以通过二分队列结合 WQS 二分的方法在 的复杂度内求解该问题。

参考实现如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
#include <cmath>
#include <deque>
#include <iostream>
#include <tuple>
#include <type_traits>
#include <vector>

// Monotone decision DP.
// This solves f(i) = min f(j-1) + w(j,i) s.t. 1 <= j <= i.
// Also records the minimal optimal decision j for each f(i).
template <typename W>
std::pair<std::vector<decltype(std::declval<W>()(0, 0))>, std::vector<int>>
monotone_decision_opt_dp(int n, W ww) {
  using ValueT = decltype(std::declval<W>()(0, 0));
  std::vector<ValueT> f(n + 1);
  std::vector<int> opt(n + 1), lt(n + 1), rt(n + 1);
  std::deque<int> dq;
  auto w = [&](int j, int i) -> ValueT { return ww(j, i) + f[j - 1]; };
  for (int j = 1; j <= n; ++j) {
    if (!dq.empty() && rt[dq.front()] < j) dq.pop_front();
    if (!dq.empty()) lt[dq.front()] = j;
    while (!dq.empty() && w(j, lt[dq.back()]) < w(dq.back(), lt[dq.back()])) {
      dq.pop_back();
    }
    if (dq.empty()) {
      lt[j] = j;
      rt[j] = n;
      dq.emplace_back(j);
    } else if (w(j, rt[dq.back()]) >= w(dq.back(), rt[dq.back()])) {
      if (rt[dq.back()] < n) {
        lt[j] = rt[dq.back()] + 1;
        rt[j] = n;
        dq.emplace_back(j);
      }
    } else {
      int ll = lt[dq.back()], rr = rt[dq.back()], i = rr;
      while (ll <= rr) {
        int mm = (ll + rr) / 2;
        if (w(j, mm) < w(dq.back(), mm)) {
          i = mm;
          rr = mm - 1;
        } else {
          ll = mm + 1;
        }
      }
      rt[dq.back()] = i - 1;
      lt[j] = i;
      rt[j] = n;
      dq.emplace_back(j);
    }
    f[j] = w(dq.front(), j);
    opt[j] = dq.front();
  }
  return {f, opt};
}

int main() {
  int n, m;
  std::cin >> n >> m;
  std::vector<long long> a(n + 1), ps(n + 1);
  for (int i = 1; i <= n; ++i) {
    std::cin >> a[i];
    ps[i] = ps[i - 1] + a[i];
  }
  // Cost function for interval [l,r].
  auto w = [&](int j, int i) -> long long {
    int mm = j + (i - j) / 2;
    return ps[i] + ps[j - 1] - 2 * ps[mm] + (2 * mm - j - i + 1) * a[mm];
  };
  // Calculate h(k) = min_x f(x) - k * g(x).
  // Also record the minimum of optimal number of segments.
  auto calc = [&](long long k) -> std::pair<long long, int> {
    auto res = monotone_decision_opt_dp(
        n, [&](int j, int i) -> long long { return w(j, i) - k; });
    auto val = res.first[n];
    int cnt = 0;
    for (int i = n; i; i = res.second[i] - 1) {
      ++cnt;
    }
    return {val, cnt};
  };
  // WQS binary search.
  // Find the largest k such that g(x) <= m.
  long long val, tar_val;
  int opt_m, tar_k;
  long long ll = -(1LL << 32), rr = 0;
  while (ll <= rr) {
    long long mm = ll + (rr - ll) / 2;
    std::tie(val, opt_m) = calc(mm);
    if (opt_m <= m) {
      tar_val = val;
      tar_k = mm;
      ll = mm + 1;
    } else {
      rr = mm - 1;
    }
  }
  long long res = tar_val + (long long)tar_k * m;
  std::cout << res << std::endl;
  return 0;
}
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#include <cmath>
#include <deque>
#include <iostream>
#include <type_traits>
#include <utility>
#include <vector>

// Monotone decision DP.
// This solves f(i) = min f(j-1) + w(j,i) s.t. 1 <= j <= i.
// Also records the minimal optimal decision j for each f(i).
template <typename W>
std::pair<std::vector<decltype(std::declval<W>()(0, 0))>, std::vector<int>>
monotone_decision_opt_dp(int n, W ww) {
  using ValueT = decltype(std::declval<W>()(0, 0));
  std::vector<ValueT> f(n + 1);
  std::vector<int> opt(n + 1), lt(n + 1), rt(n + 1);
  std::deque<int> dq;
  auto w = [&](int j, int i) -> ValueT { return ww(j, i) + f[j - 1]; };
  for (int j = 1; j <= n; ++j) {
    if (!dq.empty() && rt[dq.front()] < j) dq.pop_front();
    if (!dq.empty()) lt[dq.front()] = j;
    while (!dq.empty() && w(j, lt[dq.back()]) < w(dq.back(), lt[dq.back()])) {
      dq.pop_back();
    }
    if (dq.empty()) {
      lt[j] = j;
      rt[j] = n;
      dq.emplace_back(j);
    } else if (w(j, rt[dq.back()]) >= w(dq.back(), rt[dq.back()])) {
      if (rt[dq.back()] < n) {
        lt[j] = rt[dq.back()] + 1;
        rt[j] = n;
        dq.emplace_back(j);
      }
    } else {
      int ll = lt[dq.back()], rr = rt[dq.back()], i = rr;
      while (ll <= rr) {
        int mm = (ll + rr) / 2;
        if (w(j, mm) < w(dq.back(), mm)) {
          i = mm;
          rr = mm - 1;
        } else {
          ll = mm + 1;
        }
      }
      rt[dq.back()] = i - 1;
      lt[j] = i;
      rt[j] = n;
      dq.emplace_back(j);
    }
    f[j] = w(dq.front(), j);
    opt[j] = dq.front();
  }
  return {f, opt};
}

// Golden section search on integer domain (unimodal function)
template <typename T, typename F>
typename std::enable_if<
    std::is_integral<T>::value,
    std::pair<T, decltype(std::declval<F>()(std::declval<T>()))>>::type
golden_section_search(T ll, T rr, F func) {
  constexpr long double phi = 0.618033988749894848204586834L;
  T ml = ll + static_cast<T>((rr - ll) * (1 - phi));
  T mr = ll + static_cast<T>((rr - ll) * phi);
  auto fl = func(ml), fr = func(mr);
  while (ml < mr) {
    if (fl > fr) {
      rr = mr;
      mr = ml;
      fr = fl;
      ml = ll + static_cast<T>((rr - ll) * (1 - phi));
      fl = func(ml);
    } else {
      ll = ml;
      ml = mr;
      fl = fr;
      mr = ll + static_cast<T>((rr - ll) * phi);
      fr = func(mr);
    }
  }
  T best_x = ll;
  auto best_val = func(ll);
  for (T i = ll + 1; i <= rr; ++i) {
    auto val = func(i);
    if (val > best_val) {
      best_val = val;
      best_x = i;
    }
  }
  return {best_x, best_val};
}

int main() {
  int n, m;
  std::cin >> n >> m;
  std::vector<long long> a(n + 1), ps(n + 1);
  for (int i = 1; i <= n; ++i) {
    std::cin >> a[i];
    ps[i] = ps[i - 1] + a[i];
  }
  // Cost function for interval [l,r].
  auto w = [&](int j, int i) -> long long {
    int mm = j + (i - j) / 2;
    return ps[i] + ps[j - 1] - 2 * ps[mm] + (2 * mm - j - i + 1) * a[mm];
  };
  // Calculate h(k) = min_x f(x) - k * g(x).
  auto solve = [&](long long k) -> long long {
    return monotone_decision_opt_dp(
               n, [&](int j, int i) -> long long { return w(j, i) - k; })
               .first[n] +
           k * m;
  };
  // Solve the dual problem to find v(m).
  auto res = golden_section_search(-(1LL << 32), 0LL, solve).second;
  std::cout << res << std::endl;
  return 0;
}

二维的限制条件

Codeforces 739 E. Gosha is hunting

只神奇宝贝,序列 分别表示用宝贝球和超级球抓到第 只神奇宝贝的概率。可以向一个神奇宝贝扔一个宝贝球,或者扔一个超级球,或者两种球各扔一个,或者什么球都不扔。现有 个宝贝球和 个神奇球,需要合理分配,并同时扔出。求抓到的神奇宝贝的期望数量的最大值。单次抓捕成功与否,与其它抓捕的结果无关。

更一般地,可以抽象为如下问题:

给定长度为 的三个正实数序列 ,而且,对所有 都有 成立。求最优的下标集合 满足 且最大化

解答

原问题可以看作是这个更一般的问题在

时的特殊情形。因此,只需要讨论更一般的问题的解决方案就可以了。

表示该问题的价值函数,需要证明它是关于 的凹函数。考虑如下费用流模型:

  • 从源点 出发,分别向结点 连一条边,容量分别为 ,费用均为
  • 对于所有 ,分别从结点 向结点 连一条边,容量均为 ,费用分别为
  • 对于所有 ,从结点 出发向汇点 连两条边,容量均为 ,费用分别为

问题的答案就是该费用流模型的最大费用最大流。条件 保证了当流经结点 的流量为 时,会优先选择费用为 的那条边流出。将这个费用流模型写成线性规划问题,那么, 就会分别出现在表示边 和边 的流量限制的不等式中。因此, 确实是 的凹函数。

为了应用 WQS 二分,需要考虑移除数量限制后的最优化问题。设 分别为将一个下标放入集合 时获得的额外奖励。没有数量限制后,关于每个下标的决策都是独立的,因此,有

原问题的答案就由

给出。总的时间复杂度为 ,其中, 为对单个维度二分的次数。

抓捕神奇宝贝问题的参考实现如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <algorithm>
#include <iomanip>
#include <iostream>
#include <tuple>
#include <type_traits>
#include <vector>

// Golden section search on floating-point domain (unimodal function)
template <typename T, typename F>
typename std::enable_if<
    std::is_floating_point<T>::value,
    std::pair<T, decltype(std::declval<F>()(std::declval<T>()))>>::type
golden_section_search(T ll, T rr, F func, T eps) {
  constexpr long double phi = 0.618033988749894848204586834L;
  T ml = ll + (rr - ll) * (1 - phi);
  T mr = ll + (rr - ll) * phi;
  auto fl = func(ml), fr = func(mr);
  while ((rr - ll) > eps) {
    if (fl > fr) {
      rr = mr;
      mr = ml;
      fr = fl;
      ml = ll + (rr - ll) * (1 - phi);
      fl = func(ml);
    } else {
      ll = ml;
      ml = mr;
      fl = fr;
      mr = ll + (rr - ll) * phi;
      fr = func(mr);
    }
  }
  T mid = (ll + rr) / 2;
  return {mid, func(mid)};
}

int main() {
  int n, m1, m2;
  std::cin >> n >> m1 >> m2;
  std::vector<long double> p(n + 1), q(n + 1);
  for (int i = 1; i <= n; ++i) std::cin >> p[i];
  for (int i = 1; i <= n; ++i) std::cin >> q[i];
  // Calculate h(k1,k2).
  auto solve = [&](long double k1, long double k2) -> long double {
    long double res = 0;
    for (int i = 1; i <= n; ++i) {
      res += std::max(
          {0.0l, p[i] + k1, q[i] + k2, p[i] + q[i] - p[i] * q[i] + k1 + k2});
    }
    return res;
  };
  // Solve the dual problem to find v(m1,m2).
  // Implemented as a minimization problem by adding negative signs.
  auto res = -golden_section_search(
                  -1.0l, 0.0l,
                  [&](long double k1) -> long double {
                    return golden_section_search(
                               -1.0l, 0.0l,
                               [&](long double k2) -> long double {
                                 return -solve(k1, k2) + k2 * m2;
                               },
                               1e-8l)
                               .second +
                           k1 * m1;
                  },
                  1e-8l)
                  .second;
  std::cout << std::fixed << std::setprecision(10) << res << std::endl;
  return 0;
}

更广泛的限制条件

Codeforces 1661 F. Teleporters

条线段,它们的长度由序列 给出。可以将它们任意切割为若干条整数长度的线段,目标是最小化所有线段长度平方的总和。求至少需要切割多少次,才能使这个平方和降到不超过

解答

设将长度为 的线段切割 次能得到的最小平方和为 。由于均值不等式,当两数之和一定时,两数之差越小,两数的平方和也越小。所以,切割之后得到的线段长度越均匀,总的长度平方和也就越小。但是由于整数约束的存在,最均匀的情形就是得到了 条长度为 的线段和 条长度为 的线段。因此,有如下表达式:

第二步的等号成立,是因为 当且仅当

可以证明,函数 是关于 的凸函数。为此,需要将它延拓到 的情形。当 时,有

这是斜率为 的直线。因此, 是分段线性函数,且斜率随着 的增加而增加。这就说明了 是凸函数,它限制在整点上当然也是凸函数5

利用 ,可以将所有线段总共切割 次得到的最小平方和写作如下最优化问题的价值函数:

这是若干个凸函数的 卷积下确界,所以也是凸函数。如果题目要求的是 ,那么,可以使用与之前的例题一致的方法求解,时间复杂度为 ;但是,本题求的是满足 的最小的 。利用 WQS 二分计算 再对 二分的方法是行不通的,它的复杂度达到了 。就本题而言,有如下两种处理方法。

方法一:仍然二分斜率 ,但是二分的依据是对 上下界的估计。

传统的 WQS 二分的方法中,对于给定的斜率 ,可以计算出相应的最优的 的取值范围。因为这些 共线,所以这相当于确定了 的取值范围。因此,可以直接二分斜率 。得到斜率 之后,可以利用直线方程

计算出最小的 。整体复杂度为

为了确定 的取值范围,需要确定 的取值范围。一种做法是,在计算 时记录相应的最大最优解,利用它可以计算相应的 的下界;另一种做法是,利用 得到相应的 的上界,进而得到相应的 的下界。参考实现中,采用的是第二种做法,它不依赖于问题的具体结构,无需特别处理。

方法二:改写最优化问题,使得对偶问题的价值函数就是本问题的解。

本问题可以直接看作是如下最优化问题:

本文的分析仍然适用于这一问题。故而,可以利用它的对偶问题求解所要求的

整体算法复杂度仍然是 的。

参考代码如下:

代码仅做示意,为通过原题数据范围,需要 128 位整数,并调整二分初始区间为

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include <algorithm>
#include <iostream>
#include <tuple>
#include <type_traits>
#include <vector>

// Golden section search on integer domain (unimodal function)
template <typename T, typename F>
typename std::enable_if<
    std::is_integral<T>::value,
    std::pair<T, decltype(std::declval<F>()(std::declval<T>()))>>::type
golden_section_search(T ll, T rr, F func) {
  constexpr long double phi = 0.618033988749894848204586834L;
  T ml = ll + static_cast<T>((rr - ll) * (1 - phi));
  T mr = ll + static_cast<T>((rr - ll) * phi);
  auto fl = func(ml), fr = func(mr);
  while (ml < mr) {
    if (fl > fr) {
      rr = mr;
      mr = ml;
      fr = fl;
      ml = ll + static_cast<T>((rr - ll) * (1 - phi));
      fl = func(ml);
    } else {
      ll = ml;
      ml = mr;
      fl = fr;
      mr = ll + static_cast<T>((rr - ll) * phi);
      fr = func(mr);
    }
  }
  T best_x = ll;
  auto best_val = func(ll);
  for (T i = ll + 1; i <= rr; ++i) {
    auto val = func(i);
    if (val > best_val) {
      best_val = val;
      best_x = i;
    }
  }
  return {best_x, best_val};
}

int main() {
  int n;
  std::cin >> n;
  std::vector<int> a(n + 1);
  for (int i = 1; i <= n; ++i) std::cin >> a[i];
  for (int i = n; i >= 1; --i) a[i] -= a[i - 1];
  long long v;
  std::cin >> v;
  // Cost of adding M more teleporters to a segment of length LEN.
  auto f = [&](int len, int m) -> long long {
    long long rem = len % (m + 1);
    int q = len / (m + 1);
    return (m + 1 - rem) * q * q + rem * (q + 1) * (q + 1);
  };
  // Calculate h(k) = min_x f(x) - k * g(x).
  auto calc = [&](long long k) -> long long {
    long long res = 0;
    for (int i = 1; i <= n; ++i) {
      res += -golden_section_search(0, a[i], [&](int m) -> long long {
                return -f(a[i], m) + m * k;
              }).second;
    }
    return res;
  };
  // Find the smallest k such that h(k) + k * m <= v.
  long long ll = -(1LL << 30), rr = 0, ti = 0;
  while (ll <= rr) {
    auto mm = ll + (rr - ll) / 2;
    auto fi = calc(mm);
    auto ub = fi - calc(mm + 1);
    if (fi + ub * mm <= v) {
      ti = mm;
      rr = mm - 1;
    } else {
      ll = mm + 1;
    }
  }
  std::cout << (int)((calc(ti) - v - 1 - ti) / (-ti)) << std::endl;
  return 0;
}

代码仅做示意,由于浮点数精度问题无法通过原题数据范围。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
#include <algorithm>
#include <cmath>
#include <iostream>
#include <tuple>
#include <type_traits>
#include <vector>

// Golden section search on integer domain (unimodal function)
template <typename T, typename F>
typename std::enable_if<
    std::is_integral<T>::value,
    std::pair<T, decltype(std::declval<F>()(std::declval<T>()))>>::type
golden_section_search(T ll, T rr, F func) {
  constexpr long double phi = 0.618033988749894848204586834L;
  T ml = ll + static_cast<T>((rr - ll) * (1 - phi));
  T mr = ll + static_cast<T>((rr - ll) * phi);
  auto fl = func(ml), fr = func(mr);
  while (ml < mr) {
    if (fl > fr) {
      rr = mr;
      mr = ml;
      fr = fl;
      ml = ll + static_cast<T>((rr - ll) * (1 - phi));
      fl = func(ml);
    } else {
      ll = ml;
      ml = mr;
      fl = fr;
      mr = ll + static_cast<T>((rr - ll) * phi);
      fr = func(mr);
    }
  }
  T best_x = ll;
  auto best_val = func(ll);
  for (T i = ll + 1; i <= rr; ++i) {
    auto val = func(i);
    if (val > best_val) {
      best_val = val;
      best_x = i;
    }
  }
  return {best_x, best_val};
}

// Golden section search on floating-point domain (unimodal function)
template <typename T, typename F>
typename std::enable_if<
    std::is_floating_point<T>::value,
    std::pair<T, decltype(std::declval<F>()(std::declval<T>()))>>::type
golden_section_search(T ll, T rr, F func, T eps) {
  constexpr long double phi = 0.618033988749894848204586834L;
  T ml = ll + (rr - ll) * (1 - phi);
  T mr = ll + (rr - ll) * phi;
  auto fl = func(ml), fr = func(mr);
  while ((rr - ll) > eps) {
    if (fl > fr) {
      rr = mr;
      mr = ml;
      fr = fl;
      ml = ll + (rr - ll) * (1 - phi);
      fl = func(ml);
    } else {
      ll = ml;
      ml = mr;
      fl = fr;
      mr = ll + (rr - ll) * phi;
      fr = func(mr);
    }
  }
  T mid = (ll + rr) / 2;
  return {mid, func(mid)};
}

int main() {
  int n;
  std::cin >> n;
  std::vector<int> a(n + 1);
  for (int i = 1; i <= n; ++i) std::cin >> a[i];
  for (int i = n; i >= 1; --i) a[i] -= a[i - 1];
  long long v;
  std::cin >> v;
  // Cost of adding M more teleporters to a segment of length LEN.
  auto f = [&](int len, int m) -> long long {
    long long rem = len % (m + 1);
    int q = len / (m + 1);
    return (m + 1 - rem) * q * q + rem * (q + 1) * (q + 1);
  };
  // Calculate h(k) = min_x f(x) - k * g(x).
  auto calc = [&](long double k) -> long double {
    long double res = 0;
    for (int i = 1; i <= n; ++i) {
      res += -golden_section_search(0, a[i], [&](int m) -> long double {
                return -m + k * f(a[i], m);
              }).second;
    }
    return res;
  };
  // Solve the dual problem.
  auto res =
      golden_section_search(
          -1.0l, 0.0l,
          [&](long double k) -> long double { return calc(k) + k * v; }, 1e-12l)
          .second;
  std::cout << (int)ceill(res) << std::endl;
  return 0;
}

习题

最后,列举一些可以通过 WQS 二分解决的题目,以供练习:

参考资料与注释


  1. 实际问题中, 可能只能取到 中的有限多个格点。此处实际需要的条件是,原问题的解 可以延拓为 上的凸函数 ,也就是说 可凸延拓的(convex-extensible)。为行文方便,正文中仍然用 表示延拓后的函数。几何直观上,这相当于说点集 全部都位于它们的凸包的下凸壳上。对于一维的情形,这一条件利用代数语言 很容易刻画;但是,对于高维的情形,这稍微有些复杂,这份讲义 中提供了一些简单的充分条件。 

  2. 定理中提供的条件看似比凸函数更强一些,但是,对于算法竞赛能够遇到的情形,特别是 为有限集合时,仅强调凸函数就已经足够。由离散集合上的正常凸函数 延拓而来的函数 必然是下半连续的凸函数,因为有限多个点的凸包必然是闭凸包,而所谓下半连续的凸函数,就等价于它的上境图是闭凸包。至于正常凸函数中的「正常」一词,只要 在至少一个点处取得有限值且是凸函数,就可以保证。 

  3. 最小生成树问题有两种常见的写成线性规划问题的 方法:子回路消除模型(subtour-elimination formulation)和基于割集的模型(cut-based formulation)。只有前一种建模方式能够保证得到的线性规划问题和原问题是等价的。 

  4. 这一引理对于一般的 拟阵 也是成立的。它称为 对称基交换性质(symmetric base-exchange property),相关资料可以参考 Wikipedia 页面。因此,本题关于凸性的结论可以推广到一般的拟阵上。 

  5. 当然, 和它限制在整点上得到的函数的凸包并不相同,因为 可能存在非整数位置处的极点。这说明,并不能将定义域为实数的 直接用于本题的最优化问题中。