#!/usr/bin/python
#
# sparqlprottests.py - SPARQL Protocol Test
#

import os, sys, string, getopt
import RDF, SimpleRDF, Vocabulary
import sparqlclient
import rdfdiff

from SimpleRDF import Resource
from xml.dom import minidom

storage = None

class ProtocolTest(SimpleRDF.Resource):

    NS = "http://www.w3.org/2001/sw/DataAccess/proto-tests/test-manifest#"
    accepttype = RDF.Node(uri_string = NS + "acceptType")
    compliantresult = RDF.Node(uri_string = NS + "compliantResult")
    comment = RDF.Node(uri_string = NS + "comment")
    dataset = RDF.Node(uri_string = NS + "dataSet")
    defaultgraph = RDF.Node(uri_string = NS + "defaultGraph")
    entries = RDF.Node(uri_string = NS + "entries")
    graphdata = RDF.Node(uri_string = NS + "graphData")
    graphname = RDF.Node(uri_string = NS + "graphName")
    name = RDF.Node(uri_string = NS + "name")
    namedgraph = RDF.Node(uri_string = NS + "namedGraph")
    preferredresult = RDF.Node(uri_string = NS + "preferredResult")
    query = RDF.Node(uri_string = NS + "query")
    result = RDF.Node(uri_string = NS + "result")
    resultcode = RDF.Node(uri_string = NS + "resultCode")
    resultcontenttype = RDF.Node(uri_string = NS + "resultContentType")
    servicedataset = RDF.Node(uri_string = NS + "serviceDataSet")
    querydataset = RDF.Node(uri_string = NS + "queryDataSet")
    service = RDF.Node(uri_string = NS + "service")

    props = { "name" : name,
		  "servicedataset" : servicedataset,
		  "comment" : comment,
		  "query" : query,
		  "accepttype" : accepttype,
		  "preferredresult" : preferredresult,
          "querydataset" : querydataset,
          "service" : service }

    def __init__(self, model, node, dir, svc_hostport):
        SimpleRDF.Resource.__init__(self, model, node)
        self.dir = dir
        self.svc_hostport = svc_hostport

    def get_service(self):
        if not self.svc_hostport is None:
            return self.svc_hostport
        service = Resource.getproperty(self, ProtocolTest.service)
        if service and service.is_resource():
          return str(service.uri)
      
    def get_query(self, translate=True):
        file = self["query"]
        if file:
            file = str(file.uri)
            if file.find("file:") == 0:
                file = file[5:]
            f = open(file, "rU")
            query = f.read()
            f.close()
            
            # replace query URIs with those from queryDataSet
            if translate:
               mappings = []
               qds = self.get_querydataset()
               if qds:
                  mappings.extend([ (str(g["name"].uri), g["data"]) for g in qds.get_defaultgraphs() ])
                  mappings.extend([ (str(g["name"].uri), g["data"]) for g in qds.get_namedgraphs() ] )
                  for mapping in mappings:
                     query = query.replace(mapping[0], mapping[1])
               
            return query            

    def get_accept_types(self):
        accepttypes = []
        stream = Resource.listproperties(self, ProtocolTest.accepttype)
        for s in stream:
            accepttype = str(s.object)
            if not accepttype in accepttypes:
                accepttypes.append(accepttype)

        return accepttypes

    def get_results(self):
        preferred = self.get_preferred_result()
        compliant = self.get_compliant_results()
        if preferred:
            preferred = [preferred]
            preferred.extend(compliant)
            return preferred
        else:
            return compliant 

    def get_querydataset(self):
        results = []
        stream = Resource.listproperties(self, ProtocolTest.querydataset)
        for s in stream:
            results.append(DataSet(self.model, s.object))

        if len(results) > 1:
            raise Exception("Only one querydataset is allowed.")
        elif len(results) == 1:        
            return results.pop(0)

    def get_dataset(self):
        results = []
        stream = Resource.listproperties(self, ProtocolTest.dataset)
        for s in stream:
            results.append(DataSet(self.model, s.object))

        if len(results) > 1:
            raise Exception("Only one dataset is allowed.")
        elif len(results) == 1:        
            return results.pop(0)
    
    def get_preferred_result(self):
        results = []
        stream = Resource.listproperties(self, ProtocolTest.preferredresult)
        for s in stream:
            results.append(Result(self.model, s.object))

        if len(results) > 1:
            raise Exception("Only one preferred result is allowed.")
        elif len(results) == 1:        
            return results.pop(0)

    def get_compliant_results(self):
        results = []
        stream = Resource.listproperties(self, ProtocolTest.compliantresult)
        for s in stream:
            results.append(Result(self.model, s.object))
        return results

    def __getitem__(self, key):
        if ProtocolTest.props.has_key(string.lower(key)):
            return Resource.getproperty(self, ProtocolTest.props[string.lower(key)])
 

class InvalidTestError(Exception):
  def __init__(self, manifest, test, msg):
    self.msg = msg
    self.test = test
    self.manifest = manifest

  def __str__(self):
    return "Invalid test: '%s'\nfound at: '%s'\nerror: %s" % (self.test, self.manifest, self.msg)

class Graph(SimpleRDF.Resource):

  def __getitem__(self, key):
    if key == "name":
      return Resource.getproperty(self, ProtocolTest.graphname)
    elif key == "data":
      data = Resource.getproperty(self, ProtocolTest.graphdata)
      if data and data.is_resource():
        return str(data.uri)
      else:
        return data

  def load(self):
    data = Resource.getproperty(self, ProtocolTest.graphdata)
    uri = str(data.uri)
    if uri.find("file:") == 0:
      file = os.path.normcase(uri[5:])
      if not os.path.exists(file):
        raise Exception("Graph could not find file: %s" % file)
    

class DataSet(SimpleRDF.Resource):

  def get_defaultgraphs(self):
    graphs = []
    for statement in SimpleRDF.Resource.listproperties(self, ProtocolTest.defaultgraph):
      graphs.append(Graph(self.model, statement.object))
    return graphs

  def get_namedgraphs(self):
    graphs = []
    for statement in SimpleRDF.Resource.listproperties(self, ProtocolTest.namedgraph):
      graphs.append(Graph(self.model, statement.object))
    return graphs

class Result(SimpleRDF.Resource):
    
    props = { "result" : ProtocolTest.result ,
		  "resultcode" : ProtocolTest.resultcode ,
		  "resultcontenttype" : ProtocolTest.resultcontenttype }

    def __getitem__(self, key):
        if Result.props.has_key(string.lower(key)):
            return Resource.getproperty(self, Result.props[string.lower(key)])


    def get_result_as_model(self):
        file = self["result"]
        if file:
            model = SimpleRDF.load_model(str(file.uri))
            return model

    def get_result(self):
        file = self["result"]
        if file:
            file = str(file.uri)
        if file.find("file:") == 0:
            file = file[5:]
            f = open(file, "rU")
            result = f.read()
            f.close()
            return result     
 
 
def init():
  global storage
  storage = RDF.MemoryStorage(mem_name="proto-tests",
                              options_string="contexts='yes'")
  if storage is None:
    raise "new RDF.Storage failed"

def load_tests(basedir, svc_hostport=None):
    testdirs = [ f for f in os.listdir(basedir) if os.path.isdir(os.path.join(basedir, f)) ]  
    entries = []

    testmodel = RDF.Model(storage)

    for testdir in testdirs:
        testdir = os.path.join(basedir, testdir)

        for testfile in [ f for f in os.listdir(testdir) if f == "manifest.ttl" ]:
            manifest = os.path.join(testdir, testfile)
            context = RDF.Node(RDF.Uri("file:"+manifest))
            model = SimpleRDF.load_model(manifest)
            if not len(model) > 0:
                continue  
            testmodel.add_statements(model.as_stream(), context)


    qs = RDF.Statement( subject = None,
        predicate = ProtocolTest.entries,
        object = None)
    for (statement, context) in testmodel.find_statements_context(qs):
        entries.extend([ ProtocolTest(testmodel, t, testdir, svc_hostport) for t in SimpleRDF.from_rdf_list(testmodel, statement.object)])

    return entries

def sparql2graph(xml):

    xmldoc = minidom.parse(xml)

    storage = RDF.MemoryStorage()
    model = RDF.Model(storage)

    RDF_NS = "http://www.w3.org/1999/02/22-rdf-syntax-ns#"
    RDF_type = RDF.Node(uri_string=RDF_NS + "type")

    RES_NS = "http://www.w3.org/2001/sw/DataAccess/tests/result-set#"
    RES_ResultSet = RDF.Node(uri_string=RES_NS + "ResultSet")
    RES_resultVariable = RDF.Node(uri_string=RES_NS + "resultVariable")
    RES_binding = RDF.Node(uri_string=RES_NS + "binding")
    RES_variable = RDF.Node(uri_string=RES_NS + "variable")
    RES_value = RDF.Node(uri_string=RES_NS + "value")
    RES_solution = RDF.Node(uri_string=RES_NS + "solution")
    RES_boolean = RDF.Node(uri_string=RES_NS + "boolean")

    resultNode = RDF.Node()

    model.add_statement(RDF.Statement(resultNode, RDF_type, RES_ResultSet))

    for var in xmldoc.getElementsByTagName('variable'):
        model.add_statement(RDF.Statement(resultNode, RES_resultVariable, var.getAttribute("name")))

    for solution in xmldoc.getElementsByTagName('result'):
        solutionNode = RDF.Node()
        for binding in [ e for e in solution.childNodes if e.nodeType == e.ELEMENT_NODE ]:
            bindingNode = RDF.Node()
            model.add_statement(RDF.Statement(bindingNode, RES_variable, binding.getAttribute('name')))

            for value in [ e for e in binding.childNodes if e.nodeType == e.ELEMENT_NODE ]:
                if value.localName == 'uri':
                    model.add_statement(RDF.Statement(bindingNode, RES_value, RDF.Node(uri_string=_getnodedata(value))))
                elif value.localName == 'bnode':
                    model.add_statement(RDF.Statement(bindingNode, RES_value, RDF.Node(blank=_getnodedata(value))))
                elif value.localName == 'literal':
                    model.add_statement(RDF.Statement(bindingNode, RES_value, _getnodedata(value)))
                        
            model.add_statement(RDF.Statement(solutionNode, RES_binding, bindingNode))
        model.add_statement(RDF.Statement(resultNode, RES_solution, solutionNode))

    for boolean in xmldoc.getElementsByTagName('boolean'):
        model.add_statement(RDF.Statement(resultNode, RES_boolean, _getnodedata(boolean)))

    return model

def _getnodedata(node):
    uri = []
    if len(node.childNodes) > 0:
        for text in [ e for e in node.childNodes if e.nodeType == e.TEXT_NODE ]:
            uri.append(text.data)
    return str("".join(uri))

init()
