James Moger
2014-03-14 503a853acad49ac6da7f520c26b3b27942dbfec5
src/main/java/com/gitblit/transport/ssh/commands/BaseCommand.java
@@ -14,10 +14,13 @@
package com.gitblit.transport.ssh.commands;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicReference;
@@ -26,13 +29,14 @@
import org.apache.sshd.server.Command;
import org.apache.sshd.server.Environment;
import org.apache.sshd.server.ExitCallback;
import org.apache.sshd.server.SessionAware;
import org.apache.sshd.server.session.ServerSession;
import org.kohsuke.args4j.Argument;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.Option;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.gitblit.transport.ssh.AbstractSshCommand;
import com.gitblit.transport.ssh.SshCommandContext;
import com.gitblit.utils.IdGenerator;
import com.gitblit.utils.WorkQueue;
@@ -41,398 +45,439 @@
import com.google.common.base.Charsets;
import com.google.common.util.concurrent.Atomics;
public abstract class BaseCommand extends AbstractSshCommand {
  private static final Logger log = LoggerFactory
      .getLogger(BaseCommand.class);
public abstract class BaseCommand implements Command, SessionAware {
  /** Text of the command line which lead up to invoking this instance. */
  private String commandName = "";
   private static final Logger log = LoggerFactory.getLogger(BaseCommand.class);
  /** Unparsed command line options. */
  private String[] argv;
   /** Ssh context */
   protected SshCommandContext ctx;
  /** Ssh context */
  protected SshCommandContext ctx;
   protected InputStream in;
  /** The task, as scheduled on a worker thread. */
  private final AtomicReference<Future<?>> task;
   protected OutputStream out;
  private final WorkQueue.Executor executor;
   protected OutputStream err;
  public BaseCommand() {
    task = Atomics.newReference();
    IdGenerator gen = new IdGenerator();
    WorkQueue w = new WorkQueue(gen);
    this.executor = w.getDefaultQueue();
  }
   protected ExitCallback exit;
  public void setContext(SshCommandContext ctx) {
   this.ctx = ctx;
  }
   protected ServerSession session;
  public void setInputStream(final InputStream in) {
    this.in = in;
  }
   /** Text of the command line which lead up to invoking this instance. */
   private String commandName = "";
  public void setOutputStream(final OutputStream out) {
    this.out = out;
  }
   /** Unparsed command line options. */
   private String[] argv;
  public void setErrorStream(final OutputStream err) {
    this.err = err;
  }
   /** The task, as scheduled on a worker thread. */
   private final AtomicReference<Future<?>> task;
  public void setExitCallback(final ExitCallback callback) {
    this.exit = callback;
  }
   private final WorkQueue.Executor executor;
  protected void provideBaseStateTo(final Command cmd) {
    if (cmd instanceof BaseCommand) {
     ((BaseCommand)cmd).setContext(ctx);
    }
    cmd.setInputStream(in);
    cmd.setOutputStream(out);
    cmd.setErrorStream(err);
    cmd.setExitCallback(exit);
  }
   public BaseCommand() {
      task = Atomics.newReference();
      IdGenerator gen = new IdGenerator();
      WorkQueue w = new WorkQueue(gen);
      this.executor = w.getDefaultQueue();
   }
  protected String getName() {
    return commandName;
  }
   @Override
   public void setSession(final ServerSession session) {
      this.session = session;
   }
  void setName(final String prefix) {
    this.commandName = prefix;
  }
   @Override
   public void destroy() {
   }
  public String[] getArguments() {
    return argv;
  }
   protected static PrintWriter toPrintWriter(final OutputStream o) {
      return new PrintWriter(new BufferedWriter(new OutputStreamWriter(o, Charsets.UTF_8)));
   }
  public void setArguments(final String[] argv) {
    this.argv = argv;
  }
   @Override
   public abstract void start(Environment env) throws IOException;
  /**
   * Parses the command line argument, injecting parsed values into fields.
   * <p>
   * This method must be explicitly invoked to cause a parse.
   *
   * @throws UnloggedFailure if the command line arguments were invalid.
   * @see Option
   * @see Argument
   */
  protected void parseCommandLine() throws UnloggedFailure {
    parseCommandLine(this);
  }
   public void setContext(SshCommandContext ctx) {
      this.ctx = ctx;
   }
  /**
   * Parses the command line argument, injecting parsed values into fields.
   * <p>
   * This method must be explicitly invoked to cause a parse.
   *
   * @param options object whose fields declare Option and Argument annotations
   *        to describe the parameters of the command. Usually {@code this}.
   * @throws UnloggedFailure if the command line arguments were invalid.
   * @see Option
   * @see Argument
   */
  protected void parseCommandLine(Object options) throws UnloggedFailure {
    final CmdLineParser clp = newCmdLineParser(options);
    try {
      clp.parseArgument(argv);
    } catch (IllegalArgumentException err) {
      if (!clp.wasHelpRequestedByOption()) {
        throw new UnloggedFailure(1, "fatal: " + err.getMessage());
      }
    } catch (CmdLineException err) {
      if (!clp.wasHelpRequestedByOption()) {
        throw new UnloggedFailure(1, "fatal: " + err.getMessage());
      }
    }
   @Override
   public void setInputStream(final InputStream in) {
      this.in = in;
   }
    if (clp.wasHelpRequestedByOption()) {
      StringWriter msg = new StringWriter();
      clp.printDetailedUsage(commandName, msg);
      msg.write(usage());
      throw new UnloggedFailure(1, msg.toString());
    }
  }
   @Override
   public void setOutputStream(final OutputStream out) {
      this.out = out;
   }
  /** Construct a new parser for this command's received command line. */
  protected CmdLineParser newCmdLineParser(Object options) {
    return new CmdLineParser(options);
  }
   @Override
   public void setErrorStream(final OutputStream err) {
      this.err = err;
   }
  protected String usage() {
    return "";
  }
   @Override
   public void setExitCallback(final ExitCallback callback) {
      this.exit = callback;
   }
  private final class TaskThunk implements CancelableRunnable {
    private final CommandRunnable thunk;
    private final String taskName;
   protected void provideBaseStateTo(final Command cmd) {
      if (cmd instanceof BaseCommand) {
         ((BaseCommand) cmd).setContext(ctx);
      }
      cmd.setInputStream(in);
      cmd.setOutputStream(out);
      cmd.setErrorStream(err);
      cmd.setExitCallback(exit);
   }
    private TaskThunk(final CommandRunnable thunk) {
      this.thunk = thunk;
   protected String getName() {
      return commandName;
   }
      StringBuilder m = new StringBuilder();
      m.append(ctx.getCommandLine());
      this.taskName = m.toString();
    }
   void setName(final String prefix) {
      this.commandName = prefix;
   }
    @Override
    public void cancel() {
      synchronized (this) {
        try {
          //onExit(/*STATUS_CANCEL*/);
        } finally {
          ctx = null;
        }
      }
    }
   public String[] getArguments() {
      return argv;
   }
    @Override
    public void run() {
      synchronized (this) {
        final Thread thisThread = Thread.currentThread();
        final String thisName = thisThread.getName();
        int rc = 0;
        try {
          thisThread.setName("SSH " + taskName);
          thunk.run();
   public void setArguments(final String[] argv) {
      this.argv = argv;
   }
          out.flush();
          err.flush();
        } catch (Throwable e) {
          try {
            out.flush();
          } catch (Throwable e2) {
          }
          try {
            err.flush();
          } catch (Throwable e2) {
          }
          rc = handleError(e);
        } finally {
          try {
            onExit(rc);
          } finally {
            thisThread.setName(thisName);
          }
        }
      }
    }
   /**
    * Parses the command line argument, injecting parsed values into fields.
    * <p>
    * This method must be explicitly invoked to cause a parse.
    *
    * @throws UnloggedFailure
    *             if the command line arguments were invalid.
    * @see Option
    * @see Argument
    */
   protected void parseCommandLine() throws UnloggedFailure {
      parseCommandLine(this);
   }
    @Override
    public String toString() {
      return taskName;
    }
  }
   /**
    * Parses the command line argument, injecting parsed values into fields.
    * <p>
    * This method must be explicitly invoked to cause a parse.
    *
    * @param options
    *            object whose fields declare Option and Argument annotations to
    *            describe the parameters of the command. Usually {@code this}.
    * @throws UnloggedFailure
    *             if the command line arguments were invalid.
    * @see Option
    * @see Argument
    */
   protected void parseCommandLine(Object options) throws UnloggedFailure {
      final CmdLineParser clp = newCmdLineParser(options);
      try {
         clp.parseArgument(argv);
      } catch (IllegalArgumentException err) {
         if (!clp.wasHelpRequestedByOption()) {
            throw new UnloggedFailure(1, "fatal: " + err.getMessage());
         }
      } catch (CmdLineException err) {
         if (!clp.wasHelpRequestedByOption()) {
            throw new UnloggedFailure(1, "fatal: " + err.getMessage());
         }
      }
  /** Runnable function which can throw an exception. */
  public static interface CommandRunnable {
    public void run() throws Exception;
  }
      if (clp.wasHelpRequestedByOption()) {
         StringWriter msg = new StringWriter();
         clp.printDetailedUsage(commandName, msg);
         msg.write(usage());
         throw new UnloggedFailure(1, msg.toString());
      }
   }
   /** Construct a new parser for this command's received command line. */
   protected CmdLineParser newCmdLineParser(Object options) {
      return new CmdLineParser(options);
   }
  /** Runnable function which can retrieve a project name related to the task */
  public static interface RepositoryCommandRunnable extends CommandRunnable {
   public String getRepository();
  }
   protected String usage() {
      return "";
   }
  /**
   * Spawn a function into its own thread.
   * <p>
   * Typically this should be invoked within {@link Command#start(Environment)},
   * such as:
   *
   * <pre>
   * startThread(new Runnable() {
   *   public void run() {
   *     runImp();
   *   }
   * });
   * </pre>
   *
   * @param thunk the runnable to execute on the thread, performing the
   *        command's logic.
   */
  protected void startThread(final Runnable thunk) {
    startThread(new CommandRunnable() {
      @Override
      public void run() throws Exception {
        thunk.run();
      }
    });
  }
   private final class TaskThunk implements CancelableRunnable {
      private final CommandRunnable thunk;
      private final String taskName;
  /**
   * Terminate this command and return a result code to the remote client.
   * <p>
   * Commands should invoke this at most once. Once invoked, the command may
   * lose access to request based resources as any callbacks previously
   * registered with {@link RequestCleanup} will fire.
   *
   * @param rc exit code for the remote client.
   */
  protected void onExit(final int rc) {
    exit.onExit(rc);
//    if (cleanup != null) {
//      cleanup.run();
//    }
  }
      private TaskThunk(final CommandRunnable thunk) {
         this.thunk = thunk;
  private int handleError(final Throwable e) {
    if ((e.getClass() == IOException.class
         && "Pipe closed".equals(e.getMessage()))
        || //
        (e.getClass() == SshException.class
         && "Already closed".equals(e.getMessage()))
        || //
        e.getClass() == InterruptedIOException.class) {
      // This is sshd telling us the client just dropped off while
      // we were waiting for a read or a write to complete. Either
      // way its not really a fatal error. Don't log it.
      //
      return 127;
    }
         StringBuilder m = new StringBuilder();
         m.append(ctx.getCommandLine());
         this.taskName = m.toString();
      }
    if (e instanceof UnloggedFailure) {
    } else {
      final StringBuilder m = new StringBuilder();
      m.append("Internal server error");
//      if (userProvider.get().isIdentifiedUser()) {
//        final IdentifiedUser u = (IdentifiedUser) userProvider.get();
//        m.append(" (user ");
//        m.append(u.getAccount().getUserName());
//        m.append(" account ");
//        m.append(u.getAccountId());
//        m.append(")");
//      }
//      m.append(" during ");
//      m.append(contextProvider.get().getCommandLine());
      log.error(m.toString(), e);
    }
      @Override
      public void cancel() {
         synchronized (this) {
            try {
               // onExit(/*STATUS_CANCEL*/);
            } finally {
               ctx = null;
            }
         }
      }
    if (e instanceof Failure) {
      final Failure f = (Failure) e;
      try {
        err.write((f.getMessage() + "\n").getBytes(Charsets.UTF_8));
        err.flush();
      } catch (IOException e2) {
      } catch (Throwable e2) {
        log.warn("Cannot send failure message to client", e2);
      }
      return f.exitCode;
      @Override
      public void run() {
         synchronized (this) {
            final Thread thisThread = Thread.currentThread();
            final String thisName = thisThread.getName();
            int rc = 0;
            try {
               thisThread.setName("SSH " + taskName);
               thunk.run();
    } else {
      try {
        err.write("fatal: internal server error\n".getBytes(Charsets.UTF_8));
        err.flush();
      } catch (IOException e2) {
      } catch (Throwable e2) {
        log.warn("Cannot send internal server error message to client", e2);
      }
      return 128;
    }
  }
               out.flush();
               err.flush();
            } catch (Throwable e) {
               try {
                  out.flush();
               } catch (Throwable e2) {
               }
               try {
                  err.flush();
               } catch (Throwable e2) {
               }
               rc = handleError(e);
            } finally {
               try {
                  onExit(rc);
               } finally {
                  thisThread.setName(thisName);
               }
            }
         }
      }
  /**
   * Spawn a function into its own thread.
   * <p>
   * Typically this should be invoked within {@link Command#start(Environment)},
   * such as:
   *
   * <pre>
   * startThread(new CommandRunnable() {
   *   public void run() throws Exception {
   *     runImp();
   *   }
   * });
   * </pre>
   * <p>
   * If the function throws an exception, it is translated to a simple message
   * for the client, a non-zero exit code, and the stack trace is logged.
   *
   * @param thunk the runnable to execute on the thread, performing the
   *        command's logic.
   */
  protected void startThread(final CommandRunnable thunk) {
    final TaskThunk tt = new TaskThunk(thunk);
      task.set(executor.submit(tt));
  }
      @Override
      public String toString() {
         return taskName;
      }
   }
  /** Thrown from {@link CommandRunnable#run()} with client message and code. */
  public static class Failure extends Exception {
    private static final long serialVersionUID = 1L;
   /** Runnable function which can throw an exception. */
   public static interface CommandRunnable {
      public void run() throws Exception;
   }
    final int exitCode;
   /** Runnable function which can retrieve a project name related to the task */
   public static interface RepositoryCommandRunnable extends CommandRunnable {
      public String getRepository();
   }
    /**
     * Create a new failure.
     *
     * @param exitCode exit code to return the client, which indicates the
     *        failure status of this command. Should be between 1 and 255,
     *        inclusive.
     * @param msg message to also send to the client's stderr.
     */
    public Failure(final int exitCode, final String msg) {
      this(exitCode, msg, null);
    }
   /**
    * Spawn a function into its own thread.
    * <p>
    * Typically this should be invoked within
    * {@link Command#start(Environment)}, such as:
    *
    * <pre>
    * startThread(new Runnable() {
    *    public void run() {
    *       runImp();
    *    }
    * });
    * </pre>
    *
    * @param thunk
    *            the runnable to execute on the thread, performing the
    *            command's logic.
    */
   protected void startThread(final Runnable thunk) {
      startThread(new CommandRunnable() {
         @Override
         public void run() throws Exception {
            thunk.run();
         }
      });
   }
    /**
     * Create a new failure.
     *
     * @param exitCode exit code to return the client, which indicates the
     *        failure status of this command. Should be between 1 and 255,
     *        inclusive.
     * @param msg message to also send to the client's stderr.
     * @param why stack trace to include in the server's log, but is not sent to
     *        the client's stderr.
     */
    public Failure(final int exitCode, final String msg, final Throwable why) {
      super(msg, why);
      this.exitCode = exitCode;
    }
  }
   /**
    * Terminate this command and return a result code to the remote client.
    * <p>
    * Commands should invoke this at most once. Once invoked, the command may
    * lose access to request based resources as any callbacks previously
    * registered with {@link RequestCleanup} will fire.
    *
    * @param rc
    *            exit code for the remote client.
    */
   protected void onExit(final int rc) {
      exit.onExit(rc);
      // if (cleanup != null) {
      // cleanup.run();
      // }
   }
  /** Thrown from {@link CommandRunnable#run()} with client message and code. */
  public static class UnloggedFailure extends Failure {
    private static final long serialVersionUID = 1L;
   private int handleError(final Throwable e) {
      if ((e.getClass() == IOException.class && "Pipe closed".equals(e.getMessage())) || //
            (e.getClass() == SshException.class && "Already closed".equals(e.getMessage())) || //
            e.getClass() == InterruptedIOException.class) {
         // This is sshd telling us the client just dropped off while
         // we were waiting for a read or a write to complete. Either
         // way its not really a fatal error. Don't log it.
         //
         return 127;
      }
    /**
     * Create a new failure.
     *
     * @param msg message to also send to the client's stderr.
     */
    public UnloggedFailure(final String msg) {
      this(1, msg);
    }
      if (e instanceof UnloggedFailure) {
      } else {
         final StringBuilder m = new StringBuilder();
         m.append("Internal server error");
         // if (userProvider.get().isIdentifiedUser()) {
         // final IdentifiedUser u = (IdentifiedUser) userProvider.get();
         // m.append(" (user ");
         // m.append(u.getAccount().getUserName());
         // m.append(" account ");
         // m.append(u.getAccountId());
         // m.append(")");
         // }
         // m.append(" during ");
         // m.append(contextProvider.get().getCommandLine());
         log.error(m.toString(), e);
      }
    /**
     * Create a new failure.
     *
     * @param exitCode exit code to return the client, which indicates the
     *        failure status of this command. Should be between 1 and 255,
     *        inclusive.
     * @param msg message to also send to the client's stderr.
     */
    public UnloggedFailure(final int exitCode, final String msg) {
      this(exitCode, msg, null);
    }
      if (e instanceof Failure) {
         final Failure f = (Failure) e;
         try {
            err.write((f.getMessage() + "\n").getBytes(Charsets.UTF_8));
            err.flush();
         } catch (IOException e2) {
         } catch (Throwable e2) {
            log.warn("Cannot send failure message to client", e2);
         }
         return f.exitCode;
    /**
     * Create a new failure.
     *
     * @param exitCode exit code to return the client, which indicates the
     *        failure status of this command. Should be between 1 and 255,
     *        inclusive.
     * @param msg message to also send to the client's stderr.
     * @param why stack trace to include in the server's log, but is not sent to
     *        the client's stderr.
     */
    public UnloggedFailure(final int exitCode, final String msg,
        final Throwable why) {
      super(exitCode, msg, why);
    }
  }
      } else {
         try {
            err.write("fatal: internal server error\n".getBytes(Charsets.UTF_8));
            err.flush();
         } catch (IOException e2) {
         } catch (Throwable e2) {
            log.warn("Cannot send internal server error message to client", e2);
         }
         return 128;
      }
   }
   /**
    * Spawn a function into its own thread.
    * <p>
    * Typically this should be invoked within
    * {@link Command#start(Environment)}, such as:
    *
    * <pre>
    * startThread(new CommandRunnable() {
    *    public void run() throws Exception {
    *       runImp();
    *    }
    * });
    * </pre>
    * <p>
    * If the function throws an exception, it is translated to a simple message
    * for the client, a non-zero exit code, and the stack trace is logged.
    *
    * @param thunk
    *            the runnable to execute on the thread, performing the
    *            command's logic.
    */
   protected void startThread(final CommandRunnable thunk) {
      final TaskThunk tt = new TaskThunk(thunk);
      task.set(executor.submit(tt));
   }
   /** Thrown from {@link CommandRunnable#run()} with client message and code. */
   public static class Failure extends Exception {
      private static final long serialVersionUID = 1L;
      final int exitCode;
      /**
       * Create a new failure.
       *
       * @param exitCode
       *            exit code to return the client, which indicates the
       *            failure status of this command. Should be between 1 and
       *            255, inclusive.
       * @param msg
       *            message to also send to the client's stderr.
       */
      public Failure(final int exitCode, final String msg) {
         this(exitCode, msg, null);
      }
      /**
       * Create a new failure.
       *
       * @param exitCode
       *            exit code to return the client, which indicates the
       *            failure status of this command. Should be between 1 and
       *            255, inclusive.
       * @param msg
       *            message to also send to the client's stderr.
       * @param why
       *            stack trace to include in the server's log, but is not
       *            sent to the client's stderr.
       */
      public Failure(final int exitCode, final String msg, final Throwable why) {
         super(msg, why);
         this.exitCode = exitCode;
      }
   }
   /** Thrown from {@link CommandRunnable#run()} with client message and code. */
   public static class UnloggedFailure extends Failure {
      private static final long serialVersionUID = 1L;
      /**
       * Create a new failure.
       *
       * @param msg
       *            message to also send to the client's stderr.
       */
      public UnloggedFailure(final String msg) {
         this(1, msg);
      }
      /**
       * Create a new failure.
       *
       * @param exitCode
       *            exit code to return the client, which indicates the
       *            failure status of this command. Should be between 1 and
       *            255, inclusive.
       * @param msg
       *            message to also send to the client's stderr.
       */
      public UnloggedFailure(final int exitCode, final String msg) {
         this(exitCode, msg, null);
      }
      /**
       * Create a new failure.
       *
       * @param exitCode
       *            exit code to return the client, which indicates the
       *            failure status of this command. Should be between 1 and
       *            255, inclusive.
       * @param msg
       *            message to also send to the client's stderr.
       * @param why
       *            stack trace to include in the server's log, but is not
       *            sent to the client's stderr.
       */
      public UnloggedFailure(final int exitCode, final String msg, final Throwable why) {
         super(exitCode, msg, why);
      }
   }
}