""" 
   Copyright (C) 2001 PimenTech SARL (http://www.pimentech.net)

   This library is free software; you can redistribute it and/or
   modify it under the terms of the GNU Library General Public License as
   published by the Free Software Foundation; either version 2 of the
   License, or (at your option) any later version.

   This library is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   Library General Public License for more details.

   You should have received a copy of the GNU Library General Public
   License along with this library; see the file COPYING.LIB.  If not,
   write to the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
   Boston, MA 02111-1307, USA.  
"""

from log import *
from pg import DB
from string import *
from object import *
from sqlcommon import *

class DbHandler(Object):

	def __init__(self, DBNAME, DBUSER, DBPWD, DBHOST = 'localhost', DBPORT = '5432', log = LOG('DbHandler.log'), name = 'DbHandler', debug = None):
		Object.__init__(self, name)
		if debug:
			self.stderr = stderr
		else:
			self.stderr = open('/dev/null','w')
		self.log = log

		self.DBNAME = DBNAME
		self.DBUSER = DBUSER
		self.DBPWD = DBPWD
		self.DBHOST = DBHOST
		self.DBPORT = DBPORT
		self.db = None
		self.open()

	def open(self):
		
		if self.db: self.close()
		
		try:
			self.db = DB(self.DBNAME,self.DBHOST,int(self.DBPORT),'','',self.DBUSER,self.DBPWD)
			self.query("set datestyle to 'European' ; set datestyle to 'SQL'")
			self.log.message("connect database %s@%s with user %s" % (self.DBNAME,self.DBHOST,self.DBUSER))
		except:
			self.log.error('cannot connect %s' % sys.exc_value)

	def close(self):
		
		if not self.db: return
		
		try:
			self.db.close()
			self.db = None
			self.log.message("disconnect from database %s@%s with user %s" % (self.DBNAME,self.DBHOST,self.DBUSER))
		except:
			self.log.error('cannot close %s' % sys.exc_value)
			
	def query(self, src):
		try:
			self.stderr.write("trying : %s\n" % src)
			query = self.db.query(src)
			if not query: # cas update
				query = 1
			return query
		except:
			self.log.warning('cannot '+src+' :'+sys.exc_value)
			return None

	def get_uid(self, sequenceName = 'object_uid_seq'):
		query = self.query("select nextval('%s'::text)" % sequenceName)
		try:
			return query.getresult()[0][0]
		except: # query n'est pas du bon type avec une transaction en echec
			self.log.warning('unable to get new uid from sequence %s' % sequenceName)
			return None
	
	def uid_insert(self, src, sequenceName = 'object_uid_seq', uidName = 'uid'):

		if not self.is_insert_query(src):
			self.log.warning('not an insert query givin up in uid_insert')
			return None
		
		uid = self.get_uid(sequenceName)
		if not uid:
			self.log.warning('cannot '+src)
			return None

		# pseudo code pour rajourter l'uid
		# attention  la syntaxe de l'insert
		# il faut respecter le truc le plus basique :
		l = split(src,'(')
		try:
			src = "%s(%s,%s('%s',%s" % (l[0], uidName, l[1], uid, join(l[2:],'('))
		except:
			self.log.warning('don t know how to split query '+src)
			return None

		if self.query(src):
			return uid

		return None
	
	def is_insert_query(self, src):
		try:
			command = upper(split(src)[0])
			if command == 'INSERT':
				return 1
		except:
			return 0
		return 0
