import random
n = random.getrandbits(1024)
#Just choosing a random 1024 bit number
m = random.getrandbits(1008)
m = 2<<1008
m = m + random.getrandbits(1008)
#This is the hidden message we wish to find.
B = 2**1008
B2 = 2*B
B3 = 3*B
ocalls=0
def ceildiv(x,y): #ceildiv(a,b) = ceil(a/b)
return x//y + (x%y != 0)
def floordiv(x,y): #floordiv(a,b) = floor(a/b)
return x//y
def padCheck(y):
global ocalls
ocalls+=1
if y >= B2 and y < B3:
return True
else:
return False
#The usual PKSC#1 padding does more checks (there must be at least 8 padded bytes etc.)
#This is a simpler version, that also increase the likelihood of hitting a properly padded string
The block findS(smin)
starts from smin
, and keeps trying out various s
's until it finds one that makes the padCheck(m*s)
accept
def findS(smin):
global m
global n
si = smin
while(True):
si+=1
mi = (m*si)%n
if padCheck(mi):
return si
break
Once we have found an $s$, and if we are assuming that $a \leq m \leq b$, then we know that $2B \leq ms - rn \leq 3B-1$ for some positive integer $r$. This gives the bound
$$
\frac{bs - 2B}{n} \geq \frac{ms - 2B}{n} \geq r \geq \frac{ms-3B+1}{n} \geq \frac{as - 3B+1}{n}
$$
The function findrRanges(s,a,b)
returns the lower and upper bound based on the above formula.
def findrRanges(s,a,b):
rmin = ceildiv((a*s)-B3+1,n)
rmax = floordiv((b*s) - B2,n)
return(rmin,rmax)
In Bleichenbacher's original paper, he performs an optimisation in the setting when we have a single interval that is not the initial inverval of $(2B, 3B-1)$. In this case, you find $r \geq \frac{2(bs - 2B)}{n}$, and look for an $s$ in the range $$ \frac{2B + rn}{b} \leq s \leq \frac{3B-1 +rn}{a}. $$ Use this $s$ and proceed.
The idea is that this results in a new interval that's no more than half the original interval. You can see the details in Bleichenbacher's paper.
def findS_opt(s,a,b):
global m
global n
global B2
global B3
r = floordiv(2*((b*s) - B2),n)
while(True):
found = False
while(True):
for si in range(ceildiv(B2+(r*n),b),floordiv(B3-1+(r*n),a)+1):
mi = (si*m)%n
if padCheck(mi):
found=True
return(si)
if not found:
r+=1
If we already knew that $a \leq m \leq b$ and that $2B \leq ms-rn \leq 3B$, then we get that $$ \frac{2B + rn}{s} \leq m \leq \frac{3B + rn -1}{s}, $$ and we can intersect this with the old interval $(a,b)$. The code below basically find an $s$ and updates the interval, until it ends up with a single interval of length $1$.
def Attack():
global ocalls,n,B2,B3,m
si = ceildiv(n,B3)-1
newM = set([])
newM.add((B2,B3-1))
ocalls=0
while(True):
# First find an si
if(len(newM)>1 or (B2,B3-1) in newM):
si = findS(si)
elif(len(newM)==1): #use the optimised version in the special case
(a,b) = newM.pop()
newM.add((a,b))
si = findS_opt(si,a,b)
#update the intervals
newMM = set([])
for (a,b) in newM:
(r1,r2) = findrRanges(si,a,b)
for r in range(r1,r2+1):
aa = ceildiv(B2 + (r*n),si)
bb = floordiv(B3 - 1 + (r*n),si)
newa = max(a,aa)
newb = min(b,bb)
if newa <= newb:
newMM.add((newa,newb))
if len(newMM)>0:
newM = newMM
else:
print("Something went wrong!")
exit(-1)
if len(newM) == 1:
(a,b)= newM.pop()
newM.add((a,b))
if a==b:
print("Oracle calls:", ocalls)
return a
%%time
guess = Attack()
if m==guess:
print("Found m: ",m)