python/graph: rework shorten_vertex_names to not need edges
[amitay/samba.git] / python / samba / graph.py
1 # -*- coding: utf-8 -*-
2 # Graph topology utilities and dot file generation
3 #
4 # Copyright (C) Andrew Bartlett 2018.
5 #
6 # Written by Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
7 #
8 # This program is free software; you can redistribute it and/or modify
9 # it under the terms of the GNU General Public License as published by
10 # the Free Software Foundation; either version 3 of the License, or
11 # (at your option) any later version.
12 #
13 # This program is distributed in the hope that it will be useful,
14 # but WITHOUT ANY WARRANTY; without even the implied warranty of
15 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 # GNU General Public License for more details.
17 #
18 # You should have received a copy of the GNU General Public License
19 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
20
21 from __future__ import print_function
22 from __future__ import division
23 from samba import colour
24 import sys
25 from itertools import cycle, groupby
26
27 FONT_SIZE = 10
28
29
30 def reformat_graph_label(s):
31     """Break DNs over multiple lines, for better shaped and arguably more
32     readable nodes. We try to split after commas, and if necessary
33     after hyphens or failing that in arbitrary places."""
34     if len(s) < 12:
35         return s
36
37     s = s.replace(',', ',\n')
38     pieces = []
39     for p in s.split('\n'):
40         while len(p) > 20:
41             if '-' in p[2:20]:
42                 q, p = p.split('-', 1)
43             else:
44                 n = len(p) // 12
45                 b = len(p) // n
46                 q, p = p[:b], p[b:]
47             pieces.append(q + '-')
48         if p:
49             pieces.append(p)
50
51     return '\\n'.join(pieces)
52
53
54 def quote_graph_label(s, reformat=False):
55     """Escape a string as graphvis requires."""
56     # escaping inside quotes is simple in dot, because only " is escaped.
57     # there is no need to count backslashes in sequences like \\\\"
58     s = s.replace('"', '\"')
59     if reformat:
60         s = reformat_graph_label(s)
61     return "%s" % s
62
63
64 def shorten_vertex_names(vertices, suffix=',...', aggressive=False):
65     """Replace the common suffix (in practice, the base DN) of a number of
66     vertices with a short string (default ",..."). If this seems
67     pointless because the replaced string is very short or the results
68     seem strange, the original vertices are retained.
69
70     :param vertices: a sequence of vertices to shorten
71     :param suffix: the replacement string [",..."]
72     :param aggressive: replace certain common non-suffix strings
73
74     :return: tuple of (rename map, replacements)
75
76     The rename map is a dictionary mapping the old vertex names to
77     their shortened versions. If no changes are made, replacements
78     will be empty.
79     """
80     vmap = dict((v, v) for v in vertices)
81     replacements = []
82
83     if len(vmap) > 1:
84         # walk backwards along all the strings until we meet a character
85         # that is not shared by all.
86         i = -1
87         vlist = vmap.values()
88         try:
89             while True:
90                 c = set(x[i] for x in vlist)
91                 if len(c) > 1 or c == {'*'}:
92                     break
93                 i -= 1
94         except IndexError:
95             # We have indexed beyond the start of a string, which should
96             # only happen if one node is a strict suffix of all others.
97             return vmap, replacements
98
99         # add one to get to the last unanimous character.
100         i += 1
101
102         # now, we actually really want to split on a comma. So we walk
103         # back to a comma.
104         x = vlist[0]
105         while i < len(x) and x[i] != ',':
106             i += 1
107
108         if i >= -len(suffix):
109             # there is nothing to gain here
110             return vmap, replacements
111
112         replacements.append((suffix, x[i:]))
113
114         for k, v in vmap.items():
115             vmap[k] = v[:i] + suffix
116
117     if aggressive:
118         # Remove known common annoying strings
119         for v in vmap.values():
120             if ',CN=Servers,' not in v:
121                 break
122         else:
123             vmap = dict((k, v.replace(',CN=Servers,', ',**,', 1))
124                        for k, v in vmap.items())
125             replacements.append(('**', 'CN=Servers'))
126
127         for v in vmap.values():
128             if not v.startswith('CN=NTDS Settings,'):
129                 break
130         else:
131             vmap = dict((k, v.replace('CN=NTDS Settings,', '*,', 1))
132                        for k, v in vmap.items())
133             replacements.append(('*', 'CN=NTDS Settings'))
134
135     return vmap, replacements
136
137
138
139
140
141
142 def compile_graph_key(key_items, nodes_above=[], elisions=None,
143                       prefix='key_', width=2):
144     """Generate a dot file snippet that acts as a legend for a graph.
145
146     :param key_items: sequence of items (is_vertex, style, label)
147     :param nodes_above: list of vertices (pushes key into right position)
148     :param elision: tuple (short, full) indicating suffix replacement
149     :param prefix: string used to generate key node names ["key_"]
150     :param width: default width of node lines
151
152     Each item in key_items is a tuple of (is_vertex, style, label).
153     is_vertex is a boolean indicating whether the item is a vertex
154     (True) or edge (False). Style is a dot style string for the edge
155     or vertex. label is the text associated with the key item.
156     """
157     edge_lines = []
158     edge_names = []
159     vertex_lines = []
160     vertex_names = []
161     order_lines = []
162     for i, item in enumerate(key_items):
163         is_vertex, style, label = item
164         tag = '%s%d_' % (prefix, i)
165         label = quote_graph_label(label)
166         name = '%s_label' % tag
167
168         if is_vertex:
169             order_lines.append(name)
170             vertex_names.append(name)
171             vertex_lines.append('%s[label="%s"; %s]' %
172                                 (name, label, style))
173         else:
174             edge_names.append(name)
175             e1 = '%se1' % tag
176             e2 = '%se2' % tag
177             order_lines.append(name)
178             edge_lines.append('subgraph cluster_%s {' % tag)
179             edge_lines.append('%s[label=src; color="#000000"; group="%s_g"]' %
180                               (e1, tag))
181             edge_lines.append('%s[label=dest; color="#000000"; group="%s_g"]' %
182                               (e2, tag))
183             edge_lines.append('%s -> %s [constraint = false; %s]' % (e1, e2,
184                                                                      style))
185             edge_lines.append(('%s[shape=plaintext; style=solid; width=%f; '
186                                'label="%s\\r"]') %
187                               (name, width, label))
188             edge_lines.append('}')
189
190     elision_str = ''
191     if elisions:
192         for i, elision in enumerate(reversed(elisions)):
193             order_lines.append('elision%d' % i)
194             short, long = elision
195             if short[0] == ',' and long[0] == ',':
196                 short = short[1:]
197                 long = long[1:]
198             elision_str += ('\nelision%d[shape=plaintext; style=solid; '
199                             'label="\“%s”  means  “%s”\\r"]\n'
200                             % ((i, short, long)))
201
202     above_lines = []
203     if order_lines:
204         for n in nodes_above:
205             above_lines.append('"%s" -> %s [style=invis]' %
206                                (n, order_lines[0]))
207
208     s = ('subgraph cluster_key {\n'
209          'label="Key";\n'
210          'subgraph cluster_key_nodes {\n'
211          'label="";\n'
212          'color = "invis";\n'
213          '%s\n'
214          '}\n'
215          'subgraph cluster_key_edges {\n'
216          'label="";\n'
217          'color = "invis";\n'
218          '%s\n'
219          '{%s}\n'
220          '}\n'
221          '%s\n'
222          '}\n'
223          '%s\n'
224          '%s [style=invis; weight=9]'
225          '\n'
226          % (';\n'.join(vertex_lines),
227             '\n'.join(edge_lines),
228             ' '.join(edge_names),
229             elision_str,
230             ';\n'.join(above_lines),
231             ' -> '.join(order_lines),
232          ))
233
234     return s
235
236
237 def dot_graph(vertices, edges,
238               directed=False,
239               title=None,
240               reformat_labels=True,
241               vertex_colors=None,
242               edge_colors=None,
243               edge_labels=None,
244               vertex_styles=None,
245               edge_styles=None,
246               graph_name=None,
247               shorten_names=False,
248               key_items=None,
249               vertex_clusters=None):
250     """Generate a Graphviz representation of a list of vertices and edges.
251
252     :param vertices: list of vertex names (optional).
253     :param edges:    list of (vertex, vertex) pairs
254     :param directed: bool: whether the graph is directed
255     :param title: optional title for the graph
256     :param reformat_labels: whether to wrap long vertex labels
257     :param vertex_colors: if not None, a sequence of colours for the vertices
258     :param edge_colors: if not None, colours for the edges
259     :param edge_labels: if not None, labels for the edges
260     :param vertex_styles: if not None, DOT style strings for vertices
261     :param edge_styles: if not None, DOT style strings for edges
262     :param graph_name: if not None, name of graph
263     :param shorten_names: if True, remove common DN suffixes
264     :param key: (is_vertex, style, description) tuples
265     :param vertex_clusters: list of subgraph cluster names
266
267     Colour, style, and label lists must be the same length as the
268     corresponding list of edges or vertices (or None).
269
270     Colours can be HTML RGB strings ("#FF0000") or common names
271     ("red"), or some other formats you don't want to think about.
272
273     If `vertices` is None, only the vertices mentioned in the edges
274     are shown, and their appearance can be modified using the
275     vertex_colors and vertex_styles arguments. Vertices appearing in
276     the edges but not in the `vertices` list will be shown but their
277     styles can not be modified.
278     """
279     out = []
280     write = out.append
281
282     if vertices is None:
283         vertices = set(x[0] for x in edges) | set(x[1] for x in edges)
284
285     if shorten_names:
286         vlist = list(set(x[0] for x in edges) |
287                      set(x[1] for x in edges) |
288                      set(vertices))
289         vmap, elisions = shorten_vertex_names(vlist)
290         vertices = [vmap[x] for x in vertices]
291         edges = [(vmap[a], vmap[b]) for a, b in edges]
292
293     else:
294         elisions = None
295
296     if graph_name is None:
297         graph_name = 'A_samba_tool_production'
298
299     if directed:
300         graph_type = 'digraph'
301         connector = '->'
302     else:
303         graph_type = 'graph'
304         connector = '--'
305
306     write('/* generated by samba */')
307     write('%s %s {' % (graph_type, graph_name))
308     if title is not None:
309         write('label="%s";' % (title,))
310     write('fontsize=%s;\n' % (FONT_SIZE))
311     write('node[fontname=Helvetica; fontsize=%s];\n' % (FONT_SIZE))
312
313     prev_cluster = None
314     cluster_n = 0
315     quoted_vertices = []
316     for i, v in enumerate(vertices):
317         v = quote_graph_label(v, reformat_labels)
318         quoted_vertices.append(v)
319         attrs = []
320         if vertex_clusters and vertex_clusters[i]:
321             cluster = vertex_clusters[i]
322             if cluster != prev_cluster:
323                 if prev_cluster is not None:
324                     write("}")
325                 prev_cluster = cluster
326                 n = quote_graph_label(cluster)
327                 if cluster:
328                     write('subgraph cluster_%d {' % cluster_n)
329                     cluster_n += 1
330                     write('style = "rounded,dotted";')
331                     write('node [style="filled"; fillcolor=white];')
332                     write('label = "%s";' % n)
333
334         if vertex_styles and vertex_styles[i]:
335             attrs.append(vertex_styles[i])
336         if vertex_colors and vertex_colors[i]:
337             attrs.append('color="%s"' % quote_graph_label(vertex_colors[i]))
338         if attrs:
339             write('"%s" [%s];' % (v, ', '.join(attrs)))
340         else:
341             write('"%s";' % (v,))
342
343     if prev_cluster:
344         write("}")
345
346     for i, edge in enumerate(edges):
347         a, b = edge
348         if a is None:
349             a = "Missing source value"
350         if b is None:
351             b = "Missing destination value"
352
353         a = quote_graph_label(a, reformat_labels)
354         b = quote_graph_label(b, reformat_labels)
355
356         attrs = []
357         if edge_labels:
358             label = quote_graph_label(edge_labels[i])
359             attrs.append('label="%s"' % label)
360         if edge_colors:
361             attrs.append('color="%s"' % quote_graph_label(edge_colors[i]))
362         if edge_styles:
363             attrs.append(edge_styles[i])  # no quoting
364         if attrs:
365             write('"%s" %s "%s" [%s];' % (a, connector, b, ', '.join(attrs)))
366         else:
367             write('"%s" %s "%s";' % (a, connector, b))
368
369     if key_items:
370         key = compile_graph_key(key_items, nodes_above=quoted_vertices,
371                                 elisions=elisions)
372         write(key)
373
374     write('}\n')
375     return '\n'.join(out)
376
377
378 COLOUR_SETS = {
379     'ansi': {
380         'alternate rows': (colour.DARK_WHITE, colour.BLACK),
381         'disconnected': colour.RED,
382         'connected': colour.GREEN,
383         'transitive': colour.DARK_YELLOW,
384         'header': colour.UNDERLINE,
385         'reset': colour.C_NORMAL,
386     },
387     'ansi-heatmap': {
388         'alternate rows': (colour.DARK_WHITE, colour.BLACK),
389         'disconnected': colour.REV_RED,
390         'connected': colour.REV_GREEN,
391         'transitive': colour.REV_DARK_YELLOW,
392         'header': colour.UNDERLINE,
393         'reset': colour.C_NORMAL,
394     },
395     'xterm-256color': {
396         'alternate rows': (colour.xterm_256_colour(39),
397                            colour.xterm_256_colour(45)),
398         #'alternate rows': (colour.xterm_256_colour(246),
399         #                   colour.xterm_256_colour(247)),
400         'disconnected': colour.xterm_256_colour(124, bg=True),
401         'connected': colour.xterm_256_colour(112),
402         'transitive': colour.xterm_256_colour(214),
403         'transitive scale': (colour.xterm_256_colour(190),
404                              colour.xterm_256_colour(184),
405                              colour.xterm_256_colour(220),
406                              colour.xterm_256_colour(214),
407                              colour.xterm_256_colour(208),
408         ),
409         'header': colour.UNDERLINE,
410         'reset': colour.C_NORMAL,
411     },
412     'xterm-256color-heatmap': {
413         'alternate rows': (colour.xterm_256_colour(171),
414                            colour.xterm_256_colour(207)),
415         #'alternate rows': (colour.xterm_256_colour(246),
416         #                    colour.xterm_256_colour(247)),
417         'disconnected': colour.xterm_256_colour(124, bg=True),
418         'connected': colour.xterm_256_colour(112, bg=True),
419         'transitive': colour.xterm_256_colour(214, bg=True),
420         'transitive scale': (colour.xterm_256_colour(190, bg=True),
421                              colour.xterm_256_colour(184, bg=True),
422                              colour.xterm_256_colour(220, bg=True),
423                              colour.xterm_256_colour(214, bg=True),
424                              colour.xterm_256_colour(208, bg=True),
425         ),
426         'header': colour.UNDERLINE,
427         'reset': colour.C_NORMAL,
428     },
429     None: {
430         'alternate rows': ('',),
431         'disconnected': '',
432         'connected': '',
433         'transitive': '',
434         'header': '',
435         'reset': '',
436     }
437 }
438
439 CHARSETS = {
440     'utf8': {
441         'vertical': '│',
442         'horizontal': '─',
443         'corner': '╭',
444         #'diagonal': '╲',
445         'diagonal': '·',
446         #'missing': '🕱',
447         'missing': '-',
448         'right_arrow': '←',
449     },
450     'ascii': {
451         'vertical': '|',
452         'horizontal': '-',
453         'corner': ',',
454         'diagonal': '0',
455         'missing': '-',
456         'right_arrow': '<-',
457     }
458 }
459
460
461 def find_transitive_distance(vertices, edges):
462     all_vertices = (set(vertices) |
463                     set(e[0] for e in edges) |
464                     set(e[1] for e in edges))
465
466     if all_vertices != set(vertices):
467         print("there are unknown vertices: %s" %
468               (all_vertices - set(vertices)),
469               file=sys.stderr)
470
471     # with n vertices, we are always less than n hops away from
472     # anywhere else.
473     inf = len(all_vertices)
474     distances = {}
475     for v in all_vertices:
476         distances[v] = {v: 0}
477
478     for src, dest in edges:
479         distances[src][dest] = distances[src].get(dest, 1)
480
481     # This algorithm (and implementation) seems very suboptimal.
482     # potentially O(n^4), though n is smallish.
483     for i in range(inf):
484         changed = False
485         new_distances = {}
486         for v, d in distances.items():
487             new_d = d.copy()
488             new_distances[v] = new_d
489             for dest, cost in d.items():
490                 for leaf, cost2 in distances[dest].items():
491                     new_cost = cost + cost2
492                     old_cost = d.get(leaf, inf)
493                     if new_cost < old_cost:
494                         new_d[leaf] = new_cost
495                         changed = True
496
497         distances = new_distances
498         if not changed:
499             break
500
501     # filter out unwanted vertices and infinite links
502     answer = {}
503     for v in vertices:
504         answer[v] = {}
505         for v2 in vertices:
506             a = distances[v].get(v2, inf)
507             if a < inf:
508                 answer[v][v2] = a
509
510     return answer
511
512
513 def get_transitive_colourer(colours, n_vertices):
514     if 'transitive scale' in colours:
515         scale = colours['transitive scale']
516         m = len(scale)
517         n = 1 + int(n_vertices ** 0.5)
518
519         def f(link):
520             return scale[min(link * m // n, m - 1)]
521
522     else:
523         def f(link):
524             return colours['transitive']
525
526     return f
527
528
529 def distance_matrix(vertices, edges,
530                     utf8=False,
531                     colour=None,
532                     shorten_names=False,
533                     generate_key=False,
534                     grouping_function=None,
535                     row_comments=None):
536     lines = []
537     write = lines.append
538
539     charset = CHARSETS['utf8' if utf8 else 'ascii']
540     vertical = charset['vertical']
541     horizontal = charset['horizontal']
542     corner = charset['corner']
543     diagonal = charset['diagonal']
544     missing = charset['missing']
545     right_arrow = charset['right_arrow']
546
547     colours = COLOUR_SETS[colour]
548
549     colour_cycle = cycle(colours.get('alternate rows', ('',)))
550
551     if vertices is None:
552         vertices = sorted(set(x[0] for x in edges) | set(x[1] for x in edges))
553
554     if grouping_function is not None:
555         # we sort and colour according to the grouping function
556         # which can be used to e.g. alternate colours by site.
557         vertices = sorted(vertices, key=grouping_function)
558         colour_list = []
559         for k, v in groupby(vertices, key=grouping_function):
560             c = next(colour_cycle)
561             colour_list.extend(c for x in v)
562     else:
563         colour_list = [next(colour_cycle) for v in vertices]
564
565     if shorten_names:
566         vlist = list(set(x[0] for x in edges) |
567                      set(x[1] for x in edges) |
568                      set(vertices))
569         vmap, replacements = shorten_vertex_names(vlist, '+',
570                                                   aggressive=True)
571         vertices = [vmap[x] for x in vertices]
572         edges = [(vmap[a], vmap[b]) for a, b in edges]
573
574
575     vlen = max(6, max(len(v) for v in vertices))
576
577     # first, the key for the columns
578     c_header = colours.get('header', '')
579     c_disconn = colours.get('disconnected', '')
580     c_conn = colours.get('connected', '')
581     c_reset = colours.get('reset', '')
582
583     colour_transitive = get_transitive_colourer(colours, len(vertices))
584
585     vspace = ' ' * vlen
586     verticals = ''
587     write("%*s %s  %sdestination%s" % (vlen, '',
588                                        ' ' * len(vertices),
589                                        c_header,
590                                        c_reset))
591     for i, v in enumerate(vertices):
592         j = len(vertices) - i
593         c = colour_list[i]
594         if j == 1:
595             start = '%s%ssource%s' % (vspace[:-6], c_header, c_reset)
596         else:
597             start = vspace
598         write('%s %s%s%s%s%s %s%s' % (start,
599                                       verticals,
600                                       c_reset,
601                                       c,
602                                       corner,
603                                       horizontal * j,
604                                       v,
605                                       c_reset
606         ))
607         verticals += c + vertical
608
609     connections = find_transitive_distance(vertices, edges)
610
611     for i, v in enumerate(vertices):
612         c = colour_list[i]
613         links = connections[v]
614         row = []
615         for v2 in vertices:
616             link = links.get(v2)
617             if link is None:
618                 row.append('%s%s' % (c_disconn, missing))
619                 continue
620             if link == 0:
621                 row.append('%s%s%s%s' % (c_reset, c, diagonal, c_reset))
622             elif link == 1:
623                 row.append('%s1%s' % (c_conn, c_reset))
624             else:
625                 ct = colour_transitive(link)
626                 if link > 9:
627                     link = '+'
628                 row.append('%s%s%s' % (ct, link, c_reset))
629
630         if row_comments is not None and row_comments[i]:
631             row.append('%s %s %s' % (c_reset, right_arrow, row_comments[i]))
632
633         write('%s%*s%s %s%s' % (c, vlen, v, c_reset,
634                                 ''.join(row), c_reset))
635
636     example_c = next(colour_cycle)
637     if shorten_names:
638         write('')
639         for substitute, original in reversed(replacements):
640             write("'%s%s%s' stands for '%s%s%s'" % (example_c,
641                                                     substitute,
642                                                     c_reset,
643                                                     example_c,
644                                                     original,
645                                                     c_reset))
646     if generate_key:
647         write('')
648         write("Data can get from %ssource%s to %sdestination%s in the "
649               "indicated number of steps." % (c_header, c_reset,
650                                               c_header, c_reset))
651         write("%s%s%s means zero steps (it is the same DC)" %
652               (example_c, diagonal, c_reset))
653         write("%s1%s means a direct link" % (c_conn, c_reset))
654         write("%s2%s means a transitive link involving two steps "
655               "(i.e. one intermediate DC)" %
656               (colour_transitive(2), c_reset))
657         write("%s%s%s means there is no connection, even through other DCs" %
658               (c_disconn, missing, c_reset))
659
660     return '\n'.join(lines)