雁过无痕

  C++博客 :: 首页 :: 新随笔 :: 联系 :: 聚合  :: 管理 ::

 

先看一道面试题:


长度为n的数组,由数字1n组成,其中数字a不出现,数字b出现两次,其它的数字恰好出现一次。怎样通过只读遍历一次数组,找出数字ab

 

 

由于只能遍历一次,在遍历数组arr时,算出 ab的差值,以及ab的平方差,通过解方程,即可求得ab。具体做法为:

设:

      s1 = 1 + 2 + ... + n           (= n * (n + 1) / 2)

      s2 = arr[0] + arr[1] + ... + arr[n - 1]

   

      r1 = 1 + 4 + ... + n^2          (= n * (n + 1) * (2 * n + 1) / 6)

      r2 = arr[0]^2 + arr[1]^2 + ... + arr[n - 1]^2

     

     c = a - b = s1 - s2

     d = a^2 - b^2 = r1 - r2

    显然:  a + b = (r1 - r2) / (s1 - s2)

根据a+b的值和a-b的值,很容易就可算出ab

 

算法虽然简单,但实现起来,却有一个很大问题:计算 s1s2r1r24个数时,计算过程中可能出现溢出,造成结果不准。由于最终目的是为了计算出cd,一个改进的方法是:

 c = s1 - s2 = (1 - arr[0]) + (2 - arr[1]) + ... + (n - arr[n - 1])

 d = (1 - arr[0]^2) + (4 - arr[1]^2) + ... + (n^2 - arr[n - 1]^2)

但这样的做法,并不能解决问题,n稍微大点,照样存在溢出问题。

 

那么怎样才能避免计算溢出呢?答案很简单,用模运算!每进行一次加减运算时,都取结果为原结果除以一个足够大的常数M的余数。这样加减运算中,就不会现现溢出问题。最后再由 c % Md % M,推测出cd的具体值。比如说,计算s2改为计算:

   s2 % M = ((((arr[0] % M) + arr[1]) % M + ...) % M + arr[n - 1]) %M

从表面上看,采用模运算后,计算量会增加很多。但实际上,若M取合适的值时,计算量并不会增加!!

 

先回顾下计算机基本知识:两个各N位(寄存器为N位)的二进制无符号整数ab相加,若结果溢出了,CPU会怎么处理?当然是将溢出的那一位忽略掉(可能还要设置下溢出标志),得到的结果实际上是:(a + b) mod 2^N无符号数间的算术运算,本质上就是模运算。现在的CPU采用二补数来表示负整数,本质上也是运用模运算(教科书将二补数表示的负整数简单定义为:对正整数取反后加1),这与无符号数间的运算是一致的,在实现上,比用其它方法(比如说一补数)表示负整数,要优美易实现。

32位平台下, -x mod 2^32 = 2^32 – x (x > 0)

因而-1的二进制表示就是:0xFFFFFFFF

 

了解了这些,就不会奇怪C/C++标准的规定:无符号数间的运算是模运算不会溢出;有符号数转为无符数,采用模运算后的值。(为了兼容没采用二补数的机器,无符号数转为有符号数时,若无符号数的数值超出了有符号数可表示的范围,结果是平台相关的。)

 

因而,在对32CPU平台,可以先将有符号数转为无符号数,再取M = 2 ^32。需要特别注意的是,应该采用多少位的无符号数保存计算中用到的数值,如何避免模运算可能带来的问题:

 

① 无符号数类型的选择:

ab的取值范围为:[1, n]

c % M = (a - b) % M 的取值范围为:[1, n] (a > b)   [M - n, M - 1] (a < b)

这两个范围不能重叠,而因 n < M - n 2 * n < M

M2^32的话,且 n < 2^31 可以采用32位无符号数表示c的值。

根据c % M值在哪一个范围,可以确定a > b还是a < b

由于运算过程中都是采用无符号数计算,当 a < b时,必须进行如下调整:

        c % M 调整为 (-c) % M

        d % M 调整为 (-d) % M

这样才能保证结果的正确性。

 

用公式计算所有数字的和、平方和时,可能出现的问题:

   模运算满足: (a * b) % M  = ((a % M) * (b % M)) % M

     不满足 (a / b) % M  = ((a % M) / (b % M)) % M

   在计算 (n * (n + 1) / 2) % M时, 不能写成:

      s = ((n * (n + 1)) % M / 2) % M

   而应该写成:

     if  (n % 2 == 0)   s = ((n / 2) * (n + 1)) % M

     else             s = (((n + 1) / 2) * n) % M

   或者:s = (INT((n + 1) / 2) * (n + (n + 1) % 2)) % M (其中INT(x)为取小数x的整数部份)。

 

完整代码:

 


#include 
<climits>
#include 
<cassert>

#define SMALL_ARRAY 0

struct Pair {
  
int zero;
  
int twice;
};

//32位CPU平台,长度n一定小于2^16次方时,表示一个数的平方值,可用32位无符号数类型,效率很高。
//长度n若在[2^16, 2^31]区间,就必须用到64位无符号数类型,效率较高。
//长度n若在[2^31, 2^32)时,表示 所有数的和sum,就必须改用64位无符号数类型,效率不高。  
Pair find_number(const int arr[], unsigned len)
{
  
const unsigned bits = CHAR_BIT * sizeof(unsigned);
#if SMALL_ARRAY
  
const unsigned max_len = 1u << (bits / 2u);
  typedef unsigned 
int uint;
#else
  
const unsigned max_len = 1u << (bits - 1);
  typedef unsigned 
long long uint;
#endif

  assert(arr 
&& len >= 2 && len < max_len);
  
const unsigned* const data = (const unsigned*)arr;
  unsigned sum 
= 0;
  
uint square_sum = 0;
  
for (unsigned i = 0; i < len; ++i)  {
    
const unsigned value = data[i];
    sum 
+= value;
    square_sum 
+= (uint)value * value;     //注意两个数的乘积是否会溢出  
  }
  
  
//1 + 2 + 3 +  + len = len * (len + 1) / 2
  const uint sum_all = (len + 1/ 2u * (uint)(len + (len + 1% 2u);
  
  
//1^2 + 2^2 + 3^2 +  + len^2 = len * (len + 1) * (2 * len + 1) / 6
  const unsigned len2 = 2u * len + 1;
  
const uint square_sum_all = len2 % 3u == 0 ? len2 / 3u * sum_all : sum_all / 3u * len2;
  
  unsigned difference 
= (unsigned)sum_all - sum;
  
uint square_difference = square_sum_all - square_sum;
  
const bool is_negative = difference > INT_MAX;

  
if (is_negative) {
    difference 
= -difference;
    square_difference 
= -square_difference;
  } 
   
  assert(difference 
!= 0 && square_difference % difference == 0);
  
const unsigned sum_two = square_difference / difference;
  
  assert((sum_two 
+ difference) % 2u == 0);
  
const unsigned larger  = (sum_two + difference) / 2u;
  
const unsigned smaller = (sum_two - difference) / 2u;
  
  
if (is_negative) {
    
const Pair result = { smaller, larger};
    
return result;
  }
  
const Pair result = { larger, smaller};
  
return result;
}


int main()
{

}


posted on 2012-03-18 21:16 flyinghearts 阅读(3063) 评论(1)  编辑 收藏 引用 所属分类: 算法C++

评论

# re: 避免计算过程中出现溢出的一个技巧 2012-03-19 09:03 tb
恩 预防一下   回复  更多评论
  


只有注册用户登录后才能发表评论。
网站导航: 博客园   IT新闻   BlogJava   知识库   博问   管理