Remove use of *args from GitClient constructor; this is only likely to
[jelmer/dulwich.git] / dulwich / contrib / paramiko.py
1 # paramako.py -- paramiko implementation of the Dulwich SSHVendor interface
2 # Copyright (C) 2013 Aaron O'Mullan <aaron.omullan@friendco.de>
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; either version 2
7 # or (at your option) a later version of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17 # MA  02110-1301, USA.
18
19 """Paramiko SSH support for Dulwich.
20
21 To use this implementation as the SSH implementation in Dulwich, override
22 the dulwich.client.get_ssh_vendor attribute:
23
24   >>> from dulwich import client as _mod_client
25   >>> from dulwich.contrib.paramiko import ParamikoSSHVendor
26   >>> _mod_client.get_ssh_vendor = ParamikoSSHVendor
27
28 This implementation is experimental and does not have any tests.
29 """
30
31 import paramiko
32 import subprocess
33 import threading
34
35 class _ParamikoWrapper(object):
36     STDERR_READ_N = 2048  # 2k
37
38     def __init__(self, client, channel, progress_stderr=None):
39         self.client = client
40         self.channel = channel
41         self.progress_stderr = progress_stderr
42         self.should_monitor = bool(progress_stderr) or True
43         self.monitor_thread = None
44         self.stderr = b''
45
46         # Channel must block
47         self.channel.setblocking(True)
48
49         # Start
50         if self.should_monitor:
51             self.monitor_thread = threading.Thread(
52                 target=self.monitor_stderr)
53             self.monitor_thread.start()
54
55     def monitor_stderr(self):
56         while self.should_monitor:
57             # Block and read
58             data = self.read_stderr(self.STDERR_READ_N)
59
60             # Socket closed
61             if not data:
62                 self.should_monitor = False
63                 break
64
65             # Emit data
66             if self.progress_stderr:
67                 self.progress_stderr(data)
68
69             # Append to buffer
70             self.stderr += data
71
72     def stop_monitoring(self):
73         # Stop StdErr thread
74         if self.should_monitor:
75             self.should_monitor = False
76             self.monitor_thread.join()
77
78             # Get left over data
79             data = self.channel.in_stderr_buffer.empty()
80             self.stderr += data
81
82     def can_read(self):
83         return self.channel.recv_ready()
84
85     def write(self, data):
86         return self.channel.sendall(data)
87
88     def read_stderr(self, n):
89         return self.channel.recv_stderr(n)
90
91     def read(self, n=None):
92         data = self.channel.recv(n)
93         data_len = len(data)
94
95         # Closed socket
96         if not data:
97             return
98
99         # Read more if needed
100         if n and data_len < n:
101             diff_len = n - data_len
102             return data + self.read(diff_len)
103         return data
104
105     def close(self):
106         self.channel.close()
107         self.stop_monitoring()
108
109
110 class ParamikoSSHVendor(object):
111
112     def __init__(self):
113         self.ssh_kwargs = {}
114
115     def run_command(self, host, command, username=None, port=None,
116                     progress_stderr=None):
117         if (type(command) is not list or
118             not all([isinstance(b, bytes) for b in command])):
119             raise TypeError(command)
120         # Paramiko needs an explicit port. None is not valid
121         if port is None:
122             port = 22
123
124         client = paramiko.SSHClient()
125
126         policy = paramiko.client.MissingHostKeyPolicy()
127         client.set_missing_host_key_policy(policy)
128         client.connect(host, username=username, port=port,
129                        **self.ssh_kwargs)
130
131         # Open SSH session
132         channel = client.get_transport().open_session()
133
134         # Run commands
135         channel.exec_command(subprocess.list2cmdline(command))
136
137         return _ParamikoWrapper(
138             client, channel, progress_stderr=progress_stderr)