#!/usr/bin/python
# -*- coding: utf-8 -*-

# compute statistics about the quality of the geometric segmentation
# at the level of the given OCR element

import sys,os,codecs,string,re,getopt
import Image,ImageDraw
import xml
from BeautifulSoup import BeautifulSoup
from pylab import array,zeros,reshape

################################################################
### library
################################################################

### general utility functions

def assoc(key,list):
    for k,v in list:
        if k==key: return v
    return None

### XML node processing

def get_prop(node,name):
    title = node['title']
    props = title.split(';')
    for prop in props:
        (key,args) = prop.split(None,1)
        if key==name: return args
    return None
def get_bbox(node):
    bbox = get_prop(node,'bbox')
    if not bbox: return None
    return tuple([int(x) for x in bbox.split()])
def get_text(node):
    s = node.string
    return re.sub(r'\s+',' ',s)

# rectangle properties

def intersect(u,v):
    # intersection of two rectangles
    r = (max(u[0],v[0]),max(u[1],v[1]),min(u[2],v[2]),min(u[3],v[3]))
    return r
def width(u):
    # width of a rectangle
    return max(0,u[2]-u[0])
def height(u):
    # height of a rectangle
    return max(0,u[3]-u[1])
def area(u):
    # area of a rectangle
    return max(0,u[2]-u[0])*max(0,u[3]-u[1])
def overlaps(u,v):
    # predicate: do the two rectangles overlap?
    return area(intersect(u,v))>0
def relative_overlap(u,v):
    m = max(area(u),area(v))
    i = area(intersect(u,v))
    return float(i)/m
def erode(u,tx,ty):
    x = 2*tx+1
    y = 2*ty+1
    return tuple([u[0]+x,u[1]+y,u[2]-x,u[3]-y])

### text comparison

simp_re = re.compile(r'[^a-zA-Z0-9.,!?:;]+')

def normalize(s):
    s = simp_re.sub(' ',s)
    s = s.strip()
    return s

### edit distance

def edit_distance(a,b,threshold=99999):
    if a==b: return 0
    m = len(a)
    n = len(b)
    distances = zeros((m+1,n+1))
    distances[:,:] = threshold
    distances[:,0] = array(range(m+1))
    distances[0,:] = array(range(n+1))
    for i in range(1,m+1):
        for j in range(1,n+1):        
            if a[i-1] == b[j-1]:
                cij = 0
            else:
                cij = 1
            d = min(
                distances[i-1,j] + 1, 
                distances[i,j-1] + 1, 
                distances[i-1,j-1] + cij
            )
            if d>=threshold: return d
            distances[i,j] = d
    return distances[m,n]

#def remove_tex(text):
#    text_file = os.popen("echo %s | detex " %(text))
#    text_plain = text_file.read()
#    text_file.close()
#    return text_plain

def remove_tex(text):
    return text

################################################################
### main program
################################################################

### argument parsing

if len(sys.argv)<3:
    print "usage: %s hocr-true.html hocr-actual.html"%sys.argv[0]
    sys.exit(0)

optlist,args = getopt.getopt(sys.argv[1:],"dve:o:i:")
debug = (assoc('-d',optlist)=='')
verbose = (assoc('-v',optlist)=='')
element = assoc('-e',optlist) or 'ocr_line'
significant_overlap = assoc('-o',optlist) or 0.1
significant_overlap = float(significant_overlap)

imgfile =  assoc('-i',optlist)
if(imgfile):
    im = Image.open(imgfile)
    print im.size, im.format, im.mode
    draw=ImageDraw.Draw(im)

# get pages from inputs
gtfile = codecs.open(args[0],encoding='utf-8')
truth_doc = gtfile.read()
gtfile.close()
ocrfile = codecs.open(args[1],encoding='utf-8')
actual_doc = ocrfile.read()
ocrfile.close()

# parse pages using beautiful soup
gtsoup = BeautifulSoup(truth_doc,convertEntities='html')
truth_pages = gtsoup('div','ocr_page')
ocrsoup = BeautifulSoup(actual_doc,convertEntities='html')
actual_pages = ocrsoup('div','ocr_page')

# zip ground-truth and ocr result pages
assert len(truth_pages) == len(actual_pages)
pages = zip(truth_pages,actual_pages)

segmentation_errors = 0
segmentation_ocr_errors = 0
ocr_errors = 0

# relative and absolute thresholds in vertical and horizontal direction
HTOL=90
VTOL=80
HPIX=5
VPIX=5


used = {}

for truth,actual in pages:
    true_lines = truth('span','ocr_line')
    actual_lines = actual('span','ocr_line')
    tx=[min(HPIX,(100-HTOL)*width(get_bbox(line))/100) for line in true_lines]
    ty=[min(VPIX,(100-VTOL)*height(get_bbox(line))/100) for line in true_lines]
    for index,true_line in enumerate(true_lines):
        bbox = get_bbox(true_line)
        bbox_small = erode(bbox,tx[index],ty[index])
        candidates = [(area(intersect(get_bbox(line),bbox)),get_bbox(line),get_text(line)) for line in actual_lines]
        q = 0
        tight_overlap = False
        if candidates!=[]:
            q,actual_bbox,actual_line = max(candidates)
            actual_bbox_small = erode(actual_bbox,tx[index],ty[index])
            if(area(intersect(actual_bbox_small , bbox)) == area(actual_bbox_small) and
               area(intersect(actual_bbox , bbox_small)) == area(bbox_small)):
                tight_overlap = True

        if (tight_overlap==0) :
            if verbose:
                print "segmentation_error: area_overlap =",q*1.0/area(bbox),"true_bbox",bbox
                print "\t",get_text(true_line)
            segmentation_errors += 1
            
            if candidates!=[]:
                true_text = remove_tex(get_text(true_line))
                segmentation_ocr_errors += edit_distance(normalize(true_text),normalize(actual_line))
            else:
                segmentation_ocr_errors += len(get_text(true_line))

            if(imgfile):
                draw.rectangle(bbox,outline="#ff0000")
                if candidates!=[]:
                    draw.rectangle(actual_bbox,outline="#0000ff")
            continue
        true_text = remove_tex(get_text(true_line))
        actual_text = actual_line
        if debug:
            print "overlap",q,"true_bbox",bbox
            print "\t",true_text
            print "\t",actual_text
        error = edit_distance(normalize(true_text),normalize(actual_text))
        if verbose and error>0:
            print "ocr_error",error,"true_bbox",bbox
            print "\t",true_text
            print "\t",actual_text
        ocr_errors += error

print "segmentation_errors",segmentation_errors
print "segmentation_ocr_errors",segmentation_ocr_errors
print "ocr_errors",ocr_errors
if(imgfile):
    im.save("errors.png")
    im.show("errors.png")
