from __future__ import generators
import ned_commandLine as commandline
import ned_disorderHelper as disorderHelper
import ned_basic as basic

import os, urllib, sys,copy,operator,math,random
from xml.dom import minidom

class pfamHelper():
	def __init__(self,pfam_directory=""):	
		cmdline = commandline.CommandLine()
		self.options = cmdline.loadIniFile(os.path.join(os.path.dirname(os.path.realpath(__file__)),"../settings/utilities.ini"))
		
		if pfam_directory != "":
			self.options["pfam_dir"] = pfam_directory
		
		self.disScorer = None
		
		
	def parsePfamXML(self,id,dataType="domain",refetch=False):
		self.data = {"domains":{}}
		file_path = os.path.join(self.options["pfam_dir"] , "pfam_xml/" , id + ".xml") #,os.path.join(self.options["pfam_dir"],"/Pfam_xml/",id + ".xml")
		dir_path = os.path.join(self.options["pfam_dir"] , "pfam_xml/")
		
		self.getPfamAnnotation(id)
		
		xmldoc = minidom.parse(file_path)
		xmldoc.normalize()
		
		
		if not os.path.exists(file_path) or basic.fileAgeDays(file_path) > 120 or refetch:
			self.getPfamAnnotation(id,force="T")
			
		try:
			if len(xmldoc.getElementsByTagName('matches')) > 0:
				for node in xmldoc.getElementsByTagName('matches')[0].childNodes:						
					if node.nodeName == "match":
						if node.hasAttributes():
							domainData = {}
							for values in node.attributes.keys():
								domainData[str(node.attributes[values].name)] = str(node.attributes[values].value)
								
							domainName = domainData["id"]
							
							if domainName != "":
								if domainName in self.data["domains"]:
									pass
								else:
									self.data["domains"][domainName] = {"data":domainData}
									self.data["domains"][domainName]['hits'] = []
								
								hitData = {}
								for locationNode in node.getElementsByTagName('location'):
									for values in locationNode.attributes.keys():
										hitData[str(locationNode.attributes[values].name)] = str(locationNode.attributes[values].value)
									
									self.data["domains"][domainName]["hits"].append(hitData)
		except Exception,e:
			print "E",e
			raise				
		
		try:
			if len(xmldoc.getElementsByTagName('entry')) > 0:
				for node in xmldoc.getElementsByTagName('entry'):
					for values in node.attributes.keys():
						self.data[str(node.attributes[values].name)] = str(node.attributes[values].value)
										
			
			if len(xmldoc.getElementsByTagName('sequence')) > 0:
				for node in xmldoc.getElementsByTagName('sequence')[0].childNodes:
					self.data["Sequence"] = str(node.data)
					
		except Exception,e:
			print "E",e
			raise
			
		if dataType == "domain":
			return self.data["domains"]
				
		elif dataType == "sequence":
			return self.data["Sequence"]
		else:
			return self.data
			
			
	def calculateDomainDisorder(self,disorder=None):
		self.disScorer = disorder
		
		if disorder == None:
			self.disScorer = disorderHelper.disorderScorer()
			self.disScorer.disorderFromSequence(self.data["Sequence"])
		
		for domain in self.data["domains"]:
			for i in range(0,len(self.data["domains"][domain]["hits"])):
				
				disorder = sum(self.disScorer.data['ResidueDisorder'][int(self.data["domains"][domain]["hits"][i]["start"]):int(self.data["domains"][domain]["hits"][i]["end"])])/(int(self.data["domains"][domain]["hits"][i]["end"]) - int(self.data["domains"][domain]["hits"][i]["start"]))
				self.data["domains"][domain]["hits"][i]["disorder"] = disorder
		
	
				
	def parsePfamDomainXML(self,domain):
		self.getPfamDomainAnnotation(domain)
		
		xmldoc = minidom.parse(self.options["pfam_dir"] + "pfam_xml_domain/" + domain + '.xml')
		xmldoc.normalize()
		
		self.data = {}
		
		try:
			if len(xmldoc.getElementsByTagName('entry')) > 0:
				self.data['id'] = xmldoc.getElementsByTagName('entry')[0].getAttribute('id')
		except:
			print "Error",e
			raise
			
		try:
			domainDict = {}
			if len(xmldoc.getElementsByTagName('comment')) > 0:#[0].toxml()
				for node in xmldoc.getElementsByTagName('comment')[0].childNodes[1:2]:
				
					self.data['comment'] = str(node.data.strip())
					return str(node.data.strip())
				
		except Exception,e:
			print "Error",e
			raise
			
	def getPfamAnnotation(self,id,force="F"):
		dir_path = os.path.join(self.options["pfam_dir"] , "pfam_xml/")
		file_path = os.path.join(dir_path  , id + ".xml") #,os.path.join(self.options["pfam_dir"],"/Pfam_xml/",id + ".xml")
		
		if not os.path.exists(dir_path):
			os.mkdir(dir_path)
		
			
		if id + ".xml" in os.listdir(dir_path) and force == "F":
			pass
		else:
			url = "http://pfam.sanger.ac.uk/protein?entry=" + id + "&output=xml"
				
			opener = urllib.FancyURLopener()
			f = opener.open(url)
					
			open(file_path,"w").write(f.read())
			print id + " Pfam annotation downloaded"
	
	def getPfamDomainAnnotation(self,domain):
		if os.path.exists(self.options["pfam_dir"] + "pfam_xml_domain/"):
			pass
		else:
			os.mkdir(self.options["pfam_dir"] + "pfam_xml_domain/")
			
		if domain + ".xml" in os.listdir(self.options["pfam_dir"] + "pfam_xml_domain/"):
			pass
		else:
			url = "http://pfam.sanger.ac.uk/family?entry=" + domain + "&output=xml"
				
			opener = urllib.FancyURLopener()
			f = opener.open(url)
							
			open(self.options["pfam_dir"] + "pfam_xml_domain/" + domain + ".xml","w").write(f.read())
			print domain + " Pfam annotation downloaded"
	
	def tablify(self,output="string"):
		domainTablified = {"domain":[],"start":[],"end":[],"disorder":[]}
		
		for domain in self.data["domains"]:
			for i in range(0,len(self.data["domains"][domain]["hits"])):
				domainTablified["domain"].append(domain)
				for key in domainTablified:
					if key != "domain":
						if output == "string":
							value = str(self.data["domains"][domain]["hits"][i][key])
						else:
							value = self.data["domains"][domain]["hits"][i][key]
							
						domainTablified[key].append(value)
		
		for key in domainTablified:
			domainTablified[key] = ",".join(domainTablified[key])
			
		self.data["tablified"] = domainTablified
	
		
	def clusterMatrix(self,matrix,clustered=[]):
		maxTmp = [0,"",""]
		
		for protein1 in matrix["order"]:
			
			maxSim = max(matrix[protein1])
			indexProtein1 = matrix["order"].index(protein1)
			indexMax = matrix[protein1].index(max(matrix[protein1]))
			
			#print maxTmp,indexProtein1,indexMax,maxSim
			
			if maxSim > maxTmp[0] and indexProtein1 !=  indexMax:
				maxTmp = [max(matrix[protein1]),protein1, matrix["order"][matrix[protein1].index(max(matrix[protein1]))]]
		
		if maxTmp[0] != 0:
			
			newCluster = [min(matrix[maxTmp[1]][x],matrix[maxTmp[2]][x]) for x in range(0,len(matrix[maxTmp[1]]))]
			
			removeList = [matrix["order"].index(maxTmp[1]),matrix["order"].index(maxTmp[2])]
			removeList.sort()
			removeList.reverse()
			
			for remove in removeList:
				for protein1 in matrix["order"]:
					del matrix[protein1][remove]
				
				del newCluster[remove]
				del matrix["order"][remove]
				
			
			matrix["order"].append(",".join(maxTmp[1:]))
			matrix[",".join(maxTmp[1:])] = newCluster + [0]
			
			
			for i in range(0,len(newCluster)):
				matrix[matrix["order"][i]].append(newCluster[i])
			
			self.clusterMatrix(matrix)
			
		return matrix
		
		
	def clusterDomains(self,domainData):
		
		ordered = domainData.keys()
		ordered.sort()
		matrix = {"order":ordered}
		
		for protein1 in ordered:
			matrix[protein1] = []
			
			for protein2 in ordered:
				
				testTmp = domainData[protein1].keys() + domainData[protein2].keys()
				
				if protein2 == protein1:
					sim = 0
				else:
					try:
						sim = float(len(filter(lambda x: testTmp.count(x) > 1 , testTmp)))/((len(domainData[protein1].keys()) + len(domainData[protein2].keys())))
					except:
						sim = 0
					
				#print domainData[protein1].keys() ,domainData[protein2].keys(), sim,filter(lambda x: testTmp.count(x) > 1 , testTmp)
				matrix[protein1].append(sim)
		
		return self.clusterMatrix(matrix)
	


	def drawArchitecture(self,ids):
		from reportlab.rl_config import defaultPageSize
		from reportlab.pdfgen import canvas
		from reportlab.lib.units import inch
		from reportlab.lib import colors
		
		colourDict = {}
		shapeType = ["round","square"]
	
		self.options["CANVAS_PADDING_HEIGHT"] = 50
		self.options["CANVAS_PADDING_WIDTH"] = 150
		self.options["dy"] = 35
		
		
		self.options["fontSize"] = 10
		
		domainData = {}
		mapping = {}
		lengths = {}
		
		maxLength = 0
		
		for id in ids:
			try:
				print id
			
				domainData[id] = self.parsePfamXML(id)
			
				proteinData = self.parsePfamXML(id,"")
				
				length = len(proteinData["Sequence"])
				
				lengths[id] =length
				mapping[id] = proteinData["id"]
				
				if lengths[id] > maxLength:
					maxLength = lengths[id]
			except:
				pass
				
		
		clusters = self.clusterDomains(domainData)
		orderList = ",".join(clusters["order"]).split(",")
		
		self.options["CANVAS_HEIGHT"] = 2500#len(ids)*self.options["dy"] + self.options["CANVAS_PADDING_HEIGHT"]*2
		self.options["CANVAS_WIDTH"] = 1800#maxLength + self.options["CANVAS_PADDING_WIDTH"]*2
		
		pdf = canvas.Canvas(sys.argv[2] + ".pdf" ,pagesize=(self.options["CANVAS_WIDTH"],self.options["CANVAS_HEIGHT"]))
		pdf.setFont("Helvetica-Bold", self.options["fontSize"])
	
		x = self.options["CANVAS_PADDING_WIDTH"]
		y = self.options["CANVAS_HEIGHT"] - self.options["CANVAS_PADDING_HEIGHT"]
		
		#############
		#draw Background
		#############
		
		for i in range(0,maxLength,100):
			if (i/100)%2 == 0:
				pdf.setFillColor("black",alpha=0.03)
			else:
				pdf.setFillColor("white")
				
			pdf.rect(i + self.options["CANVAS_PADDING_WIDTH"],y - 10,100,len(ids)*self.options["dy"],stroke=0,fill=1)
			
			pdf.setFillColor("black",alpha=1)
			pdf.drawCentredString(i + self.options["CANVAS_PADDING_WIDTH"] , self.options["CANVAS_PADDING_HEIGHT"] - 20, str(i))
			pdf.drawCentredString(i + self.options["CANVAS_PADDING_WIDTH"] , self.options["CANVAS_PADDING_HEIGHT"] + len(ids)*self.options["dy"] - 5, str(i))
			
		#############
		
		last = []
		for id in orderList:	
			try:
				lastTmp = []
				
				for domain in domainData[id]:
					if domain != "len":
						for instance in domainData[id][domain]["hits"]:
								lastTmp.append(domain)
				
				lastTmp.sort()
				if last != lastTmp:
					y -= self.options["dy"]/2
					
				last = copy.deepcopy(lastTmp)
				
				pdf.setFillColor("black",alpha=1)
				pdf.drawRightString( x - 10, y, mapping[id])
				
				pdf.setFillColor("black",alpha=0.2)
				pdf.setStrokeColorRGB(0,0,0,alpha=0.1)
				
				
				pdf.roundRect(x,y,lengths[id],2,1,fill=1)
				
				for domain in domainData[id]:
					
					if domain not in colourDict:
						
						colour = (random.random()*0.8,random.random(),random.random())
						colourDict[domain] = {}
						colourDict[domain]["colour"] = colour
					
						ran = int(random.random()*2)
						colourDict[domain]["type"] = shapeType[ran]
					
					type = colourDict[domain]["type"]
					colour = colourDict[domain]["colour"]
					
					
					if domain != "len":
						for instance in domainData[id][domain]["hits"]:
							length = int(instance["end"]) - int(instance["start"])
					
							alphaValue = 0.99
							
							if type == "round":
								
								pdf.setStrokeColorRGB(colour[0],colour[1],colour[2],alpha=0.5)
								pdf.setFillColorRGB(colour[0],colour[1],colour[2],alpha=alphaValue)
							
								pdf.roundRect(x  + int(instance["start"]),y - 4,length,11,6,fill=1)
							else:
								pdf.saveState()
								pdf.translate(x  + int(instance["start"]),y - 4)
							
								pdf.setStrokeColorRGB(colour[0],colour[1],colour[2],alpha=0.5)
								pdf.setFillColorRGB(colour[0],colour[1],colour[2],alpha=alphaValue)
								pdf.rect(0,0,length,11,fill=1)
								
								pdf.restoreState()
								
							pdf.setFillColor("black",alpha=0.8)
							pdf.drawString( x + int(instance["start"]) + 3,y - 2, domain)
							
				y -= self.options["dy"]/2
				
				if y < self.options["CANVAS_PADDING_HEIGHT"]:
					y = self.options["CANVAS_HEIGHT"] - self.options["CANVAS_PADDING_HEIGHT"]
					pdf.showPage()
					
			except:	
				print "Error with", id
				
		pdf.save()
		
		
if __name__ == "__main__":
	
	if len(sys.argv) != 3:
		print "usage: ned_pfam uniprot_accession_list pdf_name" 
	else:
		protein = sys.argv[1].split(",")
		pfam = pfamHelper()
		#print pfam.parsePfamXML("P04637")
		
		pfam.drawArchitecture(protein)