#!/usr/bin/python

#-----------------------------------------------------------------------------
#
#    pydtd2xsd.py
#    
#    A DTD to XML Schema converter. See documentation below for invocation.
#    
#    Created:     2004/07/01
#    RCS-ID:      $Id: pydtd2xsd.py,v 1.2 2004/01/12 01:04:18 oliverh Exp $
#
#    Copyright (C) 2004 by Oliver M. Haynold, Evanston, Ill., United States
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program; if not, write to the Free Software
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#-----------------------------------------------------------------------------

# -- Constant declarations --

VERSION_MAJOR = 0
VERSION_MINOR = 1
VERSION_RELEASE = 2
VERSION = str(VERSION_MAJOR) + "." + str(VERSION_MINOR) + \
          "-" + str(VERSION_RELEASE)
STARTUP_STR = """pydtd2xsd %s copyright (c) 2004 by Oliver M. Haynold
pydtd2xsd comes with ABSOLUTELY NO WARRANTY and is licensed under
the GNU General Public License version 2 or later.""" % (VERSION)
NS_XS = 'http://www.w3.org/2001/XMLSchema' 

# -- imports --

import sys
try:
    import string
    import os.path
    from copy import copy, deepcopy
    from xml.parsers.xmlproc.xmldtd import load_dtd
    import Ft.Xml.Domlette
    DomImplementation = Ft.Xml.Domlette.implementation
    import Ft.Xml.Lib.Print
except ImportError, e:
    print STARTUP_STR
    print "It seems you're lacking a Python module necessary to use this program."
    print e
    sys.exit(1)


#------------------------------------------------------------------
# This is a very ugly way to circumvent an apparent bug in xmlproc.
# It seems to work, but I don't know what exactly is going wrong,
# so don't be surprised if it breaks something. Without it,
# one-character entity declarations raise an error 3001.
# There is a discussion of a related problem in
# http://mail.python.org/pipermail/xml-sig/1999-June/001330.html
# so it would seem that this is a non-trivial problem.
# Let's hope that this workaround works.

import xml.parsers.xmlproc.xmlutils

def new_flush(self):
        "Parses any remnants of data in the last block."
        if self.encoded_data:
            try:
                new_data = self.charset_converter(self.encoded_data)
                self.data = self.data + new_data
                self.datasize = len(self.data)
            except UnicodeError,e:
                self.report_error(3048, e)
            self.encoded_data = ""
        if not self.pos+1==self.datasize:
            self.final=1
            pos=self.pos
            try:
                self.do_parse()
            except xml.parsers.xmlproc.xmlutils.OutOfDataException:
                if pos!=self.pos:
                    # new line introduced by omh
                    self.pos = pos
                    # line commented out by omh
                    #self.report_error(3001)

xml.parsers.xmlproc.xmlutils.EntityParser.flush = new_flush

# end xmlproc bug workaround
#------------------------------------------------------------------

# -- Class Declarations

class XMLSchemaGenerator:
    """This class converts a Python structure representing a DTD
    into XML Schema."""
    attrtype_mappings = { "CDATA" : "xsd:string",
                          "NMTOKEN" : "xsd:NMTOKEN",
                          "NMTOKENS" : "xsd:NMTOKENS",
                          "ID": "xsd:ID",
                          "IDREF": "xsd:IDREF",
                          "IDREFS": "xsd:IDREFS" }
    def __init__(self):
        self.ns_uris = {}
        self.filepaths = {}
        self.nonsworkaround = None
    def parse_dtdrepr(self, dr):
        self.doms = {}
        self.globalattrs = {}
        for prefix1 in self.ns_uris.keys():
            dom = DomImplementation.createDocument(NS_XS, "xsd:schema", None)
            if self.ns_uris[prefix1]:
                dom.documentElement.setAttributeNS(None,
                                                   "targetNamespace",
                                                   self.ns_uris[prefix1])
            self.doms[prefix1] = dom
            self.globalattrs[prefix1] = []
            for prefix2 in self.ns_uris.keys():
                if self.ns_uris[prefix2] <> None:
                    dom.documentElement.setAttributeNS(None,
                                                       "xmlns:%s" % prefix2,
                                                       self.ns_uris[prefix2])
                if prefix1 <> prefix2:
                    importnode = dom.createElementNS(NS_XS, "xsd:import")
                    if self.ns_uris[prefix2] <> None:
                        importnode.setAttributeNS(None, "namespace",
                                                  self.ns_uris[prefix2])
                    elif self.nonsworkaround:
                        importnode.setAttributeNS(None, "namespace", "")
                    importnode.setAttributeNS(None, "schemaLocation",
                                              self.filepaths[prefix2])
                    dom.documentElement.appendChild(importnode)
        elems = dr.keys()
        for e in elems:
            prefix, lname = self._split_name(e)
            if prefix not in self.ns_uris.keys():
                raise IndexError("Unknown prefix '%s'." % prefix)
            self._make_elemnode(lname, dr[e], self.doms[prefix])
    def write_results(self):
        for prefix in self.filepaths.keys():
            f = open(self.filepaths[prefix], "w")
            Ft.Xml.Lib.Print.PrettyPrint(self.doms[prefix], f)
    def _make_elemnode(self, lname, edef, dom):
        attrs = edef["attributes"]
        content_model = edef["content_model"]
        node = dom.createElementNS(NS_XS, "xsd:element")
        node.setAttributeNS(None, "name", lname)
        node.appendChild(self._make_content_model(content_model, attrs, dom))
        dom.documentElement.appendChild(node)
    def _make_attr_node(self, attrname, attrdef, dom, ref, toplev):
        decl = attrdef["decl"]
        type = attrdef["type"]
        default = attrdef["default"]
        node = dom.createElementNS(NS_XS, "xsd:attribute")
        if not ref:
            node.setAttributeNS(None, "name", attrname)
        else:
            node.setAttributeNS(None, "ref", attrname)
        # Process decl and default
        if not toplev:
            if decl == "#DEFAULT":
                node.setAttributeNS(None, "default", default)
            elif decl == "#IMPLIED":
                node.setAttributeNS(None, "use", "optional")
            elif decl == "#REQUIRED":
                node.setAttributeNS(None, "use", "required")
            elif decl == "#FIXED":
                node.setAttributeNS(None, "fixed", default)
            else:
                raise ValueError("Unknow attribute declaration '%s'." % decl)
        # Process type
        if not ref:
            if type in self.attrtype_mappings.keys():
                node.setAttributeNS(None, "type", self.attrtype_mappings[type])
            else:
                # It's an enumeration
                stypenode = dom.createElementNS(NS_XS, "xsd:simpleType")
                srestrnode = dom.createElementNS(NS_XS, "xsd:restriction")
                stypenode.appendChild(srestrnode)
                srestrnode.setAttributeNS(None, "base", "xsd:string")
                for val in type:
                    enumnode = dom.createElementNS(NS_XS, "xsd:enumeration")
                    enumnode.setAttributeNS(None, "value", val)
                    srestrnode.appendChild(enumnode)
        return node
    def _make_content_model(self, content_model, attrs, dom):
        res = dom.createElementNS(NS_XS, "xsd:complexType")
        if content_model == None:
            res.setAttributeNS(None, "mixed" ,"true")
            listelem = dom.createElementNS(NS_XS, "xsd:sequence")
            res.appendChild(listelem)
            anyelem = dom.createElementNS(NS_XS, "xsd:any")
            listelem.appendChild(anyelem)
            anyelem.setAttributeNS(None, "minOccurs", "0")
            anyelem.setAttributeNS(None, "maxOccurs", "unbounded")
            anyelem.setAttributeNS(None, "processContents", "lax")
        else:
            listelem, mixed = self._make_list_decl(content_model, dom)
            res.appendChild(listelem)
            if mixed:
                res.setAttributeNS(None, "mixed" ,"true")
        attrnames = attrs.keys()
        for attr in attrnames:
            prefix, localname = self._split_name(attr)
            if not prefix:
                res.appendChild(self._make_attr_node(attr, attrs[attr], dom,
                                                     None, None))
            elif prefix[:3] <> "xml":
                if localname not in self.globalattrs[prefix]:
                    self.globalattrs[prefix].append(localname)
                    self.doms[prefix].documentElement.appendChild(
                        self._make_attr_node(localname, attrs[attr],
                                             self.doms[prefix], None, 1))
                res.appendChild(self._make_attr_node(attr, attrs[attr], dom,
                                                     1, 1))

        return res
    def _make_list_decl(self, content_model, dom):
        mixed = None
        type, elist, repeat = content_model
        if type in ["|", ",", ""]:
            if type in ["|", ""]:
                listelem = dom.createElementNS(NS_XS, "xsd:choice")
            else:
                listelem = dom.createElementNS(NS_XS, "xsd:sequence")
            min, max = { "" : (None, None),
                         "*" : ("0", "unbounded"),
                         "?" : ("0", None),
                         "+" : (None, "unbounded") }[repeat]
            min and listelem.setAttributeNS(None, "minOccurs", min)
            max and listelem.setAttributeNS(None, "maxOccurs", max)
            for subel in elist:
                if len(subel) == 3:
                    nnode, garbage = self._make_list_decl(subel, dom)
                else:
                    name, modifier = subel
                    if name == "#PCDATA":
                        mixed = 1
                        nnode = None
                    else:
                        nnode = dom.createElementNS(NS_XS, "xsd:element")
                        nnode.setAttributeNS(None, "ref", name)
                        min, max = {"" : (None, None),
                                    "?" : ("0", None),
                                    "*" : ("0", "unbounded")}[modifier]
                        min and nnode.setAttributeNS(None, "minOccurs", min)
                        max and nnode.setAttributeNS(None, "maxOccurs", max)
                if nnode:
                    listelem.appendChild(nnode)
        else:
            raise ValueError("Unknow list type '%s'." % type)
        return listelem, mixed
    def _split_name(self, name):
        parts = string.split(name,":")
        if len(parts) == 1:
            return None, parts[0]
        elif len(parts) == 2:
            return parts[0], parts[1]
        else:
            raise IndexError("More than one colon in an element name.")

class DTDReprManager:
    def __init__(self):
        self.reprFilename = None
        self.dtdFilename = None
        self.useRepr = None
    def getDtdRepr(self):
        if self.useRepr:
            if not self.reprFilename:
                if self.dtdFilename[-4:] in [".dtd", ".DTD"]:
                    self.reprFilename = self.dtdFilename[:-4] + ".prp"
                else:
                    raise ValueError( \
                        "In order to autogenerate representation file " +
                        "name, the DTD file name needs to end in '.dtd'")
            if os.path.exists(self.reprFilename):
                return eval(open(self.reprFilename, "r").read())
        dtd = load_dtd(self.dtdFilename)
        sr = self.make_simple_repr(dtd)
        if self.useRepr:
            try:
                open(self.reprFilename, "w").write(repr(sr))
            except:
                print "Could not write representation file '%s'." % \
                      self.reprFilename
        return sr
    def make_simple_repr(self, dtd):
        res = {}
        res = {}
        elements = dtd.get_elements()
        for ename in elements:
            elem = dtd.get_elem(ename)
            res[ename] = {}
            res[ename]["content_model"] = elem.content_model_structure
            res[ename]["attributes"] = {}
            attrs = elem.get_attr_list()
            for aname in attrs:
                attr = elem.get_attr(aname)
                res[ename]["attributes"][aname] = {}
                res[ename]["attributes"][aname]["type"] = attr.get_type()
                res[ename]["attributes"][aname]["decl"] = attr.get_decl()
                res[ename]["attributes"][aname]["default"] = attr.get_default()
        return res
    
class CommandLineManager:
    """
Command line options:
pydtd2xsd.py [--dtd-file FILE] [--repr-file FILE]
    (--ns-uri (PREFIX | None) (URI | None))*
    (--ns-schema-file (PREFIX | None) FILE)*
     [--no-ns-workaround] [--use-repr-file]

 --dtd-file FILE: Load the DTD from FILE.
 --repr-file FILE: Try to load the DTD from the compiled DTD in FILE.
        If it doesn't exist, try to create FILE. 
        CAUTION: WITH THIS OPTION ARBITRARY PYTHON CODE FROM FILE WILL
        BE EXECUTED. USE ONLY IF YOU UNDERSTAND THE SECURITY IMPLICATIONS.
 --use-repr-file: Use the standard (.prp) file name for the compiled DTD.
        This greatly improves speed for repeated invocatio, but SEE THE
        CAUTIONARY NOTE ABOVE.
 --no-ns-workaround: For namespaceless elements write 
        <xsd:import namespace='' schemaLocation='FILE'/> instead of
        leaving the namespace attribute away. This is necessary for some
        programs like Syntext Serna.
 --ns-uri (PREFIX | None) (URI | None): Associate the PREFIX with its
        corresponding namespace URI. For both you may use the special
        value 'None' (without quotes) which means no prefix and no
        namespace URI association, respectively.
 --ns-schema-file (PREFIX | None) FILE: Write the schema for the namespace
        associated with PREFIX to FILE. PREFIX may also be 'None', see
        above.
"""
    def run(self):
        print STARTUP_STR
        if len(sys.argv) <= 1:
            print self.__doc__
        else:
            self.run1()
    def run1(self):
        self.sgen = XMLSchemaGenerator()
        self.repm = DTDReprManager()
        self.parseArguments()
        dr = self.repm.getDtdRepr()
        self.sgen.parse_dtdrepr(dr)
        self.sgen.write_results()
    def parseArguments(self):
        args = deepcopy(sys.argv[1:])
        while args:
            if args[0] == "--dtd-file":
                self.repm.dtdFilename = args[1]
                args = args[2:]
            elif args[0] == "--repr-file":
                self.repm.reprFilename = args[1]
                self.repm.useRepr = 1
                args = args[2:]
            elif args[0] == "--ns-uri":
                prefix = args[1]
                if prefix == "None":
                    prefix = None
                xmlns = args[2]
                if xmlns == "None":
                    xmlns = None
                self.sgen.ns_uris[prefix] = xmlns
                args = args[3:]
            elif args[0] == "--ns-schema-file":
                prefix = args[1]
                if prefix == "None":
                    prefix = None
                xmlns = args[2]
                if xmlns == "None":
                    xmlns = None
                self.sgen.filepaths[prefix] = xmlns
                args = args[3:]
            elif args[0] == "--no-ns-workaround":
                self.sgen.nonsworkaround = 1
                args = args[1:]
            elif args[0] == "--use-repr-file":
                self.repm.useRepr = 1
                args = args[1:]
            else:
                raise ValueError("Unknown argument '%s'." % args[0])

# -- main

CommandLineManager().run()
