#!/usr/bin/python

# extract lines from Google 1000 book sample

import string,re,sys,os,shutil,glob
from xml.sax import saxexts, saxlib, saxutils
from PIL import Image

usage = """
... hocr image_pattern output_prefix

Process Google 1000 books volumes and prepares line or word images
for alignment using OCRopus.  

Run ocroscript align-... Volume_0000/0000/0000.{png,txt}

Arguments:

    hocr: hocr source file

    image_pattern: either a glob pattern that results in a list 
        of image files in order, or @filename for a file containing
        a list of image files in order; DON'T FORGET TO QUOTE THIS

    output_pattern: output images are of the form 
        output_pattern%(pageno,lineno)

Environment Variables:

    element="ocr_line": which element to extract; ocrx_word and
        ocr_cinfo are also useful

    regex=".": the text for any transcription must match this pattern

    dict=None: a dictionary; if provided, all the words in any line
        that's output by the program must occur in the dictionary

    min_len=20: minimum length of text for which lines are output

    max_len=50: maximum length of text for which lines are output

    max_lines=1000000: maximum number of lines output

    pad=2: pad the bounding box by this many pixels prior to extraction

    output_format=png: format for line image files
"""

if len(sys.argv)!=4:
    sys.stderr.write(usage)
    print "args:",sys.argv
    sys.exit(1)

exe,hocr,image_pattern,output_pattern = sys.argv

if image_pattern[0]=="@":
    image_list = open(image_pattern[1:]).readlines()
    image_list = [s[:-1] for s in image_list]
    image_list.sort()
else:
    image_list = glob.glob(image_pattern)

if not os.path.exists(hocr):
    sys.stderr.write(hocr+": not found")
    sys.exit(1)

element = os.getenv("element","ocr_line")
regex = os.getenv("regex",".")
min_len = int(os.getenv("min_len","20"))
max_len = int(os.getenv("max_len","50"))
dict = None
dictfile = os.getenv("dict")
max_lines = int(os.getenv("max_lines","1000000"))
pad = int(os.getenv("pad","2"))
output_format = os.getenv("output_format","png")

if dictfile:
    stream = open(dictfile,"r")
    words = stream.read().split()
    stream.close()
    dict = {}
    for word in words: dict[word.lower()] = 1
    # print "[read %d words from %s]\n"%(len(words),dictfile)

def check_dict(dict,s):
    if not dict: return 1
    words = re.split(r'\W+',s)
    for word in words:
        if word=="": continue
        if not dict.get(word.lower()): return 0
    return 1

def write_string(file,text):
    stream = open(file,"w")
    stream.write(text.encode("utf-8"))
    stream.close()

def get_prop(title,name):
    props = title.split(';')
    for prop in props:
        (key,args) = prop.split(None,1)
        if key==name: return args
    return None

class docHandler(saxlib.DocumentHandler):
    def __init__(self):
        self.element = element
        self.regex = regex
    def startDocument(self):
	self.total = 0
	self.pageno = -1
	self.text = None
	self.depth = 0
	self.start = -1
        self.copied = {}
    def endDocument(self):
	pass
    def startElement(self,name,attrs):
	self.depth += 1
	if attrs.get("class","")=="ocr_page":
	    self.lineno = -1
	    self.pageno += 1
            self.page = image_list[self.pageno]
	    self.image = Image.open(self.page)
	if attrs.get("class","")==self.element:
	    self.lineno += 1
	    props = attrs.get("title","")
	    self.bbox = get_prop(props,"bbox")
	    self.start = self.depth
	    self.text = u""
    def endElement(self,name):
	if self.depth==self.start:
	    if len(self.text)>=min_len and \
                    len(self.text)<=max_len and \
                    re.match(self.regex,self.text) and \
                    check_dict(dict,self.text):
		print self.page,self.bbox,self.text.encode("utf-8")
		w,h = self.image.size
		x0,y0,x1,y1 = [int(s) for s in self.bbox.split()]
		assert y0<y1 and x0<x1 and x1<=w and y1<=h
		x0 = max(0,x0-pad)
		y0 = max(0,y0-pad)
		x1 = min(w,x1+pad)
		y1 = min(h,y1+pad)
		limage = self.image.crop((x0,y0,x1,y1))
		base = output_pattern%(self.pageno,self.lineno)
                basedir = os.path.dirname(base)
                if not os.path.exists(basedir):  os.mkdir(basedir)
                limage.save(base+"."+output_format)
		write_string(base+".txt",self.text)
		write_string(base+".bbox",self.bbox)
		self.total += 1
		if self.total>=max_lines: sys.exit(0)
	    self.text = None
	    self.start = -1
	self.depth -= 1
    def characters(self,str,start,end):
	if self.text!=None:
	    self.text += str[start:end]

parser = saxexts.make_parser()
parser.setDocumentHandler(docHandler())
stream = os.popen("tidy -q -wrap 9999 -asxhtml < %s 2> /tmp/tidy_errs"%hocr,"r")
parser.parseFile(stream)
