import sys, csv, pprint, time
import cx_Oracle
from processing import Pool

SEPARATOR='\x01'

class DB:
    def __init__(self, connection_string):
        self.connection = cx_Oracle.connect(connection_string)
        self.cursor = self.connection.cursor()
        # self.cursor.arraysize = 5000

    def list_tables(self, filter=None):
        if filter:
            where = " where table_name like '%" + filter + "%'"
        else:
            where = ''

        self.cursor.execute('select table_name, owner from all_tables' + where)
        return self.cursor.fetchall()

    def schema(self, tablename):
        self.cursor.execute('select * from %s where 1=2' % (tablename, ))
	return self.cursor.description

    def show_schema(self, tablename):
        pp = pprint.PrettyPrinter()
        pp.pprint( self.schema(tablename) )

    def partitions(self, tablename, owner=None):
        if owner:
            where = " and table_owner='%s'" % (owner,)
        else:
            where = ''
        self.cursor.execute("select PARTITION_NAME, COMPOSITE, SUBPARTITION_COUNT, NUM_ROWS, HIGH_VALUE from dba_tab_partitions where table_name='%s' %s" % (tablename, where))
        return self.cursor.fetchall()

    def subpartitions(self, tablename, partition=None, owner=None):
        where = ''
        if owner:
            where += " and table_owner='%s'" % (owner,)
        if partition:
            where += " and PARTITION_NAME='%s'" % (partition,)

        self.cursor.execute("select SUBPARTITION_NAME, PARTITION_NAME from dba_tab_subpartitions where table_name='%s' %s" % (tablename, where))
        return self.cursor.fetchall()
    
    def metadata(selfself, tablename, partition=None, owner=None):
        """
( 
  fields:
    (time, query : (display, normalized), results :  [url, title, summary]),
  types:
    (datetime, (string, string), (string, string, string) )
  meta:
    { 
      rows: 1000, 
      bytes: 10000, 
      partitions: 64,
      fieldseparator: "\u0001",
      lineseparator: "\r\n",
      compression: "gzip",
    }
)
        """

        schema = self.schema(tablename)
        fields = [field[0] for field in schema]
        types  = [field[1] for field in schema]
        meta = {}
        meta['partitions']     = len(self.partitions(tablename, owner))
        meta['compression']    = 'none'
        meta['fieldseparator'] = SEPARATOR

        metadata = {
            'fields': fields,
            'types':  types,
            'meta':   meta
        }
        return metadata

def dump_part(connection_string, table, fields, subpart=None, outdir='out'):
    sys.stderr.write("Dumping %s subpartition %s, fields: [%s]\n" % (table, subpart, ', '.join(fields)) )
    start = time.time()
    dc = DB(connection_string)
    dc.cursor.arraysize=5000
    subsql = ''
    if subpart:
        subsql = ' subpartition(%s)' % (subpart, )

    try:
        dc.cursor.execute('select /*+ parallel(%s) */ %s from %s' % (table, ','.join(fields), table) + subsql)
        if subpart:
            fname = subpart
        else:
            fname = table
        outfile = outdir + "/" + table + "_" + fname + ".dat"
        writer = csv.writer(open( outfile, "wb"), delimiter=SEPARATOR, quoting=csv.QUOTE_NONE)
        writer.writerows( dc.cursor.fetchall() )
    except Exception, e:
        print "EXCEPTION:", str(e)    
    end = time.time()
    sys.stderr.write("Dumped %s %s in %.0f seconds\n" % (table, subpart, end-start) )
    return (outfile, table, fname)

class DBDumper():
    def __init__(self, connection_string, table, fields, owner=None, partition=None, outdir='out', concurrency=4):
        self.connection_string = connection_string
        self.table = table
        self.fields = fields
        self.owner = owner
        self.partition = partition
        self.outdir = outdir
        self.concurrency = concurrency

    def dump(self, callback=None):
        d = DB(self.connection_string)
        subpartitions = d.subpartitions(self.table, self.partition, self.owner)
        partnames = [r[0] for r in subpartitions]
        print "Dumping %s, %d subpartitions." % (self.table, len(partnames))

        pool = Pool(processes=self.concurrency)
        if len(partnames) == 0:
            pool.apply_async(dump_part, args=(self.connection_string, self.table, self.fields, None, self.outdir), callback=callback)
            
        for part in partnames:
            pool.apply_async(dump_part, args=(self.connection_string, self.table, self.fields, part, self.outdir), callback=callback)
        
        pool.close()
        pool.join()
        print "DONE"
