# coding: utf-8 # In[1]: import random # In[2]: n = random.getrandbits(1024) #Just choosing a random 1024 bit number # In[3]: m = random.getrandbits(1008) m = 2<<1008 m = m + random.getrandbits(1008) #This is the hidden message we wish to find. # In[4]: B = 2**1008 B2 = 2*B B3 = 3*B ocalls=0 # In[5]: def ceildiv(x,y): #ceildiv(a,b) = ceil(a/b) return x//y + (x%y != 0) def floordiv(x,y): #floordiv(a,b) = floor(a/b) return x//y def padCheck(y): global ocalls ocalls+=1 if y >= B2 and y < B3: return True else: return False #The usual PKSC#1 padding does more checks (there must be at least 8 padded bytes etc.) #This is a simpler version, that also increase the likelihood of hitting a properly padded string # ## Finding an s # # The block `findS(smin)` starts from `smin`, and keeps trying out various `s`'s until it finds one that makes the `padCheck(m*s)` accept # In[6]: def findS(smin): global m global n si = smin while(True): si+=1 mi = (m*si)%n if padCheck(mi): return si break # ## Figuring out the range of r's # # Once we have found an $s$, and if we are assuming that $a \leq m \leq b$, then we know that $2B \leq ms - rn \leq 3B-1$ for some positive integer $r$. This gives the bound # $$ # \frac{bs - 2B}{n} \geq \frac{ms - 2B}{n} \geq r \geq \frac{ms-3B+1}{n} \geq \frac{as - 3B+1}{n} # $$ # The function `findrRanges(s,a,b)` returns the lower and upper bound based on the above formula. # In[7]: def findrRanges(s,a,b): rmin = ceildiv((a*s)-B3+1,n) rmax = floordiv((b*s) - B2,n) return(rmin,rmax) # ## An optimisation in case of a single interval # # In [Bleichenbacher's original paper](http://archiv.infsec.ethz.ch/education/fs08/secsem/bleichenbacher98.pdf), he performs an optimisation in the setting when we have a single interval that is not the initial inverval of $(2B, 3B-1)$. In this case, you find $r \geq \frac{2(bs - 2B)}{n}$, and look for an $s$ in the range # $$ # \frac{2B + rn}{b} \leq s \leq \frac{3B-1 +rn}{a}. # $$ # Use this $s$ and proceed. # # The idea is that this results in a new interval that's no more than half the original interval. You can see the details in [Bleichenbacher's paper](http://archiv.infsec.ethz.ch/education/fs08/secsem/bleichenbacher98.pdf). # In[8]: def findS_opt(s,a,b): global m global n global B2 global B3 r = floordiv(2*((b*s) - B2),n) while(True): found = False while(True): for si in range(ceildiv(B2+(r*n),b),floordiv(B3-1+(r*n),a)+1): mi = (si*m)%n if padCheck(mi): found=True return(si) if not found: r+=1 # # The attack # # If we already knew that $a \leq m \leq b$ and that $2B \leq ms-rn \leq 3B$, then we get that # $$ # \frac{2B + rn}{s} \leq m \leq \frac{3B + rn -1}{s}, # $$ # and we can intersect this with the old interval $(a,b)$. The code below basically find an $s$ and updates the interval, until it ends up with a single interval of length $1$. # In[9]: def Attack(): global ocalls,n,B2,B3,m si = ceildiv(n,B3)-1 newM = set([]) newM.add((B2,B3-1)) ocalls=0 while(True): # First find an si if(len(newM)>1 or (B2,B3-1) in newM): si = findS(si) elif(len(newM)==1): #use the optimised version in the special case (a,b) = newM.pop() newM.add((a,b)) si = findS_opt(si,a,b) #update the intervals newMM = set([]) for (a,b) in newM: (r1,r2) = findrRanges(si,a,b) for r in range(r1,r2+1): aa = ceildiv(B2 + (r*n),si) bb = floordiv(B3 - 1 + (r*n),si) newa = max(a,aa) newb = min(b,bb) if newa <= newb: newMM.add((newa,newb)) if len(newMM)>0: newM = newMM else: print("Something went wrong!") exit(-1) if len(newM) == 1: (a,b)= newM.pop() newM.add((a,b)) if a==b: print("Oracle calls:", ocalls) return a # In[10]: guess = Attack() if m==guess: print("Found m: ",m)