Move ParamikoSSHVendor to dulwich.contrib.paramiko_vendor, to avoid import errors.
[jelmer/dulwich.git] / dulwich / contrib / paramiko_vendor.py
1 # paramiko_vendor.py -- paramiko implementation of the SSHVendor interface
2 # Copyright (C) 2013 Aaron O'Mullan <aaron.omullan@friendco.de>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; either version 2
7 # or (at your option) a later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19 """Paramiko SSH support for Dulwich.
20
21 To use this implementation as the SSH implementation in Dulwich, override
22 the dulwich.client.get_ssh_vendor attribute:
23
24   >>> from dulwich import client as _mod_client
25   >>> from dulwich.contrib.paramiko_vendor import ParamikoSSHVendor
26   >>> _mod_client.get_ssh_vendor = ParamikoSSHVendor
27
28 This implementation is experimental and does not have any tests.
29 """
30
31 import paramiko
32 import paramiko.client
33 import subprocess
34 import threading
35
36 class _ParamikoWrapper(object):
37     STDERR_READ_N = 2048  # 2k
38
39     def __init__(self, client, channel, progress_stderr=None):
40         self.client = client
41         self.channel = channel
42         self.progress_stderr = progress_stderr
43         self.should_monitor = bool(progress_stderr) or True
44         self.monitor_thread = None
45         self.stderr = b''
46
47         # Channel must block
48         self.channel.setblocking(True)
49
50         # Start
51         if self.should_monitor:
52             self.monitor_thread = threading.Thread(
53                 target=self.monitor_stderr)
54             self.monitor_thread.start()
55
56     def monitor_stderr(self):
57         while self.should_monitor:
58             # Block and read
59             data = self.read_stderr(self.STDERR_READ_N)
60
61             # Socket closed
62             if not data:
63                 self.should_monitor = False
64                 break
65
66             # Emit data
67             if self.progress_stderr:
68                 self.progress_stderr(data)
69
70             # Append to buffer
71             self.stderr += data
72
73     def stop_monitoring(self):
74         # Stop StdErr thread
75         if self.should_monitor:
76             self.should_monitor = False
77             self.monitor_thread.join()
78
79             # Get left over data
80             data = self.channel.in_stderr_buffer.empty()
81             self.stderr += data
82
83     def can_read(self):
84         return self.channel.recv_ready()
85
86     def write(self, data):
87         return self.channel.sendall(data)
88
89     def read_stderr(self, n):
90         return self.channel.recv_stderr(n)
91
92     def read(self, n=None):
93         data = self.channel.recv(n)
94         data_len = len(data)
95
96         # Closed socket
97         if not data:
98             return
99
100         # Read more if needed
101         if n and data_len < n:
102             diff_len = n - data_len
103             return data + self.read(diff_len)
104         return data
105
106     def close(self):
107         self.channel.close()
108         self.stop_monitoring()
109
110
111 class ParamikoSSHVendor(object):
112
113     def __init__(self):
114         self.ssh_kwargs = {}
115
116     def run_command(self, host, command, username=None, port=None,
117                     progress_stderr=None):
118         if (type(command) is not list or
119             not all([isinstance(b, bytes) for b in command])):
120             raise TypeError(command)
121         # Paramiko needs an explicit port. None is not valid
122         if port is None:
123             port = 22
124
125         client = paramiko.SSHClient()
126
127         policy = paramiko.client.MissingHostKeyPolicy()
128         client.set_missing_host_key_policy(policy)
129         client.connect(host, username=username, port=port,
130                        **self.ssh_kwargs)
131
132         # Open SSH session
133         channel = client.get_transport().open_session()
134
135         # Run commands
136         channel.exec_command(subprocess.list2cmdline(command))
137
138         return _ParamikoWrapper(
139             client, channel, progress_stderr=progress_stderr)