Index: trunk/ippToPsps/jython/batch.py
===================================================================
--- trunk/ippToPsps/jython/batch.py	(revision 31322)
+++ trunk/ippToPsps/jython/batch.py	(revision 31348)
@@ -12,4 +12,5 @@
 
 from datastore import Datastore
+from scratchdb import ScratchDb
 from gpc1db import Gpc1Db
 from ipptopspsdb import IppToPspsDb
@@ -23,6 +24,4 @@
 '''
 class Batch(object):
-
-    driverName="com.mysql.jdbc.Driver"
 
     '''
@@ -45,34 +44,21 @@
         self.survey = survey
 
+        # TODO
+        self.tablesToExport = []
+
         # open config
         doc = ElementTree(file="config.xml")
 
-        # set up JDBC connection to local Db
-        dbName = doc.find("localdatabase/name").text
-        dbHost = doc.find("localdatabase/host").text
-        dbUser = doc.find("localdatabase/user").text
-        dbPass = doc.find("localdatabase/password").text
-        self.localUrl = "jdbc:mysql://"+dbHost+"/"+dbName+"?user="+dbUser+"&password="+dbPass
-        self.localCon = DriverManager.getConnection(self.localUrl)
-        self.localStmt = self.localCon.createStatement()
-
         # create Gpc1Db object
         self.gpc1Db = Gpc1Db(self.logger)
+        self.ippToPspsDb = IppToPspsDb(logger)
+        self.scratchDb = ScratchDb(logger)
 
         if self.survey != "":
-
-            # get survey ID from init table
-            sql = "SELECT surveyID FROM Survey WHERE name = '" + self.survey + "'"
-            try:
-                rs = self.localStmt.executeQuery(sql)  
-                rs.first()
-                self.surveyID = rs.getInt(1)
-            except:
-                self.logger.exception("No survey ID found for this survey: '" + self.survey + "'")
-                self.surveyID = -1;
+            self.surveyID = self.scratchDb.getSurveyID(self.survey)
    
             # get dvo info from config
-            dvoName = doc.find("dvo_"+survey+"/name").text
-            self.dvoLocation = doc.find("dvo_"+survey+"/location").text
+            dvoName = doc.find("dvo_" + self.survey + "/name").text
+            self.dvoLocation = doc.find("dvo_" + self.survey + "/location").text
         else:
             dvoName = ""
@@ -83,9 +69,8 @@
         self.datastore = Datastore(self.logger)
 
-        # create IppToPspsDb object and create a new batch
-        self.ippToPspsDb = IppToPspsDb(logger)
-        self.batchID = self.ippToPspsDb.createNewBatch(66, 
-                survey, 
+        # create a new batch
+        self.batchID = self.ippToPspsDb.createNewBatch( 
                 self.getPspsBatchType(), 
+                survey,
                 dvoName, 
                 self.datastore.product)
@@ -97,5 +82,4 @@
         if not os.path.exists(self.localOutPath): os.makedirs(self.localOutPath)
 
-
         # store today's date
         now = datetime.datetime.now();
@@ -109,24 +93,5 @@
 
         # create DVO table
-        self.createDvoTables()
-
-
-    '''
-    Gets photcode (aka photoCalID from dvo table)
-    '''
-    def getPhotoCalID(self):
-
-        photcode = -1
-
-        sql = "SELECT photcode FROM dvoMeta"
-        try:
-            rs = self.localStmt.executeQuery(sql)  
-            rs.first()
-            photcode = rs.getInt(1)
-        except:
-            self.logger.exception("Unable to get photcode from dvo table")
-
-
-        return photcode
+        self.scratchDb.createDvoTables()
 
     '''
@@ -136,26 +101,27 @@
 
         self.logger.debug("Batch destructor")
-        self.localStmt.close()
-        self.localCon.close()
+
+
+    '''
+    Returns the value from this dictinary or else NULL
+    '''
+    def safeDictionaryAccess(self, header, key):
+
+         if key in header: return header[key]
+         else: return "NULL"
 
     '''
     Finds and reads a header extension
     '''
-    def findAndReadFITSHeader(self, name):
-
-        self.logger.info("Searching for header extension: '" + name + "'...")
-
-        file = open(self.inputFitsPath, 'r')
-
-        index = 0
+    def findAndReadFITSHeader(self, name, file):
+
         found = False
+        
         while True:
-            
-            file.seek(index, 0)
+           
+            index = file.tell()
 
             record = file.read(80)
             if not record: break;
-
-            header = {}
 
             # quit when we reach 'END'
@@ -165,12 +131,12 @@
                 if header['EXTNAME'] == name:
                     found = True
-                    break 
+                    file.seek(index + 2880, 0)
+                    break
+
+            file.seek(index + 2880, 0)
             
-            index = index + 2880
-
         if found != True: self.logger.error("...could not find extension '" + name + "'")
         else: self.logger.info("...read header at '" + name + "' and found " + str(len(header)) + " header cards") 
 
-        # TODO close file?
         return header
 
@@ -257,6 +223,6 @@
     def publishToDatastore(self):
 
-        self.datastore.publish(self.batchName, self.subDir, self.tarballFile, "tgz")
-        # TODO update ippToPsps Db here
+        if self.datastore.publish(self.batchName, self.subDir, self.tarballFile, "tgz"):
+            self.ippToPspsDb.updateLoadedToDatastore(self.batchID, 1)
 
     '''
@@ -264,4 +230,6 @@
     '''
     def getBatchFriendlySurveyType(self):
+
+        return "SCR" # TODO
 
         try:
@@ -286,48 +254,26 @@
         else: self.logger.error("Don't know this batch type: " + self.survey)
 
-
-    '''
-    Prints a log message with the current time
-    '''
-    def log(self, msg):
-
-        print datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + " | " + msg
-
-    '''
-    Sets min and max obj ID using the provided table
-    '''
-    def setMinMaxObjID(self, table):
-
-        sql = "SELECT MIN(objID), MAX(objID) FROM " + table
-        print 
-        rs = self.localStmt.executeQuery(sql)
-        rs.first()
-        self.minObjID = rs.getLong(1)
-        self.maxObjID = rs.getLong(2)
-
-    '''
-    Drops a table
-    '''
-    def dropTable(self, table):
-
-        sql = "DROP TABLE " + table
-        try: self.localStmt.execute(sql)
-        except: return
-
-    '''
-    Updates a table with surveyID
-    '''
-    def updateSurveyID(self, table):
-
-        sql = "UPDATE " + table + "  SET surveyID=%d" % self.surveyID
-        self.localStmt.execute(sql)
-
-    '''
-    Updates a table with filterID grabbed from Filter init table
-    '''
-    def updateFilterID(self, table):
-
-        sql = "UPDATE "+table+" AS a, Filter AS b SET a.filterID=b.filterID WHERE b.filterType = '" + self.filter + "'"
-        self.localStmt.execute(sql)
+    '''
+    Sets min and max obj ID using the provided table, or list of tables
+    '''
+    def setMinMaxObjID(self, tables):
+
+        first = True
+        for table in tables:
+
+            sql = "SELECT MIN(objID), MAX(objID) FROM " + table
+            rs = self.scratchDb.stmt.executeQuery(sql)
+            rs.first()
+
+            if first:
+                self.minObjID = rs.getLong(1)
+                self.maxObjID = rs.getLong(2)
+            else:
+                if rs.getLong(1) < self.minObjID: self.minObjID = rs.getLong(1)
+                if rs.getLong(2) > self.maxObjID: self.maxObjID = rs.getLong(2)
+
+            first = False
+
+        self.ippToPspsDb.updateMinMaxObjID(self.batchID, self.minObjID, self.maxObjID)
 
     '''
@@ -351,4 +297,5 @@
                param = match.group(2)
                value = match.group(3).strip()
+               if value == "NaN": value = "NULL"
                header[param] = value
 
@@ -365,5 +312,6 @@
          for table in self.pspsTables:
              self.logger.info("Creating PSPS table: " + table.name)
-             table.write(self.localUrl + '#' + table.name)
+             table.write(self.scratchDb.url + '#' + table.name)
+             self.tablesToExport.append(table.name)
 
          self.indexPspsTables()
@@ -375,19 +323,4 @@
         self.logger.warn("indexIppTables not implemented")
 
-
-    '''
-    Adds an index to the supplied table and column
-    '''
-    def createIndex(self, table, column):
-
-        self.logger.info("Creating index on column '"+column+"' for table '"+table+"'")
-
-        sql = "CREATE INDEX "+table+"_index ON "+table+" ("+column+")"
-        try:
-            self.localStmt.execute(sql)
-        except: pass
-            #self.logger.warn("Index already in place on '" + column + "' for table '" + table + "'")
-
-
     '''
     Subclass should implement this to index PSPS tables
@@ -420,8 +353,7 @@
 
           try:
-              table.write(self.localUrl + '#' + table.name)
+              table.write(self.scratchDb.url + '#' + table.name)
           except:
               self.logger.exception("   Problem writing table '" + table.name + "' to the database")
-
           count = count + 1
 
@@ -439,13 +371,13 @@
 
         self.logger.info("    Selecting database tables")
-        for table in self.pspsTables:
+        for table in self.tablesToExport:
 
            # check for an empty table
-           if self.getRowCount(table.name) < 1: continue
+           if self.scratchDb.getRowCount(table) < 1: continue
 
            # get everything from table
-           _table = stilts.tread(self.localUrl + '#SELECT * FROM ' + table.name)
-
-           self.logger.info("   Replacing NULLs with weird PSPS -999 constant for " + table.name)
+           _table = stilts.tread(self.scratchDb.url + '#SELECT * FROM ' + table)
+
+           self.logger.info("   Replacing NULLs with weird PSPS -999 constant for " + table)
 
            # replace nulls and empty fields with weird PSPS -999 pseudo-null
@@ -453,5 +385,5 @@
 
            # change table names
-           _table = stilts.tpipe(_table, cmd='tablename ' + table.name)
+           _table = stilts.tpipe(_table, cmd='tablename ' + table)
            _tables.append(_table)
 
@@ -459,65 +391,5 @@
         stilts.twrites(_tables, self.outputFitsPath, fmt='fits')
         self.logger.info("    ...done")
-
-
-    '''
-    Returns a list of column names for this table
-    '''
-    def getColumnNames(self, tableName):
-
-       sql = "SHOW COLUMNS FROM " + tableName
-       rs = self.localStmt.executeQuery(sql)
-       columns = []
-       while (rs.next()): columns.append(rs.getString(1))
-       rs.close()
-       
-       return columns
-
-    '''
-    Replaces all NULL values in the provided table with the prvoded substitute 
-    '''
-    def replaceNulls(self, tableName, sub):
-
-       # get list of columns
-       columns = self.getColumnNames(tableName)
-
-       # now loop through all columns and replace all NULLs with sub
-       for column in columns:
-          
-          sql = "UPDATE " + tableName + " SET " + column + " = " + sub + " WHERE " + column + " IS NULL"
-          self.localStmt.execute(sql)
-
-
-    '''
-    Searches a table and reports the columns that are either partially or completely populated with NULLs
-    '''
-    def reportNulls(self, tableName, showPartials):
-
-       # first, count rows
-       sql = "SELECT COUNT(*) FROM " + tableName
-       rs = self.localStmt.executeQuery(sql)
-       rs.first()
-       numRows = rs.getInt(1)
-
-       # get list of columns
-       columns = self.getColumnNames(tableName)
-
-       print "+----------------------+---------------+"
-       print "|  %25s           |" % tableName
-       print "+----------------------+---------------+"
-
-       # now see which columns are full of NULLS, with are partially NULL
-       for column in columns:
-          
-          sql = "SELECT COUNT(*) FROM " + tableName + " WHERE " + column + " IS NULL"
-          rs = self.localStmt.executeQuery(sql)
-          rs.first()
-          if rs.getInt(1) == numRows:
-              print "| %20s | all NULL      |" % column
-          elif showPartials and rs.getInt(1) > 0:
-              print "| %20s | partial NULL  |" % column
-       rs.close()
-       print "+----------------------+---------------+"
-
+        self.ippToPspsDb.updateProcessed(self.batchID, 1)
 
     '''
@@ -527,5 +399,5 @@
 
         for table in self.pspsTables:
-            self.reportNulls(table.name, showPartials)
+            self.scratchDb.reportNulls(table.name, showPartials)
 
     '''
@@ -536,5 +408,5 @@
         self.logger.info("Replacing all NULL values in PSPS tables with '" + sub + "'...")
         for table in self.pspsTables:
-            self.replaceNulls(table.name, sub)
+            self.scratchDb.replaceNulls(table.name, sub)
         self.logger.info("...done")
 
@@ -546,66 +418,10 @@
 
     '''
-    Updates provided table with DVO IDs from DVO table
-    '''
-    def updateDvoIDs(self, table):
-
-        self.logger.info("Not implemented in base-class")
-
-    '''
-    Creates a table for for ID matching
-    '''
-    def createDvoTables(self):
-
-        self.logger.info("Creating DVO meta and detection tables")
-        sql = "DROP TABLE dvoMeta"
-        try: self.localStmt.execute(sql)
-        except: pass
-        
-        sql = "DROP TABLE dvoDetection"
-        try: self.localStmt.execute(sql)
-        except: pass
-
-        sql = "CREATE TABLE dvoMeta ( \
-               flags INT, \
-               photcode INT \
-               )"
-
-        try: self.localStmt.execute(sql)
-        except: 
-            self.logger.error("Unable to create DVO meta-data database tablei")
-
-        sql = "CREATE TABLE dvoDetection ( \
-               ippDetectID BIGINT PRIMARY KEY, \
-               detectID BIGINT, \
-               ippObjID BIGINT, \
-               objID BIGINT \
-               )"
-
-        try: self.localStmt.execute(sql)
-        except: 
-            self.logger.error("Unable to create DVO detection database table")
-
-    '''
-    Returns a row count for this table
-    '''
-    def getRowCount(self, table):
-
-        sql = "SELECT COUNT(*) FROM " + table
-        try:
-            rs = self.localStmt.executeQuery(sql)  
-            rs.first()
-            return rs.getInt(1)
-        except:
-            self.logger.exception("Could not count rows for table: '" + table + "'")
-            return -1
-
-
-    '''
     Calls DVO program to 'query' DVO database and populate results to local MySQL Db table
     '''
-    def getIDsFromDVO(self, sourceID, imageID):
+    def getIDsFromDVO(self):
 
         # TODO path to DVO prog hardcoded temporarily
-        cmd = "../src/dvo %s %s %s" % (self.dvoLocation, sourceID, imageID)
+        cmd = "../src/dvograbber " + self.dvoLocation
         self.logger.info("Running: '" + cmd + "'...")
         p = Popen(cmd, shell=True, stdout=PIPE)
@@ -614,5 +430,5 @@
         self.logger.info("...done")
 
-        if self.getRowCount("dvoDetection") < 1:
+        if self.scratchDb.getRowCount("dvoDetection") < 1:
             self.logger.error("No DVO IDs found")
             return False
