diff --git a/src/actions/interactive_session.py b/src/actions/interactive_session.py index 6f652bdfb3781547b111bfcee5954a34a7354e7f..1110a4482d34130c7d799ae8916442b792a29825 100644 --- a/src/actions/interactive_session.py +++ b/src/actions/interactive_session.py @@ -16,10 +16,11 @@ from aiogram.types import ( from src.models.crud.server_crud import fn_get_server from src.models.server import ServerPermissions from src.settings import settings +from src.utils import chunks from src.utils.debug_mode import debug_message from src.utils.decorators import bot_action from src.utils.server_utils import get_server_by_alias -from src.utils.ssh import session_manager, run_ssh_command +from src.utils.ssh import session_manager logger = logging.getLogger() @@ -108,4 +109,10 @@ async def run_interactive_command(message: Message, state: FSMContext): f'Session id: {current_state["session_id"]}\n' ) result = await session_manager.run_command(current_state['session_id'], message.text) - await message.reply(f'```\n{result if result else "Nothing to show"}```', parse_mode=ParseMode.MARKDOWN) + if not result: + await message.reply(f'```\nNothing to show```', parse_mode=ParseMode.MARKDOWN) + return + + results = chunks(result, 4095) + for res in results: + await message.reply(f'```\n{res}```', parse_mode=ParseMode.MARKDOWN) diff --git a/src/utils/ssh.py b/src/utils/ssh.py index 3679296862ed84e49d4cfa51041c43f94e69fc3e..ff8a2c9107fc0c813b5bfeb43bc0ac63bb4ac42c 100644 --- a/src/utils/ssh.py +++ b/src/utils/ssh.py @@ -1,82 +1,42 @@ +import asyncio import logging -import re import uuid -import paramiko -from paramiko import SSHClient, ChannelFile +import asyncssh from src.models import Server logger = logging.getLogger() -async def open_ssh_session(server: Server) -> (SSHClient, ChannelFile, ChannelFile): - client = SSHClient() - client.load_system_host_keys() - client.set_missing_host_key_policy(paramiko.client.AutoAddPolicy) - client.connect( - hostname=server.server_address, - port=server.server_port, - username='root' - ) - channel = client.invoke_shell() - stdin = channel.makefile('wb') - stdout = channel.makefile('r') - - return client, stdin, stdout - - -async def run_interactive_command(command: str, client: SSHClient, std_in: ChannelFile, std_out: ChannelFile) -> str: - cmd = command.strip('\n') - std_in.write(cmd + '\n') - finish = 'EXIT STATUS: ' - echo_cmd = 'echo {} $?'.format(finish) - std_in.write(echo_cmd + '\n') - std_in.flush() - - sh_out = [] - sh_err = [] - for line in std_out: - if str(line).startswith(cmd) or str(line).startswith(echo_cmd): - # up for now filled with shell junk from stdin - shout = [] - elif str(line).startswith(finish): - # our finish command ends with the exit status - exit_status = int(str(line).rsplit(maxsplit=1)[1]) - if exit_status: - # stderr is combined with stdout. - # thus, swap sherr with shout in a case of failure. - sh_err = sh_out - sh_out = [] - break - else: - # get rid of 'coloring and formatting' special characters - sh_out.append(re.compile(r'(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]').sub('', line). - replace('\b', '').replace('\r', '')) - - # first and last lines of shout/sherr contain a prompt - if sh_out and echo_cmd in sh_out[-1]: - sh_out.pop() - if sh_out and cmd in sh_out[0]: - sh_out.pop(0) - if sh_err and echo_cmd in sh_err[-1]: - sh_err.pop() - if sh_err and cmd in sh_err[0]: - sh_err.pop(0) - if sh_err: - return '\n'.join(sh_err) - return '\n'.join(sh_out) +async def open_ssh_session(server: Server) -> asyncssh.SSHClientProcess: + connection = await asyncssh.connect(server.server_address, server.server_port) + process = await connection.create_process('/bin/bash') + return process + + +async def run_interactive_command(command: str, + process: asyncssh.SSHClientProcess, + timeout=0.5) -> str: + process.stdin.write(command + '\n') + res = [] + try: + line = await asyncio.wait_for(process.stdout.readline(), timeout) + res.append(line) + while line: + logger.debug(line) + res.append(await asyncio.wait_for(process.stdout.readline(), timeout)) + except asyncio.exceptions.TimeoutError as e: + logger.exception(e) + return '\n'.join(res).strip() + return '\n'.join(res).strip() async def run_ssh_command(server: Server, command: str) -> str: - session = await open_ssh_session(server) - client = session[0] - stdin, stdout, stderr = client.exec_command(command) - err = stderr.read() - if err: - raise Exception(err.decode('utf-8')) - out = stdout.read() - return out.decode('utf-8') + process = await open_ssh_session(server) + res = await run_interactive_command(command, process) + process.close() + return res class SessionManager(object): @@ -90,10 +50,10 @@ class SessionManager(object): return rand_uuid async def run_command(self, connection_id: str, command: str): - return await run_interactive_command(command, *self.__connections[connection_id]) + return await run_interactive_command(command, self.__connections[connection_id]) def close(self, connection_id: str): - self.__connections[connection_id][0].close() + self.__connections[connection_id].close() session_manager = SessionManager()