##// END OF EJS Templates
protocol: move most ssh responses to returns
Matt Mackall -
r11580:69248b5a default
parent child Browse files
Show More
@@ -1,256 +1,258 b''
1 1 # sshserver.py - ssh protocol server support for mercurial
2 2 #
3 3 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
4 4 # Copyright 2006 Vadim Gelfer <vadim.gelfer@gmail.com>
5 5 #
6 6 # This software may be used and distributed according to the terms of the
7 7 # GNU General Public License version 2 or any later version.
8 8
9 9 from i18n import _
10 10 from node import bin, hex
11 11 import streamclone, util, hook, pushkey
12 12 import os, sys, tempfile, urllib, copy
13 13
14 14 class sshserver(object):
15 15
16 16 caps = 'unbundle lookup changegroupsubset branchmap pushkey'.split()
17 17
18 18 def __init__(self, ui, repo):
19 19 self.ui = ui
20 20 self.repo = repo
21 21 self.lock = None
22 22 self.fin = sys.stdin
23 23 self.fout = sys.stdout
24 24
25 25 hook.redirect(True)
26 26 sys.stdout = sys.stderr
27 27
28 28 # Prevent insertion/deletion of CRs
29 29 util.set_binary(self.fin)
30 30 util.set_binary(self.fout)
31 31
32 32 def getargs(self, args):
33 33 data = {}
34 34 keys = args.split()
35 35 count = len(keys)
36 36 for n in xrange(len(keys)):
37 37 argline = self.fin.readline()[:-1]
38 38 arg, l = argline.split()
39 39 val = self.fin.read(int(l))
40 40 if arg not in keys:
41 41 raise util.Abort("unexpected parameter %r" % arg)
42 42 if arg == '*':
43 43 star = {}
44 44 for n in xrange(int(l)):
45 45 arg, l = argline.split()
46 46 val = self.fin.read(int(l))
47 47 star[arg] = val
48 48 data['*'] = star
49 49 else:
50 50 data[arg] = val
51 51 return [data[k] for k in keys]
52 52
53 53 def getarg(self, name):
54 54 return self.getargs(name)[0]
55 55
56 56 def respond(self, v):
57 57 self.fout.write("%d\n" % len(v))
58 58 self.fout.write(v)
59 59 self.fout.flush()
60 60
61 61 def serve_forever(self):
62 62 try:
63 63 while self.serve_one():
64 64 pass
65 65 finally:
66 66 if self.lock is not None:
67 67 self.lock.release()
68 68 sys.exit(0)
69 69
70 70 def serve_one(self):
71 71 cmd = self.fin.readline()[:-1]
72 72 if cmd:
73 73 impl = getattr(self, 'do_' + cmd, None)
74 74 if impl:
75 impl()
75 r = impl()
76 if r is not None:
77 self.respond(r)
76 78 else: self.respond("")
77 79 return cmd != ''
78 80
79 81 def do_lookup(self):
80 82 key = self.getarg('key')
81 83 try:
82 84 r = hex(self.repo.lookup(key))
83 85 success = 1
84 86 except Exception, inst:
85 87 r = str(inst)
86 88 success = 0
87 self.respond("%s %s\n" % (success, r))
89 return "%s %s\n" % (success, r)
88 90
89 91 def do_branchmap(self):
90 92 branchmap = self.repo.branchmap()
91 93 heads = []
92 94 for branch, nodes in branchmap.iteritems():
93 95 branchname = urllib.quote(branch)
94 96 branchnodes = [hex(node) for node in nodes]
95 97 heads.append('%s %s' % (branchname, ' '.join(branchnodes)))
96 self.respond('\n'.join(heads))
98 return '\n'.join(heads)
97 99
98 100 def do_heads(self):
99 101 h = self.repo.heads()
100 self.respond(" ".join(map(hex, h)) + "\n")
102 return " ".join(map(hex, h)) + "\n"
101 103
102 104 def do_hello(self):
103 105 '''the hello command returns a set of lines describing various
104 106 interesting things about the server, in an RFC822-like format.
105 107 Currently the only one defined is "capabilities", which
106 108 consists of a line in the form:
107 109
108 110 capabilities: space separated list of tokens
109 111 '''
110 112 caps = copy.copy(self.caps)
111 113 if streamclone.allowed(self.repo.ui):
112 114 caps.append('stream=%d' % self.repo.changelog.version)
113 self.respond("capabilities: %s\n" % (' '.join(caps),))
115 return "capabilities: %s\n" % (' '.join(caps),)
114 116
115 117 def do_lock(self):
116 118 '''DEPRECATED - allowing remote client to lock repo is not safe'''
117 119
118 120 self.lock = self.repo.lock()
119 self.respond("")
121 return ""
120 122
121 123 def do_unlock(self):
122 124 '''DEPRECATED'''
123 125
124 126 if self.lock:
125 127 self.lock.release()
126 128 self.lock = None
127 self.respond("")
129 return ""
128 130
129 131 def do_branches(self):
130 132 nodes = self.getarg('nodes')
131 133 nodes = map(bin, nodes.split(" "))
132 134 r = []
133 135 for b in self.repo.branches(nodes):
134 136 r.append(" ".join(map(hex, b)) + "\n")
135 self.respond("".join(r))
137 return "".join(r)
136 138
137 139 def do_between(self):
138 140 pairs = self.getarg('pairs')
139 141 pairs = [map(bin, p.split("-")) for p in pairs.split(" ")]
140 142 r = []
141 143 for b in self.repo.between(pairs):
142 144 r.append(" ".join(map(hex, b)) + "\n")
143 self.respond("".join(r))
145 return "".join(r)
144 146
145 147 def do_changegroup(self):
146 148 nodes = []
147 149 roots = self.getarg('roots')
148 150 nodes = map(bin, roots.split(" "))
149 151
150 152 cg = self.repo.changegroup(nodes, 'serve')
151 153 while True:
152 154 d = cg.read(4096)
153 155 if not d:
154 156 break
155 157 self.fout.write(d)
156 158
157 159 self.fout.flush()
158 160
159 161 def do_changegroupsubset(self):
160 162 bases, heads = self.getargs('bases heads')
161 163 bases = [bin(n) for n in bases.split(' ')]
162 164 heads = [bin(n) for n in heads.split(' ')]
163 165
164 166 cg = self.repo.changegroupsubset(bases, heads, 'serve')
165 167 while True:
166 168 d = cg.read(4096)
167 169 if not d:
168 170 break
169 171 self.fout.write(d)
170 172
171 173 self.fout.flush()
172 174
173 175 def do_addchangegroup(self):
174 176 '''DEPRECATED'''
175 177
176 178 if not self.lock:
177 179 self.respond("not locked")
178 180 return
179 181
180 182 self.respond("")
181 183 r = self.repo.addchangegroup(self.fin, 'serve', self.client_url(),
182 184 lock=self.lock)
183 self.respond(str(r))
185 return str(r)
184 186
185 187 def client_url(self):
186 188 client = os.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
187 189 return 'remote:ssh:' + client
188 190
189 191 def do_unbundle(self):
190 192 their_heads = self.getarg('heads').split()
191 193
192 194 def check_heads():
193 195 heads = map(hex, self.repo.heads())
194 196 return their_heads == [hex('force')] or their_heads == heads
195 197
196 198 # fail early if possible
197 199 if not check_heads():
198 200 self.respond(_('unsynced changes'))
199 201 return
200 202
201 203 self.respond('')
202 204
203 205 # write bundle data to temporary file because it can be big
204 206 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
205 207 fp = os.fdopen(fd, 'wb+')
206 208 try:
207 209 count = int(self.fin.readline())
208 210 while count:
209 211 fp.write(self.fin.read(count))
210 212 count = int(self.fin.readline())
211 213
212 214 was_locked = self.lock is not None
213 215 if not was_locked:
214 216 self.lock = self.repo.lock()
215 217 try:
216 218 if not check_heads():
217 219 # someone else committed/pushed/unbundled while we
218 220 # were transferring data
219 221 self.respond(_('unsynced changes'))
220 222 return
221 223 self.respond('')
222 224
223 225 # push can proceed
224 226
225 227 fp.seek(0)
226 228 r = self.repo.addchangegroup(fp, 'serve', self.client_url(),
227 229 lock=self.lock)
228 230 self.respond(str(r))
229 231 finally:
230 232 if not was_locked:
231 233 self.lock.release()
232 234 self.lock = None
233 235 finally:
234 236 fp.close()
235 237 os.unlink(tempname)
236 238
237 239 def do_stream_out(self):
238 240 try:
239 241 for chunk in streamclone.stream_out(self.repo):
240 242 self.fout.write(chunk)
241 243 self.fout.flush()
242 244 except streamclone.StreamException, inst:
243 245 self.fout.write(str(inst))
244 246 self.fout.flush()
245 247
246 248 def do_pushkey(self):
247 249 namespace, key, old, new = self.getargs('namespace key old new')
248 250 r = pushkey.push(self.repo, namespace, key, old, new)
249 self.respond('%s\n' % int(r))
251 return '%s\n' % int(r)
250 252
251 253 def do_listkeys(self):
252 254 namespace = self.getarg('namespace')
253 255 d = pushkey.list(self.repo, namespace).items()
254 256 t = '\n'.join(['%s\t%s' % (k.encode('string-escape'),
255 257 v.encode('string-escape')) for k, v in d])
256 self.respond(t)
258 return t
General Comments 0
You need to be logged in to leave comments. Login now