##// END OF EJS Templates
rust-chg: leverage impl trait at argument position...
Yuya Nishihara -
r45183:61fda2db default
parent child Browse files
Show More
@@ -1,136 +1,128 b''
1 1 // Copyright 2018 Yuya Nishihara <yuya@tcha.org>
2 2 //
3 3 // This software may be used and distributed according to the terms of the
4 4 // GNU General Public License version 2 or any later version.
5 5
6 6 //! cHg extensions to command server client.
7 7
8 8 use bytes::{BufMut, Bytes, BytesMut};
9 9 use std::ffi::OsStr;
10 10 use std::io;
11 11 use std::mem;
12 12 use std::os::unix::ffi::OsStrExt;
13 13 use std::os::unix::io::AsRawFd;
14 14 use std::path::Path;
15 15 use tokio_hglib::protocol::{OneShotQuery, OneShotRequest};
16 16 use tokio_hglib::{Client, Connection};
17 17
18 18 use crate::attachio::AttachIo;
19 19 use crate::message::{self, Instruction};
20 20 use crate::runcommand::ChgRunCommand;
21 21 use crate::uihandler::SystemHandler;
22 22
23 23 pub trait ChgClientExt<C>
24 24 where
25 25 C: Connection + AsRawFd,
26 26 {
27 27 /// Attaches the client file descriptors to the server.
28 28 fn attach_io<I, O, E>(self, stdin: I, stdout: O, stderr: E) -> AttachIo<C, I, O, E>
29 29 where
30 30 I: AsRawFd,
31 31 O: AsRawFd,
32 32 E: AsRawFd;
33 33
34 34 /// Changes the working directory of the server.
35 fn set_current_dir<P>(self, dir: P) -> OneShotRequest<C>
36 where
37 P: AsRef<Path>;
35 fn set_current_dir(self, dir: impl AsRef<Path>) -> OneShotRequest<C>;
38 36
39 37 /// Updates the environment variables of the server.
40 fn set_env_vars_os<I, P>(self, vars: I) -> OneShotRequest<C>
41 where
42 I: IntoIterator<Item = (P, P)>,
43 P: AsRef<OsStr>;
38 fn set_env_vars_os(
39 self,
40 vars: impl IntoIterator<Item = (impl AsRef<OsStr>, impl AsRef<OsStr>)>,
41 ) -> OneShotRequest<C>;
44 42
45 43 /// Changes the process title of the server.
46 fn set_process_name<P>(self, name: P) -> OneShotRequest<C>
47 where
48 P: AsRef<OsStr>;
44 fn set_process_name(self, name: impl AsRef<OsStr>) -> OneShotRequest<C>;
49 45
50 46 /// Changes the umask of the server process.
51 47 fn set_umask(self, mask: u32) -> OneShotRequest<C>;
52 48
53 49 /// Runs the specified Mercurial command with cHg extension.
54 fn run_command_chg<I, P, H>(self, handler: H, args: I) -> ChgRunCommand<C, H>
50 fn run_command_chg<H>(
51 self,
52 handler: H,
53 args: impl IntoIterator<Item = impl AsRef<OsStr>>,
54 ) -> ChgRunCommand<C, H>
55 55 where
56 I: IntoIterator<Item = P>,
57 P: AsRef<OsStr>,
58 56 H: SystemHandler;
59 57
60 58 /// Validates if the server can run Mercurial commands with the expected
61 59 /// configuration.
62 60 ///
63 61 /// The `args` should contain early command arguments such as `--config`
64 62 /// and `-R`.
65 63 ///
66 64 /// Client-side environment must be sent prior to this request, by
67 65 /// `set_current_dir()` and `set_env_vars_os()`.
68 fn validate<I, P>(self, args: I) -> OneShotQuery<C, fn(Bytes) -> io::Result<Vec<Instruction>>>
69 where
70 I: IntoIterator<Item = P>,
71 P: AsRef<OsStr>;
66 fn validate(
67 self,
68 args: impl IntoIterator<Item = impl AsRef<OsStr>>,
69 ) -> OneShotQuery<C, fn(Bytes) -> io::Result<Vec<Instruction>>>;
72 70 }
73 71
74 72 impl<C> ChgClientExt<C> for Client<C>
75 73 where
76 74 C: Connection + AsRawFd,
77 75 {
78 76 fn attach_io<I, O, E>(self, stdin: I, stdout: O, stderr: E) -> AttachIo<C, I, O, E>
79 77 where
80 78 I: AsRawFd,
81 79 O: AsRawFd,
82 80 E: AsRawFd,
83 81 {
84 82 AttachIo::with_client(self, stdin, stdout, Some(stderr))
85 83 }
86 84
87 fn set_current_dir<P>(self, dir: P) -> OneShotRequest<C>
88 where
89 P: AsRef<Path>,
90 {
85 fn set_current_dir(self, dir: impl AsRef<Path>) -> OneShotRequest<C> {
91 86 OneShotRequest::start_with_args(self, b"chdir", dir.as_ref().as_os_str().as_bytes())
92 87 }
93 88
94 fn set_env_vars_os<I, P>(self, vars: I) -> OneShotRequest<C>
95 where
96 I: IntoIterator<Item = (P, P)>,
97 P: AsRef<OsStr>,
98 {
89 fn set_env_vars_os(
90 self,
91 vars: impl IntoIterator<Item = (impl AsRef<OsStr>, impl AsRef<OsStr>)>,
92 ) -> OneShotRequest<C> {
99 93 OneShotRequest::start_with_args(self, b"setenv", message::pack_env_vars_os(vars))
100 94 }
101 95
102 fn set_process_name<P>(self, name: P) -> OneShotRequest<C>
103 where
104 P: AsRef<OsStr>,
105 {
96 fn set_process_name(self, name: impl AsRef<OsStr>) -> OneShotRequest<C> {
106 97 OneShotRequest::start_with_args(self, b"setprocname", name.as_ref().as_bytes())
107 98 }
108 99
109 100 fn set_umask(self, mask: u32) -> OneShotRequest<C> {
110 101 let mut args = BytesMut::with_capacity(mem::size_of_val(&mask));
111 102 args.put_u32_be(mask);
112 103 OneShotRequest::start_with_args(self, b"setumask2", args)
113 104 }
114 105
115 fn run_command_chg<I, P, H>(self, handler: H, args: I) -> ChgRunCommand<C, H>
106 fn run_command_chg<H>(
107 self,
108 handler: H,
109 args: impl IntoIterator<Item = impl AsRef<OsStr>>,
110 ) -> ChgRunCommand<C, H>
116 111 where
117 I: IntoIterator<Item = P>,
118 P: AsRef<OsStr>,
119 112 H: SystemHandler,
120 113 {
121 114 ChgRunCommand::with_client(self, handler, message::pack_args_os(args))
122 115 }
123 116
124 fn validate<I, P>(self, args: I) -> OneShotQuery<C, fn(Bytes) -> io::Result<Vec<Instruction>>>
125 where
126 I: IntoIterator<Item = P>,
127 P: AsRef<OsStr>,
128 {
117 fn validate(
118 self,
119 args: impl IntoIterator<Item = impl AsRef<OsStr>>,
120 ) -> OneShotQuery<C, fn(Bytes) -> io::Result<Vec<Instruction>>> {
129 121 OneShotQuery::start_with_args(
130 122 self,
131 123 b"validate",
132 124 message::pack_args_os(args),
133 125 message::parse_instructions,
134 126 )
135 127 }
136 128 }
@@ -1,501 +1,490 b''
1 1 // Copyright 2011, 2018 Yuya Nishihara <yuya@tcha.org>
2 2 //
3 3 // This software may be used and distributed according to the terms of the
4 4 // GNU General Public License version 2 or any later version.
5 5
6 6 //! Utility for locating command-server process.
7 7
8 8 use futures::future::{self, Either, Loop};
9 9 use log::debug;
10 10 use std::env;
11 11 use std::ffi::{OsStr, OsString};
12 12 use std::fs::{self, DirBuilder};
13 13 use std::io;
14 14 use std::os::unix::ffi::{OsStrExt, OsStringExt};
15 15 use std::os::unix::fs::{DirBuilderExt, MetadataExt};
16 16 use std::path::{Path, PathBuf};
17 17 use std::process::{self, Command};
18 18 use std::time::Duration;
19 19 use tokio::prelude::*;
20 20 use tokio_hglib::UnixClient;
21 21 use tokio_process::{Child, CommandExt};
22 22 use tokio_timer;
23 23
24 24 use crate::clientext::ChgClientExt;
25 25 use crate::message::{Instruction, ServerSpec};
26 26 use crate::procutil;
27 27
28 28 const REQUIRED_SERVER_CAPABILITIES: &[&str] = &[
29 29 "attachio",
30 30 "chdir",
31 31 "runcommand",
32 32 "setenv",
33 33 "setumask2",
34 34 "validate",
35 35 ];
36 36
37 37 /// Helper to connect to and spawn a server process.
38 38 #[derive(Clone, Debug)]
39 39 pub struct Locator {
40 40 hg_command: OsString,
41 41 hg_early_args: Vec<OsString>,
42 42 current_dir: PathBuf,
43 43 env_vars: Vec<(OsString, OsString)>,
44 44 process_id: u32,
45 45 base_sock_path: PathBuf,
46 46 redirect_sock_path: Option<PathBuf>,
47 47 timeout: Duration,
48 48 }
49 49
50 50 impl Locator {
51 51 /// Creates locator capturing the current process environment.
52 52 ///
53 53 /// If no `$CHGSOCKNAME` is specified, the socket directory will be
54 54 /// created as necessary.
55 55 pub fn prepare_from_env() -> io::Result<Locator> {
56 56 Ok(Locator {
57 57 hg_command: default_hg_command(),
58 58 hg_early_args: Vec::new(),
59 59 current_dir: env::current_dir()?,
60 60 env_vars: env::vars_os().collect(),
61 61 process_id: process::id(),
62 62 base_sock_path: prepare_server_socket_path()?,
63 63 redirect_sock_path: None,
64 64 timeout: default_timeout(),
65 65 })
66 66 }
67 67
68 68 /// Temporary socket path for this client process.
69 69 fn temp_sock_path(&self) -> PathBuf {
70 70 let src = self.base_sock_path.as_os_str().as_bytes();
71 71 let mut buf = Vec::with_capacity(src.len() + 6); // "{src}.{pid}".len()
72 72 buf.extend_from_slice(src);
73 73 buf.extend_from_slice(format!(".{}", self.process_id).as_bytes());
74 74 OsString::from_vec(buf).into()
75 75 }
76 76
77 77 /// Specifies the arguments to be passed to the server at start.
78 pub fn set_early_args<I, P>(&mut self, args: I)
79 where
80 I: IntoIterator<Item = P>,
81 P: AsRef<OsStr>,
82 {
78 pub fn set_early_args(&mut self, args: impl IntoIterator<Item = impl AsRef<OsStr>>) {
83 79 self.hg_early_args = args.into_iter().map(|a| a.as_ref().to_owned()).collect();
84 80 }
85 81
86 82 /// Connects to the server.
87 83 ///
88 84 /// The server process will be spawned if not running.
89 85 pub fn connect(self) -> impl Future<Item = (Self, UnixClient), Error = io::Error> {
90 86 future::loop_fn((self, 0), |(loc, cnt)| {
91 87 if cnt < 10 {
92 88 let fut = loc
93 89 .try_connect()
94 90 .and_then(|(loc, client)| {
95 91 client
96 92 .validate(&loc.hg_early_args)
97 93 .map(|(client, instructions)| (loc, client, instructions))
98 94 })
99 95 .and_then(move |(loc, client, instructions)| {
100 96 loc.run_instructions(client, instructions, cnt)
101 97 });
102 98 Either::A(fut)
103 99 } else {
104 100 let msg = format!(
105 101 concat!(
106 102 "too many redirections.\n",
107 103 "Please make sure {:?} is not a wrapper which ",
108 104 "changes sensitive environment variables ",
109 105 "before executing hg. If you have to use a ",
110 106 "wrapper, wrap chg instead of hg.",
111 107 ),
112 108 loc.hg_command
113 109 );
114 110 Either::B(future::err(io::Error::new(io::ErrorKind::Other, msg)))
115 111 }
116 112 })
117 113 }
118 114
119 115 /// Runs instructions received from the server.
120 116 fn run_instructions(
121 117 mut self,
122 118 client: UnixClient,
123 119 instructions: Vec<Instruction>,
124 120 cnt: usize,
125 121 ) -> io::Result<Loop<(Self, UnixClient), (Self, usize)>> {
126 122 let mut reconnect = false;
127 123 for inst in instructions {
128 124 debug!("instruction: {:?}", inst);
129 125 match inst {
130 126 Instruction::Exit(_) => {
131 127 // Just returns the current connection to run the
132 128 // unparsable command and report the error
133 129 return Ok(Loop::Break((self, client)));
134 130 }
135 131 Instruction::Reconnect => {
136 132 reconnect = true;
137 133 }
138 134 Instruction::Redirect(path) => {
139 135 if path.parent() != self.base_sock_path.parent() {
140 136 let msg = format!(
141 137 "insecure redirect instruction from server: {}",
142 138 path.display()
143 139 );
144 140 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
145 141 }
146 142 self.redirect_sock_path = Some(path);
147 143 reconnect = true;
148 144 }
149 145 Instruction::Unlink(path) => {
150 146 if path.parent() != self.base_sock_path.parent() {
151 147 let msg = format!(
152 148 "insecure unlink instruction from server: {}",
153 149 path.display()
154 150 );
155 151 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
156 152 }
157 153 fs::remove_file(path).unwrap_or(()); // may race
158 154 }
159 155 }
160 156 }
161 157
162 158 if reconnect {
163 159 Ok(Loop::Continue((self, cnt + 1)))
164 160 } else {
165 161 Ok(Loop::Break((self, client)))
166 162 }
167 163 }
168 164
169 165 /// Tries to connect to the existing server, or spawns new if not running.
170 166 fn try_connect(self) -> impl Future<Item = (Self, UnixClient), Error = io::Error> {
171 167 let sock_path = self
172 168 .redirect_sock_path
173 169 .as_ref()
174 170 .unwrap_or(&self.base_sock_path)
175 171 .clone();
176 172 debug!("try connect to {}", sock_path.display());
177 173 UnixClient::connect(sock_path)
178 174 .then(|res| {
179 175 match res {
180 176 Ok(client) => Either::A(future::ok((self, client))),
181 177 Err(_) => {
182 178 // Prevent us from being re-connected to the outdated
183 179 // master server: We were told by the server to redirect
184 180 // to redirect_sock_path, which didn't work. We do not
185 181 // want to connect to the same master server again
186 182 // because it would probably tell us the same thing.
187 183 if self.redirect_sock_path.is_some() {
188 184 fs::remove_file(&self.base_sock_path).unwrap_or(());
189 185 // may race
190 186 }
191 187 Either::B(self.spawn_connect())
192 188 }
193 189 }
194 190 })
195 191 .and_then(|(loc, client)| {
196 192 check_server_capabilities(client.server_spec())?;
197 193 Ok((loc, client))
198 194 })
199 195 .and_then(|(loc, client)| {
200 196 // It's purely optional, and the server might not support this command.
201 197 if client.server_spec().capabilities.contains("setprocname") {
202 198 let fut = client
203 199 .set_process_name(format!("chg[worker/{}]", loc.process_id))
204 200 .map(|client| (loc, client));
205 201 Either::A(fut)
206 202 } else {
207 203 Either::B(future::ok((loc, client)))
208 204 }
209 205 })
210 206 .and_then(|(loc, client)| {
211 207 client
212 208 .set_current_dir(&loc.current_dir)
213 209 .map(|client| (loc, client))
214 210 })
215 211 .and_then(|(loc, client)| {
216 212 client
217 213 .set_env_vars_os(loc.env_vars.iter().cloned())
218 214 .map(|client| (loc, client))
219 215 })
220 216 }
221 217
222 218 /// Spawns new server process and connects to it.
223 219 ///
224 220 /// The server will be spawned at the current working directory, then
225 221 /// chdir to "/", so that the server will load configs from the target
226 222 /// repository.
227 223 fn spawn_connect(self) -> impl Future<Item = (Self, UnixClient), Error = io::Error> {
228 224 let sock_path = self.temp_sock_path();
229 225 debug!("start cmdserver at {}", sock_path.display());
230 226 Command::new(&self.hg_command)
231 227 .arg("serve")
232 228 .arg("--cmdserver")
233 229 .arg("chgunix")
234 230 .arg("--address")
235 231 .arg(&sock_path)
236 232 .arg("--daemon-postexec")
237 233 .arg("chdir:/")
238 234 .args(&self.hg_early_args)
239 235 .current_dir(&self.current_dir)
240 236 .env_clear()
241 237 .envs(self.env_vars.iter().cloned())
242 238 .env("CHGINTERNALMARK", "")
243 239 .spawn_async()
244 240 .into_future()
245 241 .and_then(|server| self.connect_spawned(server, sock_path))
246 242 .and_then(|(loc, client, sock_path)| {
247 243 debug!(
248 244 "rename {} to {}",
249 245 sock_path.display(),
250 246 loc.base_sock_path.display()
251 247 );
252 248 fs::rename(&sock_path, &loc.base_sock_path)?;
253 249 Ok((loc, client))
254 250 })
255 251 }
256 252
257 253 /// Tries to connect to the just spawned server repeatedly until timeout
258 254 /// exceeded.
259 255 fn connect_spawned(
260 256 self,
261 257 server: Child,
262 258 sock_path: PathBuf,
263 259 ) -> impl Future<Item = (Self, UnixClient, PathBuf), Error = io::Error> {
264 260 debug!("try connect to {} repeatedly", sock_path.display());
265 261 let connect = future::loop_fn(sock_path, |sock_path| {
266 262 UnixClient::connect(sock_path.clone()).then(|res| {
267 263 match res {
268 264 Ok(client) => Either::A(future::ok(Loop::Break((client, sock_path)))),
269 265 Err(_) => {
270 266 // try again with slight delay
271 267 let fut = tokio_timer::sleep(Duration::from_millis(10))
272 268 .map(|()| Loop::Continue(sock_path))
273 269 .map_err(|err| io::Error::new(io::ErrorKind::Other, err));
274 270 Either::B(fut)
275 271 }
276 272 }
277 273 })
278 274 });
279 275
280 276 // waits for either connection established or server failed to start
281 277 connect
282 278 .select2(server)
283 279 .map_err(|res| res.split().0)
284 280 .timeout(self.timeout)
285 281 .map_err(|err| {
286 282 err.into_inner().unwrap_or_else(|| {
287 283 io::Error::new(
288 284 io::ErrorKind::TimedOut,
289 285 "timed out while connecting to server",
290 286 )
291 287 })
292 288 })
293 289 .and_then(|res| {
294 290 match res {
295 291 Either::A(((client, sock_path), server)) => {
296 292 server.forget(); // continue to run in background
297 293 Ok((self, client, sock_path))
298 294 }
299 295 Either::B((st, _)) => Err(io::Error::new(
300 296 io::ErrorKind::Other,
301 297 format!("server exited too early: {}", st),
302 298 )),
303 299 }
304 300 })
305 301 }
306 302 }
307 303
308 304 /// Determines the server socket to connect to.
309 305 ///
310 306 /// If no `$CHGSOCKNAME` is specified, the socket directory will be created
311 307 /// as necessary.
312 308 fn prepare_server_socket_path() -> io::Result<PathBuf> {
313 309 if let Some(s) = env::var_os("CHGSOCKNAME") {
314 310 Ok(PathBuf::from(s))
315 311 } else {
316 312 let mut path = default_server_socket_dir();
317 313 create_secure_dir(&path)?;
318 314 path.push("server");
319 315 Ok(path)
320 316 }
321 317 }
322 318
323 319 /// Determines the default server socket path as follows.
324 320 ///
325 321 /// 1. `$XDG_RUNTIME_DIR/chg`
326 322 /// 2. `$TMPDIR/chg$UID`
327 323 /// 3. `/tmp/chg$UID`
328 324 pub fn default_server_socket_dir() -> PathBuf {
329 325 // XDG_RUNTIME_DIR should be ignored if it has an insufficient permission.
330 326 // https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
331 327 if let Some(Ok(s)) = env::var_os("XDG_RUNTIME_DIR").map(check_secure_dir) {
332 328 let mut path = PathBuf::from(s);
333 329 path.push("chg");
334 330 path
335 331 } else {
336 332 let mut path = env::temp_dir();
337 333 path.push(format!("chg{}", procutil::get_effective_uid()));
338 334 path
339 335 }
340 336 }
341 337
342 338 /// Determines the default hg command.
343 339 pub fn default_hg_command() -> OsString {
344 340 // TODO: maybe allow embedding the path at compile time (or load from hgrc)
345 341 env::var_os("CHGHG")
346 342 .or(env::var_os("HG"))
347 343 .unwrap_or(OsStr::new("hg").to_owned())
348 344 }
349 345
350 346 fn default_timeout() -> Duration {
351 347 let secs = env::var("CHGTIMEOUT")
352 348 .ok()
353 349 .and_then(|s| s.parse().ok())
354 350 .unwrap_or(60);
355 351 Duration::from_secs(secs)
356 352 }
357 353
358 354 /// Creates a directory which the other users cannot access to.
359 355 ///
360 356 /// If the directory already exists, tests its permission.
361 fn create_secure_dir<P>(path: P) -> io::Result<()>
362 where
363 P: AsRef<Path>,
364 {
357 fn create_secure_dir(path: impl AsRef<Path>) -> io::Result<()> {
365 358 DirBuilder::new()
366 359 .mode(0o700)
367 360 .create(path.as_ref())
368 361 .or_else(|err| {
369 362 if err.kind() == io::ErrorKind::AlreadyExists {
370 363 check_secure_dir(path).map(|_| ())
371 364 } else {
372 365 Err(err)
373 366 }
374 367 })
375 368 }
376 369
377 370 fn check_secure_dir<P>(path: P) -> io::Result<P>
378 371 where
379 372 P: AsRef<Path>,
380 373 {
381 374 let a = fs::symlink_metadata(path.as_ref())?;
382 375 if a.is_dir() && a.uid() == procutil::get_effective_uid() && (a.mode() & 0o777) == 0o700 {
383 376 Ok(path)
384 377 } else {
385 378 Err(io::Error::new(io::ErrorKind::Other, "insecure directory"))
386 379 }
387 380 }
388 381
389 382 fn check_server_capabilities(spec: &ServerSpec) -> io::Result<()> {
390 383 let unsupported: Vec<_> = REQUIRED_SERVER_CAPABILITIES
391 384 .iter()
392 385 .cloned()
393 386 .filter(|&s| !spec.capabilities.contains(s))
394 387 .collect();
395 388 if unsupported.is_empty() {
396 389 Ok(())
397 390 } else {
398 391 let msg = format!(
399 392 "insufficient server capabilities: {}",
400 393 unsupported.join(", ")
401 394 );
402 395 Err(io::Error::new(io::ErrorKind::Other, msg))
403 396 }
404 397 }
405 398
406 399 /// Collects arguments which need to be passed to the server at start.
407 pub fn collect_early_args<I, P>(args: I) -> Vec<OsString>
408 where
409 I: IntoIterator<Item = P>,
410 P: AsRef<OsStr>,
411 {
400 pub fn collect_early_args(args: impl IntoIterator<Item = impl AsRef<OsStr>>) -> Vec<OsString> {
412 401 let mut args_iter = args.into_iter();
413 402 let mut early_args = Vec::new();
414 403 while let Some(arg) = args_iter.next() {
415 404 let argb = arg.as_ref().as_bytes();
416 405 if argb == b"--" {
417 406 break;
418 407 } else if argb.starts_with(b"--") {
419 408 let mut split = argb[2..].splitn(2, |&c| c == b'=');
420 409 match split.next().unwrap() {
421 410 b"traceback" => {
422 411 if split.next().is_none() {
423 412 early_args.push(arg.as_ref().to_owned());
424 413 }
425 414 }
426 415 b"config" | b"cwd" | b"repo" | b"repository" => {
427 416 if split.next().is_some() {
428 417 // --<flag>=<val>
429 418 early_args.push(arg.as_ref().to_owned());
430 419 } else {
431 420 // --<flag> <val>
432 421 args_iter.next().map(|val| {
433 422 early_args.push(arg.as_ref().to_owned());
434 423 early_args.push(val.as_ref().to_owned());
435 424 });
436 425 }
437 426 }
438 427 _ => {}
439 428 }
440 429 } else if argb.starts_with(b"-R") {
441 430 if argb.len() > 2 {
442 431 // -R<val>
443 432 early_args.push(arg.as_ref().to_owned());
444 433 } else {
445 434 // -R <val>
446 435 args_iter.next().map(|val| {
447 436 early_args.push(arg.as_ref().to_owned());
448 437 early_args.push(val.as_ref().to_owned());
449 438 });
450 439 }
451 440 }
452 441 }
453 442
454 443 early_args
455 444 }
456 445
457 446 #[cfg(test)]
458 447 mod tests {
459 448 use super::*;
460 449
461 450 #[test]
462 451 fn collect_early_args_some() {
463 452 assert!(collect_early_args(&[] as &[&OsStr]).is_empty());
464 453 assert!(collect_early_args(&["log"]).is_empty());
465 454 assert_eq!(
466 455 collect_early_args(&["log", "-Ra", "foo"]),
467 456 os_string_vec_from(&[b"-Ra"])
468 457 );
469 458 assert_eq!(
470 459 collect_early_args(&["log", "-R", "repo", "", "--traceback", "a"]),
471 460 os_string_vec_from(&[b"-R", b"repo", b"--traceback"])
472 461 );
473 462 assert_eq!(
474 463 collect_early_args(&["log", "--config", "diff.git=1", "-q"]),
475 464 os_string_vec_from(&[b"--config", b"diff.git=1"])
476 465 );
477 466 assert_eq!(
478 467 collect_early_args(&["--cwd=..", "--repository", "r", "log"]),
479 468 os_string_vec_from(&[b"--cwd=..", b"--repository", b"r"])
480 469 );
481 470 assert_eq!(
482 471 collect_early_args(&["log", "--repo=r", "--repos", "a"]),
483 472 os_string_vec_from(&[b"--repo=r"])
484 473 );
485 474 }
486 475
487 476 #[test]
488 477 fn collect_early_args_orphaned() {
489 478 assert!(collect_early_args(&["log", "-R"]).is_empty());
490 479 assert!(collect_early_args(&["log", "--config"]).is_empty());
491 480 }
492 481
493 482 #[test]
494 483 fn collect_early_args_unwanted_value() {
495 484 assert!(collect_early_args(&["log", "--traceback="]).is_empty());
496 485 }
497 486
498 487 fn os_string_vec_from(v: &[&[u8]]) -> Vec<OsString> {
499 488 v.iter().map(|s| OsStr::from_bytes(s).to_owned()).collect()
500 489 }
501 490 }
@@ -1,317 +1,309 b''
1 1 // Copyright 2018 Yuya Nishihara <yuya@tcha.org>
2 2 //
3 3 // This software may be used and distributed according to the terms of the
4 4 // GNU General Public License version 2 or any later version.
5 5
6 6 //! Utility for parsing and building command-server messages.
7 7
8 8 use bytes::{BufMut, Bytes, BytesMut};
9 9 use std::error;
10 10 use std::ffi::{OsStr, OsString};
11 11 use std::io;
12 12 use std::os::unix::ffi::OsStrExt;
13 13 use std::path::PathBuf;
14 14
15 15 pub use tokio_hglib::message::*; // re-exports
16 16
17 17 /// Shell command type requested by the server.
18 18 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
19 19 pub enum CommandType {
20 20 /// Pager should be spawned.
21 21 Pager,
22 22 /// Shell command should be executed to send back the result code.
23 23 System,
24 24 }
25 25
26 26 /// Shell command requested by the server.
27 27 #[derive(Clone, Debug, Eq, PartialEq)]
28 28 pub struct CommandSpec {
29 29 pub command: OsString,
30 30 pub current_dir: OsString,
31 31 pub envs: Vec<(OsString, OsString)>,
32 32 }
33 33
34 34 /// Parses "S" channel request into command type and spec.
35 35 pub fn parse_command_spec(data: Bytes) -> io::Result<(CommandType, CommandSpec)> {
36 36 let mut split = data.split(|&c| c == b'\0');
37 37 let ctype = parse_command_type(split.next().ok_or(new_parse_error("missing type"))?)?;
38 38 let command = split.next().ok_or(new_parse_error("missing command"))?;
39 39 let current_dir = split.next().ok_or(new_parse_error("missing current dir"))?;
40 40
41 41 let mut envs = Vec::new();
42 42 for l in split {
43 43 let mut s = l.splitn(2, |&c| c == b'=');
44 44 let k = s.next().unwrap();
45 45 let v = s.next().ok_or(new_parse_error("malformed env"))?;
46 46 envs.push((
47 47 OsStr::from_bytes(k).to_owned(),
48 48 OsStr::from_bytes(v).to_owned(),
49 49 ));
50 50 }
51 51
52 52 let spec = CommandSpec {
53 53 command: OsStr::from_bytes(command).to_owned(),
54 54 current_dir: OsStr::from_bytes(current_dir).to_owned(),
55 55 envs: envs,
56 56 };
57 57 Ok((ctype, spec))
58 58 }
59 59
60 60 fn parse_command_type(value: &[u8]) -> io::Result<CommandType> {
61 61 match value {
62 62 b"pager" => Ok(CommandType::Pager),
63 63 b"system" => Ok(CommandType::System),
64 64 _ => Err(new_parse_error(format!(
65 65 "unknown command type: {}",
66 66 decode_latin1(value)
67 67 ))),
68 68 }
69 69 }
70 70
71 71 /// Client-side instruction requested by the server.
72 72 #[derive(Clone, Debug, Eq, PartialEq)]
73 73 pub enum Instruction {
74 74 Exit(i32),
75 75 Reconnect,
76 76 Redirect(PathBuf),
77 77 Unlink(PathBuf),
78 78 }
79 79
80 80 /// Parses validation result into instructions.
81 81 pub fn parse_instructions(data: Bytes) -> io::Result<Vec<Instruction>> {
82 82 let mut instructions = Vec::new();
83 83 for l in data.split(|&c| c == b'\0') {
84 84 if l.is_empty() {
85 85 continue;
86 86 }
87 87 let mut s = l.splitn(2, |&c| c == b' ');
88 88 let inst = match (s.next().unwrap(), s.next()) {
89 89 (b"exit", Some(arg)) => decode_latin1(arg)
90 90 .parse()
91 91 .map(Instruction::Exit)
92 92 .map_err(|_| new_parse_error(format!("invalid exit code: {:?}", arg)))?,
93 93 (b"reconnect", None) => Instruction::Reconnect,
94 94 (b"redirect", Some(arg)) => {
95 95 Instruction::Redirect(OsStr::from_bytes(arg).to_owned().into())
96 96 }
97 97 (b"unlink", Some(arg)) => Instruction::Unlink(OsStr::from_bytes(arg).to_owned().into()),
98 98 _ => {
99 99 return Err(new_parse_error(format!("unknown command: {:?}", l)));
100 100 }
101 101 };
102 102 instructions.push(inst);
103 103 }
104 104 Ok(instructions)
105 105 }
106 106
107 107 // allocate large buffer as environment variables can be quite long
108 108 const INITIAL_PACKED_ENV_VARS_CAPACITY: usize = 4096;
109 109
110 110 /// Packs environment variables of platform encoding into bytes.
111 111 ///
112 112 /// # Panics
113 113 ///
114 114 /// Panics if key or value contains `\0` character, or key contains '='
115 115 /// character.
116 pub fn pack_env_vars_os<I, P>(vars: I) -> Bytes
117 where
118 I: IntoIterator<Item = (P, P)>,
119 P: AsRef<OsStr>,
120 {
116 pub fn pack_env_vars_os(
117 vars: impl IntoIterator<Item = (impl AsRef<OsStr>, impl AsRef<OsStr>)>,
118 ) -> Bytes {
121 119 let mut vars_iter = vars.into_iter();
122 120 if let Some((k, v)) = vars_iter.next() {
123 121 let mut dst = BytesMut::with_capacity(INITIAL_PACKED_ENV_VARS_CAPACITY);
124 122 pack_env_into(&mut dst, k.as_ref(), v.as_ref());
125 123 for (k, v) in vars_iter {
126 124 dst.reserve(1);
127 125 dst.put_u8(b'\0');
128 126 pack_env_into(&mut dst, k.as_ref(), v.as_ref());
129 127 }
130 128 dst.freeze()
131 129 } else {
132 130 Bytes::new()
133 131 }
134 132 }
135 133
136 134 fn pack_env_into(dst: &mut BytesMut, k: &OsStr, v: &OsStr) {
137 135 assert!(!k.as_bytes().contains(&0), "key shouldn't contain NUL");
138 136 assert!(!k.as_bytes().contains(&b'='), "key shouldn't contain '='");
139 137 assert!(!v.as_bytes().contains(&0), "value shouldn't contain NUL");
140 138 dst.reserve(k.as_bytes().len() + 1 + v.as_bytes().len());
141 139 dst.put_slice(k.as_bytes());
142 140 dst.put_u8(b'=');
143 141 dst.put_slice(v.as_bytes());
144 142 }
145 143
146 fn decode_latin1<S>(s: S) -> String
147 where
148 S: AsRef<[u8]>,
149 {
144 fn decode_latin1(s: impl AsRef<[u8]>) -> String {
150 145 s.as_ref().iter().map(|&c| c as char).collect()
151 146 }
152 147
153 fn new_parse_error<E>(error: E) -> io::Error
154 where
155 E: Into<Box<dyn error::Error + Send + Sync>>,
156 {
148 fn new_parse_error(error: impl Into<Box<dyn error::Error + Send + Sync>>) -> io::Error {
157 149 io::Error::new(io::ErrorKind::InvalidData, error)
158 150 }
159 151
160 152 #[cfg(test)]
161 153 mod tests {
162 154 use super::*;
163 155 use std::os::unix::ffi::OsStringExt;
164 156 use std::panic;
165 157
166 158 #[test]
167 159 fn parse_command_spec_good() {
168 160 let src = [
169 161 b"pager".as_ref(),
170 162 b"less -FRX".as_ref(),
171 163 b"/tmp".as_ref(),
172 164 b"LANG=C".as_ref(),
173 165 b"HGPLAIN=".as_ref(),
174 166 ]
175 167 .join(&0);
176 168 let spec = CommandSpec {
177 169 command: os_string_from(b"less -FRX"),
178 170 current_dir: os_string_from(b"/tmp"),
179 171 envs: vec![
180 172 (os_string_from(b"LANG"), os_string_from(b"C")),
181 173 (os_string_from(b"HGPLAIN"), os_string_from(b"")),
182 174 ],
183 175 };
184 176 assert_eq!(
185 177 parse_command_spec(Bytes::from(src)).unwrap(),
186 178 (CommandType::Pager, spec)
187 179 );
188 180 }
189 181
190 182 #[test]
191 183 fn parse_command_spec_too_short() {
192 184 assert!(parse_command_spec(Bytes::from_static(b"")).is_err());
193 185 assert!(parse_command_spec(Bytes::from_static(b"pager")).is_err());
194 186 assert!(parse_command_spec(Bytes::from_static(b"pager\0less")).is_err());
195 187 }
196 188
197 189 #[test]
198 190 fn parse_command_spec_malformed_env() {
199 191 assert!(parse_command_spec(Bytes::from_static(b"pager\0less\0/tmp\0HOME")).is_err());
200 192 }
201 193
202 194 #[test]
203 195 fn parse_command_spec_unknown_type() {
204 196 assert!(parse_command_spec(Bytes::from_static(b"paper\0less")).is_err());
205 197 }
206 198
207 199 #[test]
208 200 fn parse_instructions_good() {
209 201 let src = [
210 202 b"exit 123".as_ref(),
211 203 b"reconnect".as_ref(),
212 204 b"redirect /whatever".as_ref(),
213 205 b"unlink /someother".as_ref(),
214 206 ]
215 207 .join(&0);
216 208 let insts = vec![
217 209 Instruction::Exit(123),
218 210 Instruction::Reconnect,
219 211 Instruction::Redirect(path_buf_from(b"/whatever")),
220 212 Instruction::Unlink(path_buf_from(b"/someother")),
221 213 ];
222 214 assert_eq!(parse_instructions(Bytes::from(src)).unwrap(), insts);
223 215 }
224 216
225 217 #[test]
226 218 fn parse_instructions_empty() {
227 219 assert_eq!(parse_instructions(Bytes::new()).unwrap(), vec![]);
228 220 assert_eq!(
229 221 parse_instructions(Bytes::from_static(b"\0")).unwrap(),
230 222 vec![]
231 223 );
232 224 }
233 225
234 226 #[test]
235 227 fn parse_instructions_malformed_exit_code() {
236 228 assert!(parse_instructions(Bytes::from_static(b"exit foo")).is_err());
237 229 }
238 230
239 231 #[test]
240 232 fn parse_instructions_missing_argument() {
241 233 assert!(parse_instructions(Bytes::from_static(b"exit")).is_err());
242 234 assert!(parse_instructions(Bytes::from_static(b"redirect")).is_err());
243 235 assert!(parse_instructions(Bytes::from_static(b"unlink")).is_err());
244 236 }
245 237
246 238 #[test]
247 239 fn parse_instructions_unknown_command() {
248 240 assert!(parse_instructions(Bytes::from_static(b"quit 0")).is_err());
249 241 }
250 242
251 243 #[test]
252 244 fn pack_env_vars_os_good() {
253 245 assert_eq!(
254 246 pack_env_vars_os(vec![] as Vec<(OsString, OsString)>),
255 247 Bytes::new()
256 248 );
257 249 assert_eq!(
258 250 pack_env_vars_os(vec![os_string_pair_from(b"FOO", b"bar")]),
259 251 Bytes::from_static(b"FOO=bar")
260 252 );
261 253 assert_eq!(
262 254 pack_env_vars_os(vec![
263 255 os_string_pair_from(b"FOO", b""),
264 256 os_string_pair_from(b"BAR", b"baz")
265 257 ]),
266 258 Bytes::from_static(b"FOO=\0BAR=baz")
267 259 );
268 260 }
269 261
270 262 #[test]
271 263 fn pack_env_vars_os_large_key() {
272 264 let mut buf = vec![b'A'; INITIAL_PACKED_ENV_VARS_CAPACITY];
273 265 let envs = vec![os_string_pair_from(&buf, b"")];
274 266 buf.push(b'=');
275 267 assert_eq!(pack_env_vars_os(envs), Bytes::from(buf));
276 268 }
277 269
278 270 #[test]
279 271 fn pack_env_vars_os_large_value() {
280 272 let mut buf = vec![b'A', b'='];
281 273 buf.resize(INITIAL_PACKED_ENV_VARS_CAPACITY + 1, b'a');
282 274 let envs = vec![os_string_pair_from(&buf[..1], &buf[2..])];
283 275 assert_eq!(pack_env_vars_os(envs), Bytes::from(buf));
284 276 }
285 277
286 278 #[test]
287 279 fn pack_env_vars_os_nul_eq() {
288 280 assert!(panic::catch_unwind(|| {
289 281 pack_env_vars_os(vec![os_string_pair_from(b"\0", b"")])
290 282 })
291 283 .is_err());
292 284 assert!(panic::catch_unwind(|| {
293 285 pack_env_vars_os(vec![os_string_pair_from(b"FOO", b"\0bar")])
294 286 })
295 287 .is_err());
296 288 assert!(panic::catch_unwind(|| {
297 289 pack_env_vars_os(vec![os_string_pair_from(b"FO=", b"bar")])
298 290 })
299 291 .is_err());
300 292 assert_eq!(
301 293 pack_env_vars_os(vec![os_string_pair_from(b"FOO", b"=ba")]),
302 294 Bytes::from_static(b"FOO==ba")
303 295 );
304 296 }
305 297
306 298 fn os_string_from(s: &[u8]) -> OsString {
307 299 OsString::from_vec(s.to_vec())
308 300 }
309 301
310 302 fn os_string_pair_from(k: &[u8], v: &[u8]) -> (OsString, OsString) {
311 303 (os_string_from(k), os_string_from(v))
312 304 }
313 305
314 306 fn path_buf_from(s: &[u8]) -> PathBuf {
315 307 os_string_from(s).into()
316 308 }
317 309 }
General Comments 0
You need to be logged in to leave comments. Login now