##// END OF EJS Templates
sshserver: Don't try to close fp if mkstemp failed
Thomas Arendsen Hein -
r6678:1eba8e8f default
parent child Browse files
Show More
@@ -1,205 +1,207 b''
1 # sshserver.py - ssh protocol server support for mercurial
1 # sshserver.py - ssh protocol server support for mercurial
2 #
2 #
3 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
4 # Copyright 2006 Vadim Gelfer <vadim.gelfer@gmail.com>
4 # Copyright 2006 Vadim Gelfer <vadim.gelfer@gmail.com>
5 #
5 #
6 # This software may be used and distributed according to the terms
6 # This software may be used and distributed according to the terms
7 # of the GNU General Public License, incorporated herein by reference.
7 # of the GNU General Public License, incorporated herein by reference.
8
8
9 from i18n import _
9 from i18n import _
10 from node import bin, hex
10 from node import bin, hex
11 import os, streamclone, sys, tempfile, util, hook
11 import os, streamclone, sys, tempfile, util, hook
12
12
13 class sshserver(object):
13 class sshserver(object):
14 def __init__(self, ui, repo):
14 def __init__(self, ui, repo):
15 self.ui = ui
15 self.ui = ui
16 self.repo = repo
16 self.repo = repo
17 self.lock = None
17 self.lock = None
18 self.fin = sys.stdin
18 self.fin = sys.stdin
19 self.fout = sys.stdout
19 self.fout = sys.stdout
20
20
21 hook.redirect(True)
21 hook.redirect(True)
22 sys.stdout = sys.stderr
22 sys.stdout = sys.stderr
23
23
24 # Prevent insertion/deletion of CRs
24 # Prevent insertion/deletion of CRs
25 util.set_binary(self.fin)
25 util.set_binary(self.fin)
26 util.set_binary(self.fout)
26 util.set_binary(self.fout)
27
27
28 def getarg(self):
28 def getarg(self):
29 argline = self.fin.readline()[:-1]
29 argline = self.fin.readline()[:-1]
30 arg, l = argline.split()
30 arg, l = argline.split()
31 val = self.fin.read(int(l))
31 val = self.fin.read(int(l))
32 return arg, val
32 return arg, val
33
33
34 def respond(self, v):
34 def respond(self, v):
35 self.fout.write("%d\n" % len(v))
35 self.fout.write("%d\n" % len(v))
36 self.fout.write(v)
36 self.fout.write(v)
37 self.fout.flush()
37 self.fout.flush()
38
38
39 def serve_forever(self):
39 def serve_forever(self):
40 while self.serve_one(): pass
40 while self.serve_one(): pass
41 sys.exit(0)
41 sys.exit(0)
42
42
43 def serve_one(self):
43 def serve_one(self):
44 cmd = self.fin.readline()[:-1]
44 cmd = self.fin.readline()[:-1]
45 if cmd:
45 if cmd:
46 impl = getattr(self, 'do_' + cmd, None)
46 impl = getattr(self, 'do_' + cmd, None)
47 if impl: impl()
47 if impl: impl()
48 else: self.respond("")
48 else: self.respond("")
49 return cmd != ''
49 return cmd != ''
50
50
51 def do_lookup(self):
51 def do_lookup(self):
52 arg, key = self.getarg()
52 arg, key = self.getarg()
53 assert arg == 'key'
53 assert arg == 'key'
54 try:
54 try:
55 r = hex(self.repo.lookup(key))
55 r = hex(self.repo.lookup(key))
56 success = 1
56 success = 1
57 except Exception,inst:
57 except Exception,inst:
58 r = str(inst)
58 r = str(inst)
59 success = 0
59 success = 0
60 self.respond("%s %s\n" % (success, r))
60 self.respond("%s %s\n" % (success, r))
61
61
62 def do_heads(self):
62 def do_heads(self):
63 h = self.repo.heads()
63 h = self.repo.heads()
64 self.respond(" ".join(map(hex, h)) + "\n")
64 self.respond(" ".join(map(hex, h)) + "\n")
65
65
66 def do_hello(self):
66 def do_hello(self):
67 '''the hello command returns a set of lines describing various
67 '''the hello command returns a set of lines describing various
68 interesting things about the server, in an RFC822-like format.
68 interesting things about the server, in an RFC822-like format.
69 Currently the only one defined is "capabilities", which
69 Currently the only one defined is "capabilities", which
70 consists of a line in the form:
70 consists of a line in the form:
71
71
72 capabilities: space separated list of tokens
72 capabilities: space separated list of tokens
73 '''
73 '''
74
74
75 caps = ['unbundle', 'lookup', 'changegroupsubset']
75 caps = ['unbundle', 'lookup', 'changegroupsubset']
76 if self.ui.configbool('server', 'uncompressed'):
76 if self.ui.configbool('server', 'uncompressed'):
77 caps.append('stream=%d' % self.repo.changelog.version)
77 caps.append('stream=%d' % self.repo.changelog.version)
78 self.respond("capabilities: %s\n" % (' '.join(caps),))
78 self.respond("capabilities: %s\n" % (' '.join(caps),))
79
79
80 def do_lock(self):
80 def do_lock(self):
81 '''DEPRECATED - allowing remote client to lock repo is not safe'''
81 '''DEPRECATED - allowing remote client to lock repo is not safe'''
82
82
83 self.lock = self.repo.lock()
83 self.lock = self.repo.lock()
84 self.respond("")
84 self.respond("")
85
85
86 def do_unlock(self):
86 def do_unlock(self):
87 '''DEPRECATED'''
87 '''DEPRECATED'''
88
88
89 if self.lock:
89 if self.lock:
90 self.lock.release()
90 self.lock.release()
91 self.lock = None
91 self.lock = None
92 self.respond("")
92 self.respond("")
93
93
94 def do_branches(self):
94 def do_branches(self):
95 arg, nodes = self.getarg()
95 arg, nodes = self.getarg()
96 nodes = map(bin, nodes.split(" "))
96 nodes = map(bin, nodes.split(" "))
97 r = []
97 r = []
98 for b in self.repo.branches(nodes):
98 for b in self.repo.branches(nodes):
99 r.append(" ".join(map(hex, b)) + "\n")
99 r.append(" ".join(map(hex, b)) + "\n")
100 self.respond("".join(r))
100 self.respond("".join(r))
101
101
102 def do_between(self):
102 def do_between(self):
103 arg, pairs = self.getarg()
103 arg, pairs = self.getarg()
104 pairs = [map(bin, p.split("-")) for p in pairs.split(" ")]
104 pairs = [map(bin, p.split("-")) for p in pairs.split(" ")]
105 r = []
105 r = []
106 for b in self.repo.between(pairs):
106 for b in self.repo.between(pairs):
107 r.append(" ".join(map(hex, b)) + "\n")
107 r.append(" ".join(map(hex, b)) + "\n")
108 self.respond("".join(r))
108 self.respond("".join(r))
109
109
110 def do_changegroup(self):
110 def do_changegroup(self):
111 nodes = []
111 nodes = []
112 arg, roots = self.getarg()
112 arg, roots = self.getarg()
113 nodes = map(bin, roots.split(" "))
113 nodes = map(bin, roots.split(" "))
114
114
115 cg = self.repo.changegroup(nodes, 'serve')
115 cg = self.repo.changegroup(nodes, 'serve')
116 while True:
116 while True:
117 d = cg.read(4096)
117 d = cg.read(4096)
118 if not d:
118 if not d:
119 break
119 break
120 self.fout.write(d)
120 self.fout.write(d)
121
121
122 self.fout.flush()
122 self.fout.flush()
123
123
124 def do_changegroupsubset(self):
124 def do_changegroupsubset(self):
125 bases = []
125 bases = []
126 heads = []
126 heads = []
127 argmap = dict([self.getarg(), self.getarg()])
127 argmap = dict([self.getarg(), self.getarg()])
128 bases = [bin(n) for n in argmap['bases'].split(' ')]
128 bases = [bin(n) for n in argmap['bases'].split(' ')]
129 heads = [bin(n) for n in argmap['heads'].split(' ')]
129 heads = [bin(n) for n in argmap['heads'].split(' ')]
130
130
131 cg = self.repo.changegroupsubset(bases, heads, 'serve')
131 cg = self.repo.changegroupsubset(bases, heads, 'serve')
132 while True:
132 while True:
133 d = cg.read(4096)
133 d = cg.read(4096)
134 if not d:
134 if not d:
135 break
135 break
136 self.fout.write(d)
136 self.fout.write(d)
137
137
138 self.fout.flush()
138 self.fout.flush()
139
139
140 def do_addchangegroup(self):
140 def do_addchangegroup(self):
141 '''DEPRECATED'''
141 '''DEPRECATED'''
142
142
143 if not self.lock:
143 if not self.lock:
144 self.respond("not locked")
144 self.respond("not locked")
145 return
145 return
146
146
147 self.respond("")
147 self.respond("")
148 r = self.repo.addchangegroup(self.fin, 'serve', self.client_url())
148 r = self.repo.addchangegroup(self.fin, 'serve', self.client_url())
149 self.respond(str(r))
149 self.respond(str(r))
150
150
151 def client_url(self):
151 def client_url(self):
152 client = os.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
152 client = os.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
153 return 'remote:ssh:' + client
153 return 'remote:ssh:' + client
154
154
155 def do_unbundle(self):
155 def do_unbundle(self):
156 their_heads = self.getarg()[1].split()
156 their_heads = self.getarg()[1].split()
157
157
158 def check_heads():
158 def check_heads():
159 heads = map(hex, self.repo.heads())
159 heads = map(hex, self.repo.heads())
160 return their_heads == [hex('force')] or their_heads == heads
160 return their_heads == [hex('force')] or their_heads == heads
161
161
162 # fail early if possible
162 # fail early if possible
163 if not check_heads():
163 if not check_heads():
164 self.respond(_('unsynced changes'))
164 self.respond(_('unsynced changes'))
165 return
165 return
166
166
167 self.respond('')
167 self.respond('')
168
168
169 # write bundle data to temporary file because it can be big
169 # write bundle data to temporary file because it can be big
170
170 tempname = fp = None
171 try:
171 try:
172 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
172 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
173 fp = os.fdopen(fd, 'wb+')
173 fp = os.fdopen(fd, 'wb+')
174
174
175 count = int(self.fin.readline())
175 count = int(self.fin.readline())
176 while count:
176 while count:
177 fp.write(self.fin.read(count))
177 fp.write(self.fin.read(count))
178 count = int(self.fin.readline())
178 count = int(self.fin.readline())
179
179
180 was_locked = self.lock is not None
180 was_locked = self.lock is not None
181 if not was_locked:
181 if not was_locked:
182 self.lock = self.repo.lock()
182 self.lock = self.repo.lock()
183 try:
183 try:
184 if not check_heads():
184 if not check_heads():
185 # someone else committed/pushed/unbundled while we
185 # someone else committed/pushed/unbundled while we
186 # were transferring data
186 # were transferring data
187 self.respond(_('unsynced changes'))
187 self.respond(_('unsynced changes'))
188 return
188 return
189 self.respond('')
189 self.respond('')
190
190
191 # push can proceed
191 # push can proceed
192
192
193 fp.seek(0)
193 fp.seek(0)
194 r = self.repo.addchangegroup(fp, 'serve', self.client_url())
194 r = self.repo.addchangegroup(fp, 'serve', self.client_url())
195 self.respond(str(r))
195 self.respond(str(r))
196 finally:
196 finally:
197 if not was_locked:
197 if not was_locked:
198 self.lock.release()
198 self.lock.release()
199 self.lock = None
199 self.lock = None
200 finally:
200 finally:
201 fp.close()
201 if fp is not None:
202 os.unlink(tempname)
202 fp.close()
203 if tempname is not None:
204 os.unlink(tempname)
203
205
204 def do_stream_out(self):
206 def do_stream_out(self):
205 streamclone.stream_out(self.repo, self.fout)
207 streamclone.stream_out(self.repo, self.fout)
General Comments 0
You need to be logged in to leave comments. Login now