#!/usr/bin/python

# This file is part of vdr-webvideo-plugin.
#
# Copyright 2009 Antti Ajanki <antti.ajanki@iki.fi>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import socket
import cStringIO
import sys
import cmd
import mimetypes
import urllib
import select
import os.path
import subprocess
import time
import re
import libxml2
from optparse import OptionParser
from ConfigParser import RawConfigParser

version = '0.1.5'

# Default options
DEFAULT_SERVER = '127.0.0.1'
DEFAULT_PORT = 2357
DEFAULT_PLAYERS = ['vlc --play-and-exit "%s"', 
                   'totem "%s"', 
                   'mplayer "%s"', 
                   'xine "%s"']
DEFAULT_CONFIG = {'daemonpath': '/usr/bin/webvid',
                  'daemonargs': '-d -l/tmp/webvid.log'}

mimetypes.init()
# These mimetypes are common but often missing
mimetypes.add_type('video/flv', '.flv')
mimetypes.add_type('video/x-flv', '.flv')

def getContentUnicode(node):
    # node.getContent() returns an UTF-8 encoded sequence of bytes (a
    # string). Convert it to a unicode object.
    return unicode(node.getContent(), 'UTF-8')

class WVClient:
    DEFAULT_URL_PRIORITY = 100
    
    def __init__(self, addr, port, daemonpath, daemonargs, streamplayers):
        self.addr = addr
        self.port = port
        self.daemonpath = daemonpath
        self.daemonargs = daemonargs
        self.streamplayers = streamplayers
        self.buffer = ''
        self.history = []
        self.history_pointer = 0
        self.sock = None

    def read_headers(self):
        expectresponseline = True
        responsecode = -1
        responsetext = ''
        headers = {}
        buf = ''
        while True:
            try:
                # FIXME
                self.sock.setblocking(0)
                r = ''
                while r.find('\r\n\r\n') == -1:
                    select.select([self.sock.fileno()], [], [])
                    re = self.sock.recv(4096)
                    if len(re) == 0:
                        break
                    r += re
                self.sock.setblocking(1)
            except socket.error, e:
                print 'error while reading headers',  e
                break
            if len(r) == 0:
                break

            buf = buf+r
            lf = buf.find('\r\n')
            while (lf != -1) and (lf != 0):
                line = buf[:lf]
                buf = buf[lf+2:]
                if expectresponseline:
                    # The first line contains the response code
                    responseline = line
                    try:
                        protover, code, responsetext = responseline.split(' ', 2)
                        responsecode = int(code)
                    except ValueError:
                        pass
                    expectresponseline = False
                else:
                    # Read one header line off the buffer
                    try:
                        name, value = line.split(':')
                        headers[name.lower()] = value.strip()
                    except ValueError:
                        pass
                lf = buf.find('\r\n')
            if lf == 0:
                # headers end with an empty line
                buf = buf[2:]
                break
        self.buffer = buf
        return (responsecode, responsetext, headers)

    def write_body(self, dest, contentlen, progress_stream=None):
        """Reads contentlen bytes from socket and writes them to a
        file-like object dest. Returns the number of bytes actully
        written."""
        last_update = time.time()
        self.download_start = last_update
        self.progress_len = 0
        self.progress_samples = [(0, last_update)]
        
        remaining = contentlen
        while remaining > 0:
            if remaining > 4096:
                r = self.read(4096)
            else:
                r = self.read(remaining)
            if not r:
                # The connection was closed. Invalidate the socket so
                # that we will reconnect on next request.
                self.sock = None
                break
            dest.write(r)
            remaining -= len(r)

            if progress_stream is not None:
                now = time.time()
                if now - last_update > 1:
                    self.update_progress_indicator(remaining, contentlen,
                                                   now, progress_stream)
                    last_update = now
        
        if progress_stream is not None:
            self.update_progress_indicator(remaining, contentlen,
                                           time.time(), progress_stream)
            progress_stream.write('\n')
        return contentlen-remaining

    def read(self, bytes):
        if len(self.buffer) > 0:
            ret = self.buffer[:bytes]
            self.buffer = self.buffer[bytes:]
            return ret
        else:
            try:
                return self.sock.recv(bytes)
            except socket.error, e:
                print e
                return ''

    def _try_connect(self):
        self.sock = socket.socket(socket.AF_INET)
        try:
            self.sock.connect((self.addr, self.port))
        except socket.error, e:
            print 'Error connecting to the server:', e.args[1]
            return e.args[0]
        return 0

    def connect(self):
        self.buffer = ''
        e = self._try_connect()
        if e == 0:
            return True
        elif e == 111: # Connection refused
            # Maybe the server is not running? Try connecting again
            # after starting the server. Server can only be started
            # when connecting to localhost.
            if self.addr in ['127.0.0.1', 'localhost']:
                print 'Trying to start the server'
                cmd = [self.daemonpath]
                cmd.extend(self.daemonargs)
                try:
                    subprocess.call(cmd)
                except OSError, e:
                    print 'Failed to start the server: ' + e.args[1]
                    return False
                time.sleep(0.5)
                e = self._try_connect()
                if e == 0:
                    return True
        return False

    def close(self):
        if self.sock is not None:
            try:
                self.sock.sendall('CLOSE X WVTP/1.0\r\n\r\n')
                self.sock.close()
            except socket.error:
                # ignore errors while closing
                pass

    def parse_page(self, page):
        if page is None:
            return None
        try:
            doc = libxml2.parseDoc(page)
        except libxml2.parserError:
            return None

        root = doc.getRootElement()
        if root.name != 'wvmenu':
            return None
        queryitems = []
        menu = Menu()
        node = root.children
        while node:
            if node.name == 'title':
                menu.title = getContentUnicode(node)
            elif node.name == 'link':
                mi = self.parse_link(node)
                menu.add(mi)
            elif node.name == 'textfield':
                mi = self.parse_textfield(node)
                menu.add(mi)
                queryitems.append(mi)
            elif node.name == 'itemlist':
                mi = self.parse_itemlist(node)
                menu.add(mi)
                queryitems.append(mi)
            elif node.name == 'textarea':
                mi = self.parse_textarea(node)
                menu.add(mi)
            elif node.name == 'button':
                mi = self.parse_button(node, queryitems)
                menu.add(mi)
            node = node.next
        doc.freeDoc()
        return menu
        
    def parse_link(self, node):
        label = ''
        ref = None
        obj = None
        child = node.children
        while child:
            if child.name == 'label':
                label = getContentUnicode(child)
            elif child.name == 'ref':
                ref = getContentUnicode(child)
            elif child.name == 'object':
                obj = getContentUnicode(child)
            child = child.next
        return MenuItemLink(label, ref, obj)

    def parse_textfield(self, node):
        label = ''
        id = node.prop('id')
        child = node.children
        while child:
            if child.name == 'label':
                label = getContentUnicode(child)
            child = child.next
        return MenuItemTextField(label, id)

    def parse_textarea(self, node):
        label = ''
        child = node.children
        while child:
            if child.name == 'label':
                label = getContentUnicode(child)
            child = child.next
        return MenuItemTextArea(label)

    def parse_itemlist(self, node):
        label = ''
        id = node.prop('id')
        items = []
        values = []
        child = node.children
        while child:
            if child.name == 'label':
                label = getContentUnicode(child)
            elif child.name == 'item':
                items.append(getContentUnicode(child))
                values.append(child.prop('value'))
            child = child.next
        return MenuItemList(label, id, items, values)

    def parse_button(self, node, queryitems):
        label = ''
        submission = None
        child = node.children
        while child:
            if child.name == 'label':
                label = getContentUnicode(child)
            elif child.name == 'submission':
                submission = getContentUnicode(child)
            child = child.next
        return MenuItemSubmitButton(label, submission, queryitems)

    def parse_mediaurl(self, xml):
        try:
            doc = libxml2.parseDoc(xml)
        except libxml2.parserError:
            print 'Failed to parse mediaurl'
            return None, None

        title = '???'
        urls_and_priorities = []
        root = doc.getRootElement()
        if root is None:
            return None, None
        node = root.children
        while node:
            if node.name == 'title':
                title = getContentUnicode(node)
            elif node.name == 'url':
                p = node.prop('priority')
                if p is None:
                    p = self.DEFAULT_URL_PRIORITY
                urls_and_priorities.append((p, getContentUnicode(node)))
            node = node.next
        doc.freeDoc()
        if not title:
            title = '???'
        urls_and_priorities.sort()
        urls_and_priorities.reverse()
        urls = [b for a,b in urls_and_priorities]
        return title, urls

    def get_headers_raw(self, ref, verb='GET'):
        if (self.sock is None) and (not self.connect()):
            return (-1, 'Can\'t connect to server', None)
        try:
            self.sock.sendall('%s %s WVTP/1.0\r\n\r\n' % (verb, ref))
        except socket.error, e:
            return (-1, 'Error in sendall: ' + str(e), None)
        responsecode, responsetext, headers = self.read_headers()
        return (responsecode, responsetext, headers)
        
    def get_raw(self, ref, verb='GET'):
        status, statuspharse, headers = self.get_headers_raw(ref, verb)
        if status != 200:
            return (status, statuspharse, None, None)
        body = cStringIO.StringIO()
        contentlen = int(headers.get('content-length', 0))
        r = self.write_body(body, contentlen)
        if r != contentlen:
            status = -1
            statuspharse = 'Length of body (%d) differs from expected (%d)' % (r, contentlen)
        return (status, statuspharse, headers, body.getvalue())

    def get(self, ref):
        status, statusmsg, headers, body = self.get_raw(ref)
        return (status, statusmsg, self.parse_page(body))

    def guess_extension(self, mimetype, url):
        ext = mimetypes.guess_extension(mimetype)
        if (ext is None) or (mimetype == 'text/plain'):
            # This function is only called for video files. Try to
            # extract the extension from url because text/plain is
            # clearly wrong.
            baseurl = url.split('?', 1)[0]
            i = baseurl.rfind('.')
            if i == -1:
                ext = ''
            else:
                ext = baseurl[i:]
            
        return ext

    def download(self, obj):
        status, statusmsg, headers, body = self.get_raw(obj)
        if body is None:
            return False
        title, urls = self.parse_mediaurl(body)
        if urls is None:
            print 'No URLs!'
            return False

        for url in urls:
            status, statuspharse, headers = self.get_headers_raw(url, 'DOWNLOAD')
            if status == 200:
                ext = self.guess_extension(headers.get('content-type', ''), url)
                
                filename = self.next_available_file_name(self.safe_filename(title), ext)
                contentlen = int(headers.get('content-length', 0))
                f = None
                try:
                    f = open(filename, 'w')
                    r = self.write_body(f, contentlen, sys.stdout)
                    f.close()
                except IOError, e:
                    print 'IOError while writing to %s: %s' % (filename, e.args[1])
                    return False
                print 'Saved to %s' % filename
                if r != contentlen:
                    print 'Warning: the size of the file (%d) differs from expected (%d)' % (r, contentlen)
                return True

        print 'No valid url found'
        return False

    def safe_filename(self, name):
        """Sanitize a filename. No paths (replace '/' -> '!') and no
        names starting with a dot."""
        return name.replace('/', '!').lstrip('.')
        
    def next_available_file_name(self, basename, ext):
        if not os.path.exists(basename + ext):
            return basename + ext
        i = 1
        while os.path.exists('%s-%d%s' % (basename, i, ext)):
            i += 1
        return '%s-%d%s' % (basename, i, ext)

    def play_stream(self, obj):
        status, statusmsg, headers, body = self.get_raw(obj)
        if body is None:
            return False
        title, urls = self.parse_mediaurl(body)
        if urls is None:
            print 'No URLs!'
            return False

        # The mediaurl object contains sorted list of alternative URLs
        # for the same content. Test them to find out the first
        # accessible one. STREAM command returns error if url points
        # to a non-existing file and returns URL to a video file if
        # url is an playlist.
        streamurl = None
        for url in urls:
            status, statusmsg, headers, body = self.get_raw(url, 'STREAM')
            if body is None:
                continue

            title, mediaurls = self.parse_mediaurl(body)
            if mediaurls is None:
                continue
            streamurl = mediaurls[0]
            break

        if streamurl is None:
            return False

        # Found url, now find a working media player
        for pl in self.streamplayers:
            if '%s' not in pl:
                playcmd = pl + ' ' + streamurl
            else:
                try:
                    playcmd = pl % streamurl
                except TypeError:
                    print 'Can\'t substitute URL in', pl
                    continue

            try:
                print 'Trying player: ' + playcmd
                retcode = subprocess.call(playcmd, shell=True)
                if retcode > 0:
                    print 'Player failed with returncode', retcode
                else:
                    return True
            except OSError, e:
                print 'Execution failed:', e
                retcode = -1

        return False

    def get_current_menu(self):
        if (self.history_pointer >= 0) and \
               (self.history_pointer < len(self.history)):
            return self.history[self.history_pointer]
        else:
            return None

    def history_add(self, menu):
        if menu is not None:
            self.history = self.history[:(self.history_pointer+1)]
            self.history.append(menu)
            self.history_pointer = len(self.history)-1

    def history_back(self):
        if self.history_pointer > 0:
            self.history_pointer -= 1
        return self.get_current_menu()

    def history_forward(self):
        if self.history_pointer < len(self.history)-1:
            self.history_pointer += 1
        return self.get_current_menu()

    def update_progress_indicator(self, remaining, totalbytes, now, stream):
        if totalbytes <= 0:
            totalbytes = 1
        if now <= self.download_start:
            now = self.download_start + 1
        
        bytes = totalbytes-remaining
        if totalbytes > 0:
            percentage = float(bytes)/totalbytes * 100.0
        else:
            percentage = 0
        
        # Estimate bytes per second rate from the last 10 samples
        self.progress_samples.append((bytes, now))
        if len(self.progress_samples) > 10:
            self.progress_samples.pop(0)

        bytes_old, time_old = self.progress_samples[0]
        if now > time_old:
            rate = float(bytes-bytes_old)/(now-time_old)
        else:
            rate = 0
        if rate > 0:
            time_left = self.pretty_time(remaining/rate)
        else:
            time_left = '???'
        
        progress = '%3.f %% of %s (%.1f kB/s) %s remaining' % \
                   (percentage, self.pretty_bytes(totalbytes),
                    rate/1024.0, time_left)

        stream.write('\r')
        stream.write(' '*self.progress_len)
        stream.write('\r')
        stream.write(progress)
        stream.flush()

        self.progress_len = len(progress)

    def pretty_bytes(self, bytes):
        if bytes < 1100:
            return '%d B' % bytes
        elif bytes < 1024*1024:
            return '%.1f kB' % (float(bytes)/1024)
        else:
            return '%.1f MB' % (float(bytes)/1024/1024)

    def pretty_time(self, sec):
        sec = int(round(sec))
        if sec < 60:
            return '%d s' % sec
        elif sec < 60*60:
            s = sec % 60
            m = sec/60
            return '%d min %d s' % (m, s)
        else:
            h = sec / (60*60)
            m = (sec-60*60*h) / 60
            return '%d hours %d min' % (h, m)
        

class Menu:
    def __init__(self):
        self.title = None
        self.items = []

    def __str__(self):
        s = u''
        if self.title:
            s = self.title + '\n' + '='*len(self.title) + '\n'
        for i,it in enumerate(self.items):
            s += u'%d. %s\n' % (i+1, unicode(it))
        return s

    def __getitem__(self, i):
        return self.items[i]

    def __len__(self):
        return len(self.items)

    def add(self, menuitem):
        self.items.append(menuitem)


class MenuItemLink:
    def __init__(self, label, ref, obj):
        self.label = unicode(label)
        self.ref = ref
        self.obj = obj

    def __str__(self):
        res = self.label
        if not self.obj:
            res = '[' + res + ']'
        return res

    def activate(self):
        return self.ref


class MenuItemTextField:
    def __init__(self, label, id):
        self.label = unicode(label)
        self.id = id
        self.value = u''

    def __str__(self):
        return u'%s: %s' % (self.label, self.value)

    def get_query_string(self):
        return '%s=%s' % (self.id, urllib.quote_plus(self.value.encode('utf-8')))

    def activate(self):
        self.value = unicode(raw_input('%s> ' % self.label), sys.stdin.encoding)
        return None


class MenuItemTextArea:
    def __init__(self, label):
        self.label = unicode(label)

    def __str__(self):
        return self.label

    def activate(self):
        return None


class MenuItemList:
    def __init__(self, label, id, items, values):
        self.label = unicode(label)
        self.id = id
        assert len(items) == len(values)
        self.items = items
        self.values = values
        self.current = 0

    def __str__(self):
        itemstrings = [self.label + ':']
        for i,k in enumerate(self.items):
            if i == self.current:
                itemstrings.append('<' + k + '>')
            else:
                itemstrings.append(k)
        return u' '.join(itemstrings)

    def get_query_string(self):
        if (self.current >= 0) and (self.current < len(self.items)) \
               and (self.values[self.current] != ''):
            return '%s=%s' % (self.id, self.values[self.current])
        else:
            return None

    def activate(self):
        tmp = raw_input('Select item (1-%d)> ' % len(self.items))
        try:
            i = int(tmp)
            if (i < 1) or (i > len(self.items)):
                raise ValueError
            self.current = i-1
        except ValueError:
            print 'Must be an integer in the range 1 - %d' % len(self.items)
        return None


class MenuItemSubmitButton:
    def __init__(self, label, baseurl, subitems):
        self.label = unicode(label)
        self.baseurl = baseurl
        self.subitems = subitems

    def __str__(self):
        return '[' + self.label + ']'

    def activate(self):
        substrings = [x.get_query_string() for x in self.subitems if x.get_query_string() is not None]
        return self.baseurl + '?' + '&'.join(substrings)


class WVShell(cmd.Cmd):
    def __init__(self, client, completekey='tab', stdin=None, stdout=None):
        cmd.Cmd.__init__(self, completekey, stdin, stdout)
        self.prompt = '> '
        self.client = client

    def preloop(self):
        self.stdout.write('wvclient %s starting\n' % version)
        self.do_menu(None)

    def precmd(self, arg):
        try:
            int(arg)
            return 'select ' + arg
        except ValueError:
            return arg

    def onecmd(self, c):
        try:
            return cmd.Cmd.onecmd(self, c)
        except Exception:
            import traceback
            print 'Exception occured while handling command "' + c + '"'
            print traceback.format_exc()
            return False

    def display_menu(self, menu):
        if menu is not None:
            self.stdout.write(unicode(menu).encode(self.stdout.encoding, 'replace'))
    
    def _get_numbered_item(self, arg):
        menu = self.client.get_current_menu()
        try:
            v = int(arg)-1
            if (menu is None) or (v < 0) or (v >= len(menu)):
                raise ValueError
        except ValueError:
            self.stdout.write('Invalid selection: %s\n' % arg)
            return None
        return menu[v]
        
    def do_select(self, arg):
        """select x
Select the link whose index is x.
        """
        menuitem = self._get_numbered_item(arg)
        if menuitem is None:
            return False
        ref = menuitem.activate()
        if ref is not None:
            status, statusmsg, menu = self.client.get(ref)
            if menu is not None:
                self.client.history_add(menu)
            else:
                self.stdout.write('Error: %d %s\n' % (status, statusmsg))
        else:
            menu = self.client.get_current_menu()
        self.display_menu(menu)
        return False

    def do_download(self, arg):
        """download x
Download media object whose index is x to a file. Downloadable items
are the ones without brackets.
        """
        menuitem = self._get_numbered_item(arg)
        if menuitem is None:
            return False
        elif hasattr(menuitem, 'obj') and menuitem.obj is not None:
            self.client.download(menuitem.obj)
        else:
            self.stdout.write('Not a media object\n')
        return False

    def do_stream(self, arg):
        """stream x
Play the media file whose index is x. Media objects are the ones
without brackets.
        """
        menuitem = self._get_numbered_item(arg)
        if menuitem is None:
            return False
        elif hasattr(menuitem, 'obj') and menuitem.obj is not None:
            self.client.play_stream(menuitem.obj)
        else:
            self.stdout.write('Not a media object\n')
        return False

    def do_display(self, arg):
        """Redisplay the current menu."""
        if not arg:
            self.display_menu(self.client.get_current_menu())
        else:
            self.stdout.write('Unknown parameter %s\n' % arg)
        return False

    def do_menu(self, arg):
        """Get back to the main menu."""
        status, statusmsg, menu = self.client.get('/mainmenu')
        if menu is not None:
            self.client.history_add(menu)
            self.display_menu(menu)
        else:
            self.stdout.write('Error: %d %s\n' % (status, statusmsg))
        return False

    def do_back(self, arg):
        """Go to the previous menu in the history."""
        menu = self.client.history_back()
        self.display_menu(menu)
        return False

    def do_forward(self, arg):
        """Go to the next menu in the history."""
        menu = self.client.history_forward()
        self.display_menu(menu)
        return False

    def do_quit(self, arg):
        """Quit the program."""
        self.client.close()
        return True

    def do_EOF(self, arg):
        """Quit the program."""
        self.client.close()
        return True


def load_config(options):
    """Load options from config files."""
    cp = RawConfigParser()
    readfiles = cp.read(['/etc/webvi.conf', os.path.expanduser('~/.webvi')])
    if cp.has_section('webvi'):
        for opt, val in cp.items('webvi'):
            options[opt] = val
    return options

def parse_command_line(options):
    parser = OptionParser()
    parser.add_option('-s', '--server', type='string', dest='host',
                      help='connect to SERVER', metavar='SERVER',
                      default=DEFAULT_SERVER)
    parser.add_option('-p', '--port', type='int', dest='port',
                      help='connect to PORT', metavar='PORT',
                      default=DEFAULT_PORT)
    (cmdlineopt, args) = parser.parse_args()

    options['host'] = cmdlineopt.host
    options['port'] = cmdlineopt.port
    return options

def player_list(options):
    """Return a sorted list of player commands extracted from options
    dictionary."""
    # Load streamplayer items from the config file and sort them
    # according to priority.
    players = []
    for opt,val in options.iteritems():
        m = re.match(r'streamplayer([1-9])$', opt)
        if m is not None:
            players.append((int(m.group(1)), val))

    players.sort()
    ret = []
    for p,playcmd in players:
        ret.append(playcmd)

    # If the config file did not define any players use the default
    # players
    if not ret:
        ret = list(DEFAULT_PLAYERS)

    return ret

def main():
    options = load_config(dict(DEFAULT_CONFIG))
    options = parse_command_line(options)

    shell = WVShell(WVClient(options['host'], options['port'], 
                             options['daemonpath'], 
                             options['daemonargs'].split(' '),
                             player_list(options)))
    shell.cmdloop()

if __name__ == '__main__':
    main()
