##// END OF EJS Templates
protocol: move most ssh responses to returns
Matt Mackall -
r11580:69248b5a default
parent child Browse files
Show More
@@ -1,256 +1,258 b''
1 # sshserver.py - ssh protocol server support for mercurial
1 # sshserver.py - ssh protocol server support for mercurial
2 #
2 #
3 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
4 # Copyright 2006 Vadim Gelfer <vadim.gelfer@gmail.com>
4 # Copyright 2006 Vadim Gelfer <vadim.gelfer@gmail.com>
5 #
5 #
6 # This software may be used and distributed according to the terms of the
6 # This software may be used and distributed according to the terms of the
7 # GNU General Public License version 2 or any later version.
7 # GNU General Public License version 2 or any later version.
8
8
9 from i18n import _
9 from i18n import _
10 from node import bin, hex
10 from node import bin, hex
11 import streamclone, util, hook, pushkey
11 import streamclone, util, hook, pushkey
12 import os, sys, tempfile, urllib, copy
12 import os, sys, tempfile, urllib, copy
13
13
14 class sshserver(object):
14 class sshserver(object):
15
15
16 caps = 'unbundle lookup changegroupsubset branchmap pushkey'.split()
16 caps = 'unbundle lookup changegroupsubset branchmap pushkey'.split()
17
17
18 def __init__(self, ui, repo):
18 def __init__(self, ui, repo):
19 self.ui = ui
19 self.ui = ui
20 self.repo = repo
20 self.repo = repo
21 self.lock = None
21 self.lock = None
22 self.fin = sys.stdin
22 self.fin = sys.stdin
23 self.fout = sys.stdout
23 self.fout = sys.stdout
24
24
25 hook.redirect(True)
25 hook.redirect(True)
26 sys.stdout = sys.stderr
26 sys.stdout = sys.stderr
27
27
28 # Prevent insertion/deletion of CRs
28 # Prevent insertion/deletion of CRs
29 util.set_binary(self.fin)
29 util.set_binary(self.fin)
30 util.set_binary(self.fout)
30 util.set_binary(self.fout)
31
31
32 def getargs(self, args):
32 def getargs(self, args):
33 data = {}
33 data = {}
34 keys = args.split()
34 keys = args.split()
35 count = len(keys)
35 count = len(keys)
36 for n in xrange(len(keys)):
36 for n in xrange(len(keys)):
37 argline = self.fin.readline()[:-1]
37 argline = self.fin.readline()[:-1]
38 arg, l = argline.split()
38 arg, l = argline.split()
39 val = self.fin.read(int(l))
39 val = self.fin.read(int(l))
40 if arg not in keys:
40 if arg not in keys:
41 raise util.Abort("unexpected parameter %r" % arg)
41 raise util.Abort("unexpected parameter %r" % arg)
42 if arg == '*':
42 if arg == '*':
43 star = {}
43 star = {}
44 for n in xrange(int(l)):
44 for n in xrange(int(l)):
45 arg, l = argline.split()
45 arg, l = argline.split()
46 val = self.fin.read(int(l))
46 val = self.fin.read(int(l))
47 star[arg] = val
47 star[arg] = val
48 data['*'] = star
48 data['*'] = star
49 else:
49 else:
50 data[arg] = val
50 data[arg] = val
51 return [data[k] for k in keys]
51 return [data[k] for k in keys]
52
52
53 def getarg(self, name):
53 def getarg(self, name):
54 return self.getargs(name)[0]
54 return self.getargs(name)[0]
55
55
56 def respond(self, v):
56 def respond(self, v):
57 self.fout.write("%d\n" % len(v))
57 self.fout.write("%d\n" % len(v))
58 self.fout.write(v)
58 self.fout.write(v)
59 self.fout.flush()
59 self.fout.flush()
60
60
61 def serve_forever(self):
61 def serve_forever(self):
62 try:
62 try:
63 while self.serve_one():
63 while self.serve_one():
64 pass
64 pass
65 finally:
65 finally:
66 if self.lock is not None:
66 if self.lock is not None:
67 self.lock.release()
67 self.lock.release()
68 sys.exit(0)
68 sys.exit(0)
69
69
70 def serve_one(self):
70 def serve_one(self):
71 cmd = self.fin.readline()[:-1]
71 cmd = self.fin.readline()[:-1]
72 if cmd:
72 if cmd:
73 impl = getattr(self, 'do_' + cmd, None)
73 impl = getattr(self, 'do_' + cmd, None)
74 if impl:
74 if impl:
75 impl()
75 r = impl()
76 if r is not None:
77 self.respond(r)
76 else: self.respond("")
78 else: self.respond("")
77 return cmd != ''
79 return cmd != ''
78
80
79 def do_lookup(self):
81 def do_lookup(self):
80 key = self.getarg('key')
82 key = self.getarg('key')
81 try:
83 try:
82 r = hex(self.repo.lookup(key))
84 r = hex(self.repo.lookup(key))
83 success = 1
85 success = 1
84 except Exception, inst:
86 except Exception, inst:
85 r = str(inst)
87 r = str(inst)
86 success = 0
88 success = 0
87 self.respond("%s %s\n" % (success, r))
89 return "%s %s\n" % (success, r)
88
90
89 def do_branchmap(self):
91 def do_branchmap(self):
90 branchmap = self.repo.branchmap()
92 branchmap = self.repo.branchmap()
91 heads = []
93 heads = []
92 for branch, nodes in branchmap.iteritems():
94 for branch, nodes in branchmap.iteritems():
93 branchname = urllib.quote(branch)
95 branchname = urllib.quote(branch)
94 branchnodes = [hex(node) for node in nodes]
96 branchnodes = [hex(node) for node in nodes]
95 heads.append('%s %s' % (branchname, ' '.join(branchnodes)))
97 heads.append('%s %s' % (branchname, ' '.join(branchnodes)))
96 self.respond('\n'.join(heads))
98 return '\n'.join(heads)
97
99
98 def do_heads(self):
100 def do_heads(self):
99 h = self.repo.heads()
101 h = self.repo.heads()
100 self.respond(" ".join(map(hex, h)) + "\n")
102 return " ".join(map(hex, h)) + "\n"
101
103
102 def do_hello(self):
104 def do_hello(self):
103 '''the hello command returns a set of lines describing various
105 '''the hello command returns a set of lines describing various
104 interesting things about the server, in an RFC822-like format.
106 interesting things about the server, in an RFC822-like format.
105 Currently the only one defined is "capabilities", which
107 Currently the only one defined is "capabilities", which
106 consists of a line in the form:
108 consists of a line in the form:
107
109
108 capabilities: space separated list of tokens
110 capabilities: space separated list of tokens
109 '''
111 '''
110 caps = copy.copy(self.caps)
112 caps = copy.copy(self.caps)
111 if streamclone.allowed(self.repo.ui):
113 if streamclone.allowed(self.repo.ui):
112 caps.append('stream=%d' % self.repo.changelog.version)
114 caps.append('stream=%d' % self.repo.changelog.version)
113 self.respond("capabilities: %s\n" % (' '.join(caps),))
115 return "capabilities: %s\n" % (' '.join(caps),)
114
116
115 def do_lock(self):
117 def do_lock(self):
116 '''DEPRECATED - allowing remote client to lock repo is not safe'''
118 '''DEPRECATED - allowing remote client to lock repo is not safe'''
117
119
118 self.lock = self.repo.lock()
120 self.lock = self.repo.lock()
119 self.respond("")
121 return ""
120
122
121 def do_unlock(self):
123 def do_unlock(self):
122 '''DEPRECATED'''
124 '''DEPRECATED'''
123
125
124 if self.lock:
126 if self.lock:
125 self.lock.release()
127 self.lock.release()
126 self.lock = None
128 self.lock = None
127 self.respond("")
129 return ""
128
130
129 def do_branches(self):
131 def do_branches(self):
130 nodes = self.getarg('nodes')
132 nodes = self.getarg('nodes')
131 nodes = map(bin, nodes.split(" "))
133 nodes = map(bin, nodes.split(" "))
132 r = []
134 r = []
133 for b in self.repo.branches(nodes):
135 for b in self.repo.branches(nodes):
134 r.append(" ".join(map(hex, b)) + "\n")
136 r.append(" ".join(map(hex, b)) + "\n")
135 self.respond("".join(r))
137 return "".join(r)
136
138
137 def do_between(self):
139 def do_between(self):
138 pairs = self.getarg('pairs')
140 pairs = self.getarg('pairs')
139 pairs = [map(bin, p.split("-")) for p in pairs.split(" ")]
141 pairs = [map(bin, p.split("-")) for p in pairs.split(" ")]
140 r = []
142 r = []
141 for b in self.repo.between(pairs):
143 for b in self.repo.between(pairs):
142 r.append(" ".join(map(hex, b)) + "\n")
144 r.append(" ".join(map(hex, b)) + "\n")
143 self.respond("".join(r))
145 return "".join(r)
144
146
145 def do_changegroup(self):
147 def do_changegroup(self):
146 nodes = []
148 nodes = []
147 roots = self.getarg('roots')
149 roots = self.getarg('roots')
148 nodes = map(bin, roots.split(" "))
150 nodes = map(bin, roots.split(" "))
149
151
150 cg = self.repo.changegroup(nodes, 'serve')
152 cg = self.repo.changegroup(nodes, 'serve')
151 while True:
153 while True:
152 d = cg.read(4096)
154 d = cg.read(4096)
153 if not d:
155 if not d:
154 break
156 break
155 self.fout.write(d)
157 self.fout.write(d)
156
158
157 self.fout.flush()
159 self.fout.flush()
158
160
159 def do_changegroupsubset(self):
161 def do_changegroupsubset(self):
160 bases, heads = self.getargs('bases heads')
162 bases, heads = self.getargs('bases heads')
161 bases = [bin(n) for n in bases.split(' ')]
163 bases = [bin(n) for n in bases.split(' ')]
162 heads = [bin(n) for n in heads.split(' ')]
164 heads = [bin(n) for n in heads.split(' ')]
163
165
164 cg = self.repo.changegroupsubset(bases, heads, 'serve')
166 cg = self.repo.changegroupsubset(bases, heads, 'serve')
165 while True:
167 while True:
166 d = cg.read(4096)
168 d = cg.read(4096)
167 if not d:
169 if not d:
168 break
170 break
169 self.fout.write(d)
171 self.fout.write(d)
170
172
171 self.fout.flush()
173 self.fout.flush()
172
174
173 def do_addchangegroup(self):
175 def do_addchangegroup(self):
174 '''DEPRECATED'''
176 '''DEPRECATED'''
175
177
176 if not self.lock:
178 if not self.lock:
177 self.respond("not locked")
179 self.respond("not locked")
178 return
180 return
179
181
180 self.respond("")
182 self.respond("")
181 r = self.repo.addchangegroup(self.fin, 'serve', self.client_url(),
183 r = self.repo.addchangegroup(self.fin, 'serve', self.client_url(),
182 lock=self.lock)
184 lock=self.lock)
183 self.respond(str(r))
185 return str(r)
184
186
185 def client_url(self):
187 def client_url(self):
186 client = os.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
188 client = os.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
187 return 'remote:ssh:' + client
189 return 'remote:ssh:' + client
188
190
189 def do_unbundle(self):
191 def do_unbundle(self):
190 their_heads = self.getarg('heads').split()
192 their_heads = self.getarg('heads').split()
191
193
192 def check_heads():
194 def check_heads():
193 heads = map(hex, self.repo.heads())
195 heads = map(hex, self.repo.heads())
194 return their_heads == [hex('force')] or their_heads == heads
196 return their_heads == [hex('force')] or their_heads == heads
195
197
196 # fail early if possible
198 # fail early if possible
197 if not check_heads():
199 if not check_heads():
198 self.respond(_('unsynced changes'))
200 self.respond(_('unsynced changes'))
199 return
201 return
200
202
201 self.respond('')
203 self.respond('')
202
204
203 # write bundle data to temporary file because it can be big
205 # write bundle data to temporary file because it can be big
204 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
206 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
205 fp = os.fdopen(fd, 'wb+')
207 fp = os.fdopen(fd, 'wb+')
206 try:
208 try:
207 count = int(self.fin.readline())
209 count = int(self.fin.readline())
208 while count:
210 while count:
209 fp.write(self.fin.read(count))
211 fp.write(self.fin.read(count))
210 count = int(self.fin.readline())
212 count = int(self.fin.readline())
211
213
212 was_locked = self.lock is not None
214 was_locked = self.lock is not None
213 if not was_locked:
215 if not was_locked:
214 self.lock = self.repo.lock()
216 self.lock = self.repo.lock()
215 try:
217 try:
216 if not check_heads():
218 if not check_heads():
217 # someone else committed/pushed/unbundled while we
219 # someone else committed/pushed/unbundled while we
218 # were transferring data
220 # were transferring data
219 self.respond(_('unsynced changes'))
221 self.respond(_('unsynced changes'))
220 return
222 return
221 self.respond('')
223 self.respond('')
222
224
223 # push can proceed
225 # push can proceed
224
226
225 fp.seek(0)
227 fp.seek(0)
226 r = self.repo.addchangegroup(fp, 'serve', self.client_url(),
228 r = self.repo.addchangegroup(fp, 'serve', self.client_url(),
227 lock=self.lock)
229 lock=self.lock)
228 self.respond(str(r))
230 self.respond(str(r))
229 finally:
231 finally:
230 if not was_locked:
232 if not was_locked:
231 self.lock.release()
233 self.lock.release()
232 self.lock = None
234 self.lock = None
233 finally:
235 finally:
234 fp.close()
236 fp.close()
235 os.unlink(tempname)
237 os.unlink(tempname)
236
238
237 def do_stream_out(self):
239 def do_stream_out(self):
238 try:
240 try:
239 for chunk in streamclone.stream_out(self.repo):
241 for chunk in streamclone.stream_out(self.repo):
240 self.fout.write(chunk)
242 self.fout.write(chunk)
241 self.fout.flush()
243 self.fout.flush()
242 except streamclone.StreamException, inst:
244 except streamclone.StreamException, inst:
243 self.fout.write(str(inst))
245 self.fout.write(str(inst))
244 self.fout.flush()
246 self.fout.flush()
245
247
246 def do_pushkey(self):
248 def do_pushkey(self):
247 namespace, key, old, new = self.getargs('namespace key old new')
249 namespace, key, old, new = self.getargs('namespace key old new')
248 r = pushkey.push(self.repo, namespace, key, old, new)
250 r = pushkey.push(self.repo, namespace, key, old, new)
249 self.respond('%s\n' % int(r))
251 return '%s\n' % int(r)
250
252
251 def do_listkeys(self):
253 def do_listkeys(self):
252 namespace = self.getarg('namespace')
254 namespace = self.getarg('namespace')
253 d = pushkey.list(self.repo, namespace).items()
255 d = pushkey.list(self.repo, namespace).items()
254 t = '\n'.join(['%s\t%s' % (k.encode('string-escape'),
256 t = '\n'.join(['%s\t%s' % (k.encode('string-escape'),
255 v.encode('string-escape')) for k, v in d])
257 v.encode('string-escape')) for k, v in d])
256 self.respond(t)
258 return t
General Comments 0
You need to be logged in to leave comments. Login now