Source code for ncdiff.model

import math
import os
import re
import queue
import logging

from lxml import etree
from copy import deepcopy
from ncclient import operations
from threading import Thread, current_thread
from pyang import statements
try:
    from pyang.repository import FileRepository
except ImportError:
    from pyang import FileRepository
try:
    from pyang.context import Context
except ImportError:
    from pyang import Context

from .errors import ModelError


# create a logger for this module
logger = logging.getLogger(__name__)

PARSER = etree.XMLParser(encoding='utf-8', remove_blank_text=True)


def write_xml(filename, element):
    element_tree = etree.ElementTree(element)
    element_tree.write(
        filename,
        encoding='utf-8',
        xml_declaration=True,
        pretty_print=True,
        with_tail=False,
    )


def read_xml(filename):
    if os.path.isfile(filename):
        try:
            element_tree = etree.parse(filename, parser=PARSER)
            return element_tree.getroot()
        except Exception:
            return None
    else:
        return None


[docs]class Model(object): '''Model Abstraction of a YANG module. It supports str() which returns a string similar to the output of 'pyang -f tree'. Attributes ---------- name : `str` Model name. prefix : `str` Prefix of the model. prefixes : `dict` All prefixes used in the model. Dictionary keys are prefixes, and values are URLs. url : `str` URL of the model. urls : `dict` All URLs used in the model. Dictionary keys are URLs, and values are prefixes. tree : `Element` The model tree as an Element object. roots : `list` All root nodes of the model. Each node is an Element object. width : `dict` This is used to facilitate pretty print of a model. Dictionary keys are nodes in the model tree, and values are indents. ''' def __init__(self, tree): ''' __init__ instantiates a Model instance. ''' self.tree = tree self.name = tree.tag ns = tree.findall('namespace') self.prefixes = {c.attrib['prefix']: c.text for c in ns} self.prefix = tree.attrib['prefix'] self.url = self.prefixes[self.prefix] self.urls = {v: k for k, v in self.prefixes.items()} self.convert_tree() self.width = {} def __str__(self): return self.emit_tree(self.tree) @property def roots(self): return [c.tag for c in self.tree]
[docs] def emit_tree(self, tree): '''emit_tree High-level api: Emit a string presentation of the model. Parameters ---------- tree : `Element` The model. Returns ------- str A string presentation of the model that is very similar to the output of 'pyang -f tree' ''' ret = [] ret.append('module: {}'.format(tree.tag)) ret += self.emit_children(tree, type='other') rpc_lines = self.emit_children(tree, type='rpc') if rpc_lines: ret += ['', ' rpcs:'] + rpc_lines notification_lines = self.emit_children(tree, type='notification') if notification_lines: ret += ['', ' notifications:'] + notification_lines return '\n'.join(ret)
[docs] def emit_children(self, tree, type='other'): '''emit_children High-level api: Emit a string presentation of a part of the model. Parameters ---------- tree : `Element` The model. type : `str` Type of model content required. Its value can be 'other', 'rpc', or 'notification'. Returns ------- str A string presentation of the model that is very similar to the output of 'pyang -f tree' ''' def is_type(element, type): type_info = element.get('type') if type == type_info: return True if type == 'rpc' or type == 'notification': return False if type_info == 'rpc' or type_info == 'notification': return False return True ret = [] for root in [i for i in tree if is_type(i, type)]: for i in root.iter(): line = self.get_depth_str(i, type=type) name_str = self.get_name_str(i) room_consumed = len(name_str) line += name_str if i.get('type') == 'anyxml' or \ i.get('type') == 'anydata' or \ i.get('datatype') is not None or \ i.get('if-feature') is not None: line += self.get_datatype_str(i, room_consumed) ret.append(line) return ret
[docs] def get_width(self, element): '''get_width High-level api: Calculate how much indent is needed for a node. Parameters ---------- element : `Element` A node in model tree. Returns ------- int Start position from the left margin. ''' parent = element.getparent() if parent in self.width: return self.width[parent] ret = 0 for sibling in parent: w = len(self.get_name_str(sibling)) if w > ret: ret = w self.width[parent] = math.ceil((ret + 3) / 3.0) * 3 return self.width[parent]
[docs] @staticmethod def get_depth_str(element, type='other'): '''get_depth_str High-level api: Produce a string that represents tree hierarchy. Parameters ---------- element : `Element` A node in model tree. type : `str` Type of model content required. Its value can be 'other', 'rpc', or 'notification'. Returns ------- str A string that represents tree hierarchy. ''' def following_siblings(element, type): if type == 'rpc' or type == 'notification': return [s for s in list(element.itersiblings()) if s.get('type') == type] else: return [ s for s in list(element.itersiblings()) if ( s.get('type') != 'rpc' and s.get('type') != 'notification' ) ] ancestors = list(reversed(list(element.iterancestors()))) ret = ' ' for i, ancestor in enumerate(ancestors): if i == 1: if following_siblings(ancestor, type): ret += '| ' else: ret += ' ' else: if ancestor.getnext() is None: ret += ' ' else: ret += '| ' ret += '+--' return ret
[docs] @staticmethod def get_flags_str(element): '''get_flags_str High-level api: Produce a string that represents the type of a node. Parameters ---------- element : `Element` A node in model tree. Returns ------- str A string that represents the type of a node. ''' type_info = element.get('type') if type_info == 'rpc' or type_info == 'action': return '-x' elif type_info == 'notification': return '-n' access_info = element.get('access') if access_info is None: return '' elif access_info == 'write': return '-w' elif access_info == 'read-write': return 'rw' elif access_info == 'read-only': return 'ro' else: return '--'
[docs] def get_name_str(self, element): '''get_name_str High-level api: Produce a string that represents the name of a node. Parameters ---------- element : `Element` A node in model tree. Returns ------- str A string that represents the name of a node. ''' name = self.remove_model_prefix(self.url_to_prefix(element.tag)) flags = self.get_flags_str(element) type_info = element.get('type') if type_info is None: pass elif type_info == 'choice': if element.get('mandatory') == 'true': return flags + ' ({})'.format(name) else: return flags + ' ({})?'.format(name) elif type_info == 'case': return ':({})'.format(name) elif type_info == 'container': return flags + ' {}'.format(name) elif ( type_info == 'leaf' or type_info == 'anyxml' or type_info == 'anydata' ): if element.get('mandatory') == 'true': return flags + ' {}'.format(name) else: return flags + ' {}?'.format(name) elif type_info == 'list': if element.get('key') is not None: return flags + ' {}* [{}]'.format(name, element.get('key')) else: return flags + ' {}*'.format(name) elif type_info == 'leaf-list': return flags + ' {}*'.format(name) else: return flags + ' {}'.format(name)
[docs] def get_datatype_str(self, element, length): '''get_datatype_str High-level api: Produce a string that indicates the data type of a node. Parameters ---------- element : `Element` A node in model tree. length : `int` String length that has been consumed. Returns ------- str A string that indicates the data type of a node. ''' spaces = ' '*(self.get_width(element) - length) type_info = element.get('type') ret = '' if type_info == 'anyxml' or type_info == 'anydata': ret = spaces + '<{}>'.format(type_info) elif element.get('datatype') is not None: ret = spaces + element.get('datatype') if element.get('if-feature') is not None: return ret + ' {' + element.get('if-feature') + '}?' else: return ret
[docs] def prefix_to_url(self, id): '''prefix_to_url High-level api: Convert an identifier from `prefix:tagname` notation to `{namespace}tagname` notation. If the identifier does not have a prefix, it is assumed that the whole identifier is a tag name. Parameters ---------- id : `str` Identifier in `prefix:tagname` notation. Returns ------- str Identifier in `{namespace}tagname` notation. ''' parts = id.split(':') if len(parts) > 1: return '{' + self.prefixes[parts[0]] + '}' + parts[1] else: return '{' + self.url + '}' + id
[docs] def url_to_prefix(self, id): '''url_to_prefix High-level api: Convert an identifier from `{namespace}tagname` notation to `prefix:tagname` notation. If the identifier does not have a namespace, it is assumed that the whole identifier is a tag name. Parameters ---------- id : `str` Identifier in `{namespace}tagname` notation. Returns ------- str Identifier in `prefix:tagname` notation. ''' ret = re.search('^{(.+)}(.+)$', id) if ret: return self.urls[ret.group(1)] + ':' + ret.group(2) else: return id
[docs] def remove_model_prefix(self, id): '''remove_model_prefix High-level api: If prefix is the model prefix, return tagname without prefix. If prefix is not the model prefix, simply return the identifier without modification. Parameters ---------- id : `str` Identifier in `prefix:tagname` notation. Returns ------- str Identifier in `prefix:tagname` notation if prefix is not the model prefix. Or identifier in `tagname` notation if prefix is the model prefix. ''' reg_str = '^' + self.prefix + ':(.+)$' ret = re.search(reg_str, id) if ret: return ret.group(1) else: return id
[docs] def convert_tree(self): '''convert_tree High-level api: Convert cxml tree to an internal schema tree. Parameters ---------- None Returns ------- Element This is the tree after convertion. ''' for ns in self.tree.findall('namespace'): self.tree.remove(ns)
class DownloadWorker(Thread): def __init__(self, downloader): Thread.__init__(self) self.downloader = downloader def run(self): while not self.downloader.download_queue.empty(): try: module = self.downloader.download_queue.get(timeout=0.01) except queue.Empty: pass else: self.downloader.download(module) self.downloader.download_queue.task_done() logger.debug('Thread {} exits'.format(current_thread().name)) class ContextWorker(Thread): def __init__(self, context): Thread.__init__(self) self.context = context def run(self): varnames = Context.add_module.__code__.co_varnames while not self.context.modulefile_queue.empty(): try: modulefile = self.context.modulefile_queue.get(timeout=0.01) except queue.Empty: pass else: with open(modulefile, 'r', encoding='utf-8') as f: text = f.read() kwargs = { 'ref': modulefile, 'text': text, } if 'primary_module' in varnames: kwargs['primary_module'] = True if 'format' in varnames: kwargs['format'] = 'yang' if 'in_format' in varnames: kwargs['in_format'] = 'yang' module_statement = self.context.add_module(**kwargs) self.context.update_dependencies(module_statement) self.context.modulefile_queue.task_done() logger.debug('Thread {} exits'.format(current_thread().name)) class CompilerContext(Context): def __init__(self, repository): Context.__init__(self, repository) self.dependencies = None self.modulefile_queue = None if 'prune' in dir(statements.Statement): self.num_threads = 2 else: self.num_threads = 1 def _get_latest_revision(self, modulename): latest = None for module_name, module_revision in self.modules: if module_name == modulename and ( latest is None or module_revision > latest ): latest = module_revision return latest def get_statement(self, modulename, xpath=None): revision = self._get_latest_revision(modulename) if revision is None: return None if xpath is None: return self.modules[(modulename, revision)] # in order to follow the Xpath, the module is required to be validated node_statement = self.modules[(modulename, revision)] if node_statement.i_is_validated is not True: return None # xpath is given, so find the node statement xpath_list = xpath.split('/') # only absolute Xpaths are supported if len(xpath_list) < 2: return None if ( xpath_list[0] == '' and xpath_list[1] == '' or xpath_list[0] != '' ): return None # find the node statement root_prefix = node_statement.i_prefix for n in xpath_list[1:]: node_statement = self.get_child(root_prefix, node_statement, n) if node_statement is None: return None return node_statement def get_child(self, root_prefix, parent, child_id): child_id_list = child_id.split(':') if len(child_id_list) > 1: children = [ c for c in parent.i_children if c.arg == child_id_list[1] and c.i_module.i_prefix == child_id_list[0] ] elif len(child_id_list) == 1: children = [ c for c in parent.i_children if c.arg == child_id_list[0] and c.i_module.i_prefix == root_prefix ] return children[0] if children else None def update_dependencies(self, module_statement): if self.dependencies is None: self.dependencies = etree.Element('modules') for m in [ m for m in self.dependencies if m.attrib.get('id') == module_statement.arg ]: self.dependencies.remove(m) module_node = etree.SubElement(self.dependencies, 'module') module_node.set('id', module_statement.arg) module_node.set('type', module_statement.keyword) if module_statement.keyword == 'module': statement = module_statement.search_one('prefix') if statement is not None: module_node.set('prefix', statement.arg) statement = module_statement.search_one("namespace") if statement is not None: namespace = etree.SubElement(module_node, 'namespace') namespace.text = statement.arg if module_statement.keyword == 'submodule': statement = module_statement.search_one("belongs-to") if statement is not None: belongs_to = etree.SubElement(module_node, 'belongs-to') belongs_to.set('module', statement.arg) dependencies = set() for parent_node_name, child_node_name, attr_name in [ ('includes', 'include', 'module'), ('imports', 'import', 'module'), ('revisions', 'revision', 'date'), ]: parent = etree.SubElement(module_node, parent_node_name) statements = module_statement.search(child_node_name) if statements: for statement in statements: child = etree.SubElement(parent, child_node_name) child.set(attr_name, statement.arg) if child_node_name in ['include', 'import']: dependencies.add(statement.arg) return dependencies def write_dependencies(self): dependencies_file = os.path.join( self.repository.dirs[0], 'dependencies.xml', ) write_xml(dependencies_file, self.dependencies) def read_dependencies(self): dependencies_file = os.path.join( self.repository.dirs[0], 'dependencies.xml', ) self.dependencies = read_xml(dependencies_file) def load_context(self): self.modulefile_queue = queue.Queue() for filename in os.listdir(self.repository.dirs[0]): if filename.lower().endswith('.yang'): filepath = os.path.join(self.repository.dirs[0], filename) self.modulefile_queue.put(filepath) for x in range(self.num_threads): worker = ContextWorker(self) worker.daemon = True worker.name = 'context_worker_{}'.format(x) worker.start() self.modulefile_queue.join() self.write_dependencies() def validate_context(self): revisions = {} for mudule_name, module_revision in self.modules: if mudule_name not in revisions or ( mudule_name in revisions and revisions[mudule_name] < module_revision ): revisions[mudule_name] = module_revision self.validate() if 'prune' in dir(statements.Statement): for mudule_name, module_revision in revisions.items(): self.modules[(mudule_name, module_revision)].prune() def internal_reset(self): self.modules = {} self.revs = {} self.errors = [] for mod, rev, handle in self.repository.get_modules_and_revisions( self): if mod not in self.revs: self.revs[mod] = [] revs = self.revs[mod] revs.append((rev, handle))
[docs]class ModelDownloader(object): '''ModelDownloader Abstraction of a Netconf schema downloader. Attributes ---------- device : `ModelDevice` Model name. dir_yang : `str` Path to yang files. yang_capabilities : `str` Path to capabilities.txt file in the folder of yang files. need_download : `bool` True if the content of capabilities.txt file disagrees with device capabilities exchange. False otherwise. ''' def __init__(self, nc_device, folder): ''' __init__ instantiates a ModelDownloader instance. ''' self.device = nc_device self.dir_yang = os.path.abspath(folder) if not os.path.isdir(self.dir_yang): os.makedirs(self.dir_yang) self.yang_capabilities = os.path.join( self.dir_yang, 'capabilities.txt', ) repo = FileRepository(path=self.dir_yang) self.context = CompilerContext(repository=repo) self.download_queue = queue.Queue() self.num_threads = 2 @property def need_download(self): if os.path.isfile(self.yang_capabilities): with open(self.yang_capabilities, 'r') as f: c = f.read() if c == '\n'.join(sorted(list(self.device.server_capabilities))): return False return True
[docs] def download_all(self, check_before_download=True): '''download_all High-level api: Convert cxml tree to an internal schema tree. This method is recursive. Parameters ---------- check_before_download : `bool` True if checking capabilities.txt file is required. Returns ------- None Nothing returns. ''' # check the content of self.yang_capabilities if check_before_download and not self.need_download: logger.info('Skip downloading as the content of {} ' 'matches device hello message' .format(self.yang_capabilities)) return # clean up folder self.dir_yang for root, dirs, files in os.walk(self.dir_yang): for f in files: os.remove(os.path.join(root, f)) # download all self.to_be_downloaded = set(self.device.models_loadable) self.context.dependencies = etree.Element('modules') for module in sorted(list(self.to_be_downloaded)): self.download_queue.put(module) for x in range(self.num_threads): worker = DownloadWorker(self) worker.daemon = True worker.name = 'download_worker_{}'.format(x) worker.start() self.download_queue.join() # write self.yang_capabilities capabilities = '\n'.join(sorted(list(self.device.server_capabilities))) with open(self.yang_capabilities, 'wb') as f: f.write(capabilities.encode('utf-8')) # write dependencies self.context.write_dependencies()
[docs] def download(self, module): '''download High-level api: Download a module schema. Parameters ---------- module : `str` Module name that will be downloaded. Returns ------- None Nothing returns. ''' logger.debug('Downloading {}.yang...'.format(module)) try: from .manager import ModelDevice reply = super(ModelDevice, self.device).execute( operations.retrieve.GetSchema, module, ) except operations.rpc.RPCError: logger.warning("Module or submodule '{}' cannot be downloaded" .format(module)) return if reply.ok: varnames = Context.add_module.__code__.co_varnames fname = os.path.join(self.dir_yang, module+'.yang') with open(fname, 'wb') as f: f.write(reply.data.encode('utf-8')) kwargs = { 'ref': fname, 'text': reply.data, } if 'primary_module' in varnames: kwargs['primary_module'] = True if 'format' in varnames: kwargs['format'] = 'yang' if 'in_format' in varnames: kwargs['in_format'] = 'yang' module_statement = self.context.add_module(**kwargs) dependencies = self.context.update_dependencies(module_statement) s = dependencies - self.to_be_downloaded if s: logger.info('{} requires submodules: {}' .format(module, ', '.join(s))) self.to_be_downloaded.update(s) for m in s: self.download_queue.put(m) else: logger.warning("module or submodule '{}' cannot be downloaded:\n{}" .format(module, reply._raw))
[docs]class ModelCompiler(object): '''ModelCompiler Abstraction of a YANG file compiler. Attributes ---------- dir_yang : `str` Path to yang files. dependencies : `Element` Dependency infomation stored in an Element object. context : `CompilerContext` A CompilerContext object that holds the context of all modules. module_prefixes : `dict` A dictionary that stores module prefixes. It is keyed by module names. module_namespaces : `dict` A dictionary that stores module namespaces. It is keyed by module names. identity_deps : `dict` A dictionary that stores module identities. It is keyed by bases. pyang_errors : `list` A list of tuples. Each tuple contains a pyang error.Position object, an error tag and a tuple of some error arguments. It is possible to call pyang.error.err_to_str() to print out detailed error messages. ''' def __init__(self, folder): ''' __init__ instantiates a ModelCompiler instance. ''' self.dir_yang = os.path.abspath(folder) self.context = None self.module_prefixes = {} self.module_namespaces = {} self.identity_deps = {} self.build_dependencies() @property def pyang_errors(self): if self.context is None: return [] else: return self.context.errors def _read_from_cache(self, name): cached_name = os.path.join(self.dir_yang, name + ".xml") return read_xml(cached_name) def _write_to_cache(self, name, element): cached_name = os.path.join(self.dir_yang, name + ".xml") write_xml(cached_name, element)
[docs] def build_dependencies(self): '''build_dependencies High-level api: Briefly compile all yang files and find out dependency infomation of all modules. Returns ------- None Nothing returns. ''' if self.context is None: repo = FileRepository(path=self.dir_yang) self.context = CompilerContext(repository=repo) if self.context.dependencies is None: self.context.read_dependencies() if self.context.dependencies is None: self.context.load_context()
[docs] def get_dependencies(self, module): '''get_dependencies High-level api: Get dependency infomationa of a module. Parameters ---------- module : `str` Module name that is inquired about. Returns ------- tuple A tuple with two elements: a set of imports and a set of depends. ''' if self.context is None or self.context.dependencies is None: self.build_dependencies() dependencies = self.context.dependencies imports = set() for m in list(filter(lambda i: i.get('id') == module, dependencies.findall('./module'))): imports.update(set(i.get('module') for i in m.findall('./imports/import'))) depends = set() for m in dependencies: if list(filter(lambda i: i.get('module') == module, m.findall('./imports/import'))): depends.add(m.get('id')) if list(filter(lambda i: i.get('module') == module, m.findall('./includes/include'))): depends.add(m.get('id')) return (imports, depends)
[docs] def compile(self, module): '''compile High-level api: Compile a module. Parameters ---------- module : `str` Module name that is inquired about. Returns ------- Model A Model object. ''' cached_tree = self._read_from_cache(module) if cached_tree is not None: return Model(cached_tree) varnames = Context.add_module.__code__.co_varnames imports, depends = self.get_dependencies(module) required_module_set = imports | depends required_module_set.add(module) self.context.internal_reset() for m in required_module_set: modulefile = os.path.join(self.context.repository.dirs[0], m + '.yang') if os.path.isfile(modulefile): with open(modulefile, 'r', encoding='utf-8') as f: text = f.read() kwargs = { 'ref': modulefile, 'text': text, } if 'primary_module' in varnames: kwargs['primary_module'] = True if 'format' in varnames: kwargs['format'] = 'yang' if 'in_format' in varnames: kwargs['in_format'] = 'yang' self.context.add_module(**kwargs) self.context.validate_context() vm = self.context.get_module(module) st = etree.Element(vm.arg) st.set('type', 'module') statement = vm.search_one('prefix') if statement is not None: st.set('prefix', statement.arg) for m_statement in self.context.modules.values(): if m_statement.keyword == 'module': namespace = etree.SubElement(st, 'namespace') namespace.set('prefix', m_statement.i_prefix) statement = m_statement.search_one('namespace') if statement is not None: namespace.text = statement.arg self.module_namespaces[m_statement.i_modulename] = \ statement.arg etree.register_namespace( m_statement.i_prefix, statement.arg) # prepare self.module_prefixes self.module_prefixes[m_statement.i_modulename] = \ m_statement.i_prefix # prepare self.identity_deps for idn in m_statement.i_identities.values(): curr_idn = m_statement.arg + ':' + idn.arg base_idn = idn.search_one("base") if base_idn is None: # identity does not have a base self.identity_deps.setdefault(curr_idn, []) else: # identity has a base base_idns = base_idn.arg.split(':') if len(base_idns) > 1: # base is located in another module mn = m_statement.i_prefixes.get(base_idns[0]) b_idn = base_idns[1] if mn is None \ else mn[0] + ':' + base_idns[1] else: b_idn = module + ':' + base_idn.arg if self.identity_deps.get(b_idn) is None: self.identity_deps.setdefault(b_idn, []) else: self.identity_deps[b_idn].append(curr_idn) for child in vm.i_children: if child.keyword in statements.data_definition_keywords: self.depict_a_schema_node(vm, st, child) for child in vm.i_children: if child.keyword == 'rpc': self.depict_a_schema_node(vm, st, child, mode='rpc') for child in vm.i_children: if child.keyword == 'notification': self.depict_a_schema_node(vm, st, child, mode='notification') self._write_to_cache(module, st) return Model(st)
def depict_a_schema_node(self, module, parent, child, mode=None): n = etree.SubElement( parent, '{' + self.module_namespaces[child.i_module.i_modulename] + '}' + child.arg) self.set_access(child, n, mode) n.set('type', child.keyword) sm = child.search_one('status') if sm is not None and sm.arg in ['deprecated', 'obsolete']: n.set('status', sm.arg) sm = child.search_one('default') if sm is not None: n.set('default', sm.arg) if child.keyword == 'list': sm = child.search_one('key') if sm is not None: n.set('key', sm.arg) sm = child.search_one('ordered-by') if sm is not None and sm.arg == 'user': n.set('ordered-by', 'user') elif child.keyword == 'container': sm = child.search_one('presence') if sm is not None: n.set('presence', 'true') elif child.keyword == 'choice': sm = child.search_one('mandatory') if sm is not None and sm.arg == 'true': n.set('mandatory', 'true') cases = [c.arg for c in child.search('case')] if cases: n.set('values', '|'.join(cases)) elif child.keyword in ['leaf', 'leaf-list']: self.set_leaf_datatype_value(child, n) sm = child.search_one('mandatory') if ( sm is not None and sm.arg == 'true' or hasattr(child, 'i_is_key') ): n.set('mandatory', 'true') if hasattr(child, 'i_is_key'): n.set('is_key', 'true') if child.keyword == 'leaf-list': sm = child.search_one('ordered-by') if sm is not None and sm.arg == 'user': n.set('ordered-by', 'user') featurenames = [f.arg for f in child.search('if-feature')] if hasattr(child, 'i_augment'): featurenames.extend([ f.arg for f in child.i_augment.search('if-feature') if f.arg not in featurenames ]) if featurenames: n.set('if-feature', ' '.join(featurenames)) if hasattr(child, 'i_children'): for c in child.i_children: if mode == 'rpc' and c.keyword in ['input', 'output']: self.depict_a_schema_node(module, n, c, mode=c.keyword) else: self.depict_a_schema_node(module, n, c, mode=mode) @staticmethod def set_access(statement, node, mode): if ( mode in ['input', 'rpc'] or statement.keyword == 'rpc' or statement.keyword == ('tailf-common', 'action') ): node.set('access', 'write') elif ( mode in ['output', 'notification'] or statement.keyword == 'notification' ): node.set('access', 'read-only') elif hasattr(statement, 'i_config') and statement.i_config: node.set('access', 'read-write') else: node.set('access', 'read-only') def set_leaf_datatype_value(self, leaf_statement, leaf_node): sm = leaf_statement.search_one('type') if sm is None: datatype = '' else: if sm.arg == 'leafref': p = sm.search_one('path') if p is not None: # Try to make the path as compact as possible. # Remove local prefixes, and only use prefix when # there is a module change in the path. target = [] curprefix = leaf_statement.i_module.i_prefix for name in p.arg.split('/'): if name.find(":") == -1: prefix = curprefix else: [prefix, name] = name.split(':', 1) if prefix == curprefix: target.append(name) else: target.append(prefix + ':' + name) curprefix = prefix datatype = "-> %s" % "/".join(target) else: datatype = sm.arg elif sm.arg == 'identityref': idn_base = sm.search_one('base') datatype = sm.arg + ":" + idn_base.arg else: datatype = sm.arg leaf_node.set('datatype', datatype) type_values = self.type_values(sm) if type_values: leaf_node.set('values', type_values) if sm.arg == 'union': leaf_node.set( 'unionmembertypes', '|'.join([m.arg for m in sm.search('type')]) ) def type_values(self, type_statement): if type_statement is None: return '' if ( type_statement.i_is_derived is False and type_statement.i_typedef is not None ): return self.type_values( type_statement.i_typedef.search_one('type')) if type_statement.arg == 'boolean': return 'true|false' if type_statement.arg == 'union': return self.type_union_values(type_statement) if type_statement.arg == 'enumeration': return '|'.join([e.arg for e in type_statement.search('enum')]) if type_statement.arg == 'identityref': return self.type_identityref_values(type_statement) return '' def type_union_values(self, type_statement): vlist = [] for type in type_statement.search('type'): v = self.type_values(type) if v: vlist.append(v) return '|'.join(vlist) def type_identityref_values(self, type_statement): base_idn = type_statement.search_one('base') if base_idn: # identity has a base base_idns = base_idn.arg.split(':') my_modulename = type_statement.i_module.i_modulename if len(base_idns) > 1: modulename = \ type_statement.i_module.i_prefixes.get(base_idns[0]) if modulename is None: return '' else: idn_key = modulename[0] + ':' + base_idns[1] else: idn_key = my_modulename + ':' + base_idn.arg value_stmts = [] values = self.identity_deps.get(idn_key, []) for value in values: ids = value.split(':') value_stmts.append(self.module_prefixes[ids[0]] + ':' + ids[1]) if values: return '|'.join(value_stmts) return ''
[docs]class ModelDiff(object): '''ModelDiff Abstraction of differences between two Model instances. It supports str() which returns a string illustrating the differences between model1 and model2. Attributes ---------- model1 : `Model` First Model instance. model2 : `Model` Second Model instance. tree : `Element` The model difference tree as an Element object. added : `str` A string presentation of added nodes from model1 to model2. deleted : `str` A string presentation of deleted nodes from model1 to model2. modified : `str` A string presentation of modified nodes from model1 to model2. width : `dict` This is used to facilitate pretty print of a model. Dictionary keys are nodes in the model tree, and values are indents. ''' __str__ = Model.__str__ emit_tree = Model.emit_tree get_width = Model.get_width def __init__(self, model1, model2): ''' __init__ instantiates a Model instance. ''' self.model1 = model1 self.model2 = model2 self.width = {} if model1.tree.tag == model2.tree.tag: self.tree = etree.Element(model1.tree.tag) if id(self.model1) != id(self.model2): self.compare_nodes(model1.tree, model2.tree, self.tree) else: raise ValueError("cannot generate diff of different modules: " "'{}' vs '{}'" .format(model1.tree.tag, model2.tree.tag)) def __bool__(self): if list(self.tree): return True else: return False @property def added(self): tree_added = deepcopy(self.tree) if self.trim(tree_added, 'added'): return None else: return self.emit_tree(tree_added) @property def deleted(self): tree_deleted = deepcopy(self.tree) if self.trim(tree_deleted, 'deleted'): return None else: return self.emit_tree(tree_deleted) @property def modified(self): tree_modified = deepcopy(self.tree) if self.trim(tree_modified, 'modified'): return None else: return self.emit_tree(tree_modified)
[docs] def compare(self, xpath): '''compare High-level api: Return a string presentation of comparison between the node in model1 and model2. Parameters ---------- xpath : `str` XPATH to locate a node. Returns ------- str A string presentation of comparison. ''' def print_node(node_list): if not node_list: return 'None' node = node_list[0] ret = ["tag: '{}'".format(node.tag), "text: {}".format(print_value(node.text)), "attributes:"] for a, v in node.attrib.items(): ret.append(" {} = {}".format(a, print_value(v))) return '\n'.join(ret) def print_value(value): if value is None: return 'None' else: return "'{}'".format(value) prefixes = deepcopy(self.model1.prefixes) prefixes.update(self.model2.prefixes) node1_list = self.model1.tree.xpath(xpath, namespaces=prefixes) node2_list = self.model2.tree.xpath(xpath, namespaces=prefixes) return '\n'.join(['-'*21 + ' XPATH ' + '-'*21, "'{}'".format(xpath), '-'*20 + ' MODEL 1 ' + '-'*20, print_node(node1_list), '-'*20 + ' MODEL 2 ' + '-'*20, print_node(node2_list), '-'*49, ''])
[docs] def emit_children(self, tree, type='other'): '''emit_children High-level api: Emit a string presentation of a part of the model. Parameters ---------- tree : `Element` The model. type : `str` Type of model content required. Its value can be 'other', 'rpc', or 'notification'. Returns ------- str A string presentation of the model that is very similar to the output of 'pyang -f tree'. ''' def is_type(element, type): type_info = element.get('type') if type == type_info: return True if type == 'rpc' or type == 'notification': return False if type_info == 'rpc' or type_info == 'notification': return False return True ret = [] for root in [i for i in tree if is_type(i, type)]: for i in root.iter(): line = Model.get_depth_str(i, type=type) name_str = self.get_name_str(i) room_consumed = len(name_str) line += name_str if i.get('diff') is not None: line += self.get_diff_str(i, room_consumed) ret.append(line) return ret
[docs] def get_name_str(self, element): '''get_name_str High-level api: Produce a string that represents the name of a node. Parameters ---------- element : `Element` A node in model tree. Returns ------- str A string that represents the name of a node. ''' if element.get('diff') == 'added': return self.model2.get_name_str(element) else: return self.model1.get_name_str(element)
[docs] def get_diff_str(self, element, length): '''get_diff_str High-level api: Produce a string that indicates the difference between two models. Parameters ---------- element : `Element` A node in model tree. length : `int` String length that has been consumed. Returns ------- str A string that indicates the difference between two models. ''' spaces = ' '*(self.get_width(element) - length) return spaces + element.get('diff')
[docs] @staticmethod def compare_nodes(node1, node2, ret): '''compare_nodes High-level api: Compare node1 and node2 and put the result in ret. Parameters ---------- node1 : `Element` A node in a model tree. node2 : `Element` A node in another model tree. ret : `Element` A node in self.tree. Returns ------- None Nothing returns. ''' for child in node2: peer = ModelDiff.get_peer(child.tag, node1) if peer is None: ModelDiff.copy_subtree(ret, child, 'added') else: if ModelDiff.node_equal(peer, child): continue else: if child.attrib['type'] in ['leaf-list', 'leaf']: ModelDiff.copy_node(ret, child, 'modified') else: ret_child = ModelDiff.copy_node(ret, child, '') ModelDiff.compare_nodes(peer, child, ret_child) for child in node1: peer = ModelDiff.get_peer(child.tag, node2) if peer is None: ModelDiff.copy_subtree(ret, child, 'deleted')
[docs] @staticmethod def copy_subtree(ret, element, msg): '''copy_subtree High-level api: Copy element as a subtree and put it as a child of ret. Parameters ---------- element : `Element` A node in a model tree. msg : `str` Message to be added. ret : `Element` A node in self.tree. Returns ------- None Nothing returns. ''' sub_element = ModelDiff.process_attrib(deepcopy(element), msg) ret.append(sub_element) return sub_element
[docs] @staticmethod def copy_node(ret, element, msg): '''copy_node High-level api: Copy element as a node without its children and put it as a child of ret. Parameters ---------- element : `Element` A node in a model tree. msg : `str` Message to be added. ret : `Element` A node in self.tree. Returns ------- None Nothing returns. ''' sub_element = etree.SubElement(ret, element.tag, attrib=element.attrib) ModelDiff.process_attrib(sub_element, msg) return sub_element
[docs] @staticmethod def process_attrib(element, msg): '''process_attrib High-level api: Delete four attributes from an ElementTree node if they exist: operation, insert, etc. Then a new attribute 'diff' is added. Parameters ---------- element : `Element` A node needs to be looked at. msg : `str` Message to be added in attribute 'diff'. Returns ------- Element Argument 'element' is returned after processing. ''' known_attrib = ['type', 'access', 'mandatory', 'presence', 'values', 'key', 'is_key', 'prefix', 'datatype', 'if-feature', 'ordered-by', 'default'] for node in element.iter(): for attrib in node.attrib.keys(): if attrib not in known_attrib: del node.attrib[attrib] if msg: node.attrib['diff'] = msg return element
[docs] @staticmethod def get_peer(tag, node): '''get_peer High-level api: Find all children under the node with the tag. Parameters ---------- tag : `str` A tag in `{namespace}tagname` notaion. node : `Element` A node to be looked at. Returns ------- Element or None None if not found. An Element object when found. ''' peers = node.findall(tag) if len(peers) < 1: return None elif len(peers) > 1: raise ModelError("not unique tag '{}'".format(tag)) else: return peers[0]
[docs] @staticmethod def node_equal(node1, node2): '''node_equal High-level api: Evaluate whether two nodes are equal. Parameters ---------- node1 : `Element` A node in a model tree. node2 : `Element` A node in another model tree. Returns ------- bool True if node1 and node2 are equal. ''' if ModelDiff.node_less(node1, node2) and \ ModelDiff.node_less(node2, node1): return True else: return False
[docs] @staticmethod def node_less(node1, node2): '''node_less Low-level api: Return True if all descendants of node1 exist in node2. Otherwise False. This is a recursive method. Parameters ---------- node1 : `Element` A node in a model tree. node2 : `Element` A node in another model tree. Returns ------- bool True if all descendants of node1 exist in node2, otherwise False. ''' for x in ['tag', 'text']: if node1.__getattribute__(x) != node2.__getattribute__(x): return False for a in node1.attrib: if a not in node2.attrib or \ node1.attrib[a] != node2.attrib[a]: return False for child in node1: peers = node2.findall(child.tag) if len(peers) < 1: return False elif len(peers) > 1: raise ModelError("not unique peer '{}'".format(child.tag)) else: if not ModelDiff.node_less(child, peers[0]): return False return True
[docs] @staticmethod def trim(parent, msg): '''trim Low-level api: Return True if parent has no child after trimming. The trimming to filter out one type of diff: added, deleted, or modified. Parameters ---------- parent : `Element` A node in a model tree. msg : `str` A type of diff: added, deleted, or modified. Returns ------- bool True if parent has no child after trimming. ''' for child in list(parent): diff = child.get('diff') type = child.get('type') if diff and diff != msg: parent.remove(child) elif ( type == 'container' or type == 'list' or type == 'choice' or type == 'case' ): if ModelDiff.trim(child, msg): parent.remove(child) return len(list(parent)) == 0