diff --git a/evering/__main__.py b/evering/__main__.py index 7f9bc8e..9044ccf 100644 --- a/evering/__main__.py +++ b/evering/__main__.py @@ -1,48 +1,90 @@ import argparse -from pathlib import Path -from typing import Union import logging +from pathlib import Path +from typing import Any +from .config import * +from .known_files import * +from .process import * +from .util import * +from .explore import * +from .prompt import * -from .colors import * - -logging.basicConfig(level=logging.DEBUG, style="{", format="{levelname:>7}: {message}") +#logging.basicConfig(level=logging.DEBUG, style="{", format="{levelname:>7}: {message}") +logging.basicConfig(level=logging.INFO, style="{", format="{levelname:>7}: {message}") logger = logging.getLogger(__name__) +HEADER_FILE_SUFFIX = ".evering-header" + +""" +(error) -> CatastrophicError +(warning) -> log message +(skip/abort) -> LessCatastrophicError + +- Load config + - no readable config file found (error) + - config file can't be found (error) + - config file can't be opened (error) + - config file contains invalid syntax (error) + +- Load known files + - known_files can't be read (error) + - known_files contains invalid syntax (error) + - known_files contains invalid data (error) +- Locate config files + header files + - missing permissions to view folders (warning) + - header file but no corresponding file (warning) -def command_test_func(args): - logger.debug(styled("Debug", BLUE.fg, BOLD)) - logger.info(styled("Info", GREEN.fg, BOLD)) - logger.warning(styled("Warning", YELLOW.fg, BOLD)) - logger.error(styled("Error", RED.fg, BOLD)) - logger.info(styled("Test", BRIGHT_BLACK.fg, BOLD)) +- Process files +Processing files +================ +Header problems: +- header file can't be read (skip/abort) +- invalid header syntax (skip/abort) +Config file problems: +- file can't be read (skip/abort) +- file contains no lines (warning) +- invalid config file syntax (skip/abort) +- error while compiling (skip/abort) +Writing problems: +- no targets (skip/abort) +- can't write/copy to target (warning) +- can't write to known files (error) +""" +def run(args: Any) -> None: + config = Config.load_config_file(args.config_file and Path(args.config_file) or None) + known_files = KnownFiles(config.known_files) + processor = Processor(config, known_files) + config_files = find_config_files(config.config_dir) + for file_info in config_files: + try: + processor.process_file(file_info.path, file_info.header) + except LessCatastrophicError as e: + logger.error(e) + if prompt_choice("[C]ontinue to the next file or [A]bort the program?", "Ca") == "a": + raise CatastrophicError("Aborted") - - -def main(): +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("-c", "--config-file") - subparsers = parser.add_subparsers(title="commands") - - command_test = subparsers.add_parser("test") - command_test.set_defaults(func=command_test_func) - command_test.add_argument("some_file") args = parser.parse_args() - if "func" in args: - args.func(args) - else: - parser.print_help() + try: + run(args) + except CatastrophicError as e: + logger.error(e) + except ConfigurationException as e: + logger.error(e) if __name__ == "__main__": main() diff --git a/evering/colors.py b/evering/colors.py index 12bf028..4cba858 100644 --- a/evering/colors.py +++ b/evering/colors.py @@ -3,8 +3,9 @@ This module includes functions to color the console output with ANSI escape sequences. """ -from typing import Optional, Tuple from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple, Union __all__ = [ "CSI", "ERASE_LINE", @@ -12,6 +13,7 @@ __all__ = [ "Color", "BLACK", "RED", "GREEN", "YELLOW", "BLUE", "MAGENTA", "CYAN", "WHITE", "BRIGHT_BLACK", "BRIGHT_RED", "BRIGHT_GREEN", "BRIGHT_YELLOW", "BRIGHT_BLUE", "BRIGHT_MAGENTA", "BRIGHT_CYAN", "BRIGHT_WHITE", "style_sequence", "styled", + "style_path", "style_var", "style_error", "style_warning", ] # ANSI escape sequences @@ -61,3 +63,17 @@ def styled(text: str, *args: int) -> str: return f"{sequence}{text}{reset}" else: return text # No styling necessary + +def style_path(path: Union[str, Path]) -> str: + if isinstance(path, Path): + path = str(path) + return styled(path, BRIGHT_BLACK.fg, BOLD) + +def style_var(text: str) -> str: + return styled(repr(text), BLUE.fg) + +def style_error(text: str) -> str: + return styled(text, RED.fg, BOLD) + +def style_warning(text: str) -> str: + return styled(text, YELLOW.fg, BOLD) diff --git a/evering/config.py b/evering/config.py index c575e66..78cbd26 100644 --- a/evering/config.py +++ b/evering/config.py @@ -6,38 +6,292 @@ The result of loading a config file are the "local" variables, including the modules loaded via "import". """ +import logging from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union -__all__ = [] +from .colors import * +from .util import * + +__all__ = [ + "DEFAULT_LOCATIONS", "DEFAULT_CONFIG_FILE", + "ConfigurationException", "Config", +] +logger = logging.getLogger(__name__) DEFAULT_LOCATIONS = [ - Path("~/.config/evering/config"), - Path("~/.evering/config"), - Path("~/.evering.conf"), + Path("~/.config/evering/config.py"), + Path("~/.evering/config.py"), + Path("~/.evering.py"), ] -class LoadConfigException(Exception): +DEFAULT_CONFIG_FILE = """ +known_files = "known_files" +config_dir = "config/" +binary = True +statement_prefix = "#" +expression_delimiters = ("{{", "}}") +""" + +class ConfigurationException(Exception): pass -def load_config(path: Path = None) -> Dict[str, Any]: - if path is not None: - return load_config_file(path) - else: - for path in DEFAULT_LOCATIONS: - try: - return load_config_file(path) - except LoadConfigException: - # Try the next default location - # TODO print a log message - pass - else: - raise LoadConfigException("no config file found in any of the default locations") +class Config: + @classmethod + def load_config_file(cls, path: Optional[Path]) -> "Config": + """ + May raise: ConfigurationException + """ -def load_config_file(path: Path) -> Dict[str, Any]: - try: - with open(path) as f: - l = {} - exec(f.read(), locals=l, globals={}) - return l - except IOException as e: - raise LoadConfigException(str(e)) + local_vars: Dict[str, Any] + + if path is None: + # Try out all default config file locations + for path in DEFAULT_LOCATIONS: + try: + local_vars = cls._load_config_file(path) + break + except (ReadFileException, ExecuteException) as e: + logger.debug(f"Could not load config from {style_path(path)}: {e}") + else: + raise ConfigurationException(style_error( + "No valid config file found in any of the default locations")) + else: + # Use the path + try: + local_vars = cls._load_config_file(path) + except (ReadFileException, ExecuteException) as e: + raise ConfigurationException( + style_error("Could not load config file from ") + + style_path(path) + f": {e}") + + return cls(local_vars) + + @staticmethod + def _load_config_file(path: Path) -> Dict[str, Any]: + """ + May raise: ReadFileException, ExecuteException + """ + + local_vars: Dict[str, Any] = {} + safer_exec(DEFAULT_CONFIG_FILE, local_vars) + + safer_exec(read_file(path), local_vars) + if not "base_dir" in local_vars: + local_vars["base_dir"] = path.parent + + logger.info(f"Loaded config from {style_path(str(path))}") + + return local_vars + + def __init__(self, local_vars: Dict[str, Any]) -> None: + """ + May raise: ConfigurationException + """ + + self.local_vars = local_vars + + def copy(self) -> "Config": + return Config(copy_local_variables(self.local_vars)) + + def _get(self, name: str, *types: type) -> Any: + """ + May raise: ConfigurationException + """ + + if not name in self.local_vars: + raise ConfigurationException( + style_error(f"Expected a variable named ") + + style_var(name)) + + value = self.local_vars[name] + + if types: + if not any(isinstance(value, t) for t in types): + raise ConfigurationException( + style_error("Expexted variable ") + style_var(name) + + style_error(" to have one of the following types:\n" + + ", ".join(t.__name__ for t in types)) + ) + + return value + + def _get_optional(self, name: str, *types: type) -> Optional[Any]: + if not name in self.local_vars: + return None + else: + return self._get(name, *types) + + def _set(self, name: str, value: Any) -> None: + self.local_vars[name] = value + + @staticmethod + def _is_pathy(elem: Any) -> bool: + return isinstance(elem, str) or isinstance(elem, Path) + + # Attributes begin here + + # Locations and paths + + @property + def base_dir(self) -> Path: + """ + The path that is the base of all other relative paths. + + Default: The directory the config file was loaded from. + """ + + return Path(self._get("base_dir", str, Path)).expanduser() + + @base_dir.setter + def base_dir(self, path: Path) -> None: + self._set("base_dir", path) + + def _interpret_path(self, path: Union[str, Path]) -> Path: + path = Path(path).expanduser() + if path.is_absolute(): + logger.debug(style_path(path) + " is absolute, no interpreting required") + return path + else: + logger.debug(style_path(path) + " is relative, interpreting as " + style_path(self.base_dir / path)) + return self.base_dir / path + + @property + def known_files(self) -> Path: + """ + The path where evering stores which files it is currently + managing. + + Default: "known_files" + """ + + return self._interpret_path(self._get("known_files", str, Path)) + + @property + def config_dir(self) -> Path: + """ + The directory containing the config files. + + Default: "config/" + """ + + return self._interpret_path(self._get("config_dir", str, Path)) + + # Parsing and compiling behavior + + @property + def binary(self) -> bool: + """ + When interpreting a separate header file: Whether the + corresponding file should not be parsed and compiled, but + instead just copied to the targets. + + Has no effect if the file has no header files. + + Default: True + """ + + return self._get("binary", bool) + + @property + def targets(self) -> List[Path]: + """ + The locations the (compiled) config file should be put + in. Must be set for all files. + + Default: not set + """ + + name = "targets" + target = self._get(name) + is_path = self._is_pathy(target) + is_list_of_paths = (isinstance(target, list) and + all(self._is_pathy(elem) for elem in target)) + + if not is_path and not is_list_of_paths: + raise ConfigurationException( + style_error("Expected variable ") + style_var(name) + + style_error(" to be either a path or a list of paths")) + + if is_path: + return [self._interpret_path(target)] + else: + return [self._interpret_path(elem) for elem in target] + + @property + def statement_prefix(self) -> str: + """ + This determines the prefix for statements like "# if", + "# elif", "# else" or "# endif". The prefix always has at + least length 1. + + Default: "#" + """ + + name = "statement_prefix" + prefix = self._get(name, str) + + if len(prefix) < 1: + raise ConfigurationException( + style_error("Expected variable ") + style_var(name) + + style_error(" to have at least length 1")) + + return prefix + + @property + def expression_delimiters(self) -> Tuple[str, str]: + """ + This determines the delimiters for expressions like + "{{ 1 + 1 }}". + + It is a tuple of the form: (, ), where both + the prefix and suffix are strings of at least length 1. + + Default: ("{{", "}}") + """ + + name = "expression_delimiters" + delimiters = self._get(name, tuple) + + if len(delimiters) != 2: + raise ConfigurationException( + style_error("Expected variable ") + style_var(name) + + style_error(" to be a tuple of length 2")) + + if len(delimiters[0]) < 1 or len(delimiters[1]) < 1: + raise ConfigurationException( + style_error("Expected both strings in variable ") + + style_var(name) + style_error( "to be of length >= 1")) + + return delimiters + + # Environment and file-specific information + + @property + def filename(self) -> str: + """ + The name of the file currently being compiled, as a string. + + Only set during compilation. + """ + + return self._get("filename", str) + + @filename.setter + def filename(self, filename: str) -> None: + self._set("filename", filename) + + @property + def target(self) -> Path: + """ + The location the file is currently being compiled for, as a + Path. + + Only set during compilation. + """ + + return self._interpret_path(self._get("target", str, Path)) + + @target.setter + def target(self, path: Path) -> None: + self._set("target", path) diff --git a/evering/explore.py b/evering/explore.py new file mode 100644 index 0000000..3a6b434 --- /dev/null +++ b/evering/explore.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional +import logging + +from .util import * +from .colors import * + +__all__ = ["FileInfo", "find_config_files"] +logger = logging.getLogger(__name__) + +HEADER_FILE_SUFFIX = ".evering-header" + +@dataclass +class FileInfo: + path: Path + header: Optional[Path] = None + +def find_config_files(config_dir: Path) -> List[FileInfo]: + try: + return explore_dir(config_dir) + except OSError as e: + raise CatastrophicError(style_error("could not access config dir ") + style_path(config_dir) + f": {e}") + +def explore_dir(cur_dir: Path) -> List[FileInfo]: + if not cur_dir.is_dir(): + raise CatastrophicError(style_path(cur_dir) + style_error(" is not a directory")) + + files: Dict[Path, FileInfo] = {} + header_files: List[Path] = [] + subdirs: List[Path] = [] + + # 1. Sort all the files in this folder into their respective categories + for element in cur_dir.iterdir(): + if element.is_dir(): + logger.debug(f"Found subdir {style_path(element)}") + subdirs.append(element) + elif element.is_file(): + if element.suffix == HEADER_FILE_SUFFIX: + logger.debug(f"Found header file {style_path(element)}") + header_files.append(element) + else: + logger.debug(f"Found file {style_path(element)}") + files[element] = FileInfo(element) + else: + logger.debug(f"{style_path(element)} is neither a dir nor a file") + + # 2. Assign the header files to their respective files + for header_file in header_files: + matching_file = header_file.with_suffix("") # Remove last suffix + matching_file_info = files.get(matching_file) + + if matching_file_info is None: + logger.warning(style_warning("No corresponding file for header file ") + style_path(header_file)) + else: + logger.debug(f"Assigned header file {style_path(header_file)} to file {style_path(matching_file)}") + matching_file_info.header = header_file + + # 3. Collect the resulting FileInfos + result = list(files.values()) + + # 4. And (try to) recursively descend into all folders + for subdir in subdirs: + try: + result.extend(explore_dir(subdir)) + except OSError as e: + logger.warning(style_warning("Could not descend into folder ") + style_path(subdir) + f": {e}") + + return result diff --git a/evering/known_files.py b/evering/known_files.py new file mode 100644 index 0000000..173c855 --- /dev/null +++ b/evering/known_files.py @@ -0,0 +1,79 @@ +import json +import logging +from pathlib import Path +from typing import Dict, List, Set + +from .colors import * +from .util import * + +__all__ = ["KnownFiles"] +logger = logging.getLogger(__name__) + +class KnownFiles: + def __init__(self, path: Path) -> None: + self._path = path + self._old_known_files: Dict[Path, str] = {} + self._new_known_files: Dict[Path, str] = {} + + try: + with open(self._path) as f: + self._old_known_files = self._read_known_files(f.read()) + except FileNotFoundError as e: + logger.debug(f"File {style_path(self._path)} does not exist, " + "creating a new file on the first upcoming save") + + def _read_known_files(self, text: str) -> Dict[Path, str]: + known_files: Dict[Path, str] = {} + raw_known_files = json.loads(text) + + if not isinstance(raw_known_files, dict): + raise CatastrophicError(style_error("Root level structure is not a dictionary")) + + for path, file_hash in raw_known_files.items(): + if not isinstance(path, str): + raise CatastrophicError(style_error(f"Path {path!r} is not a string")) + if not isinstance(file_hash, str): + raise CatastrophicError(style_error(f"Hash {hash!r} at path {path!r} is not a string")) + + path = Path(path).expanduser().resolve() # normalized + known_files[path] = file_hash + + return known_files + + def update_file(self, path: Path, file_hash: str) -> None: + self._new_known_files[path.expanduser().resolve()] = file_hash + + def save_incremental(self) -> None: + to_save: Dict[str, str] = {} + for path in self._old_known_files.keys() | self._new_known_files.keys(): + if path in self._new_known_files: + to_save[str(path)] = self._new_known_files[path] + else: + to_save[str(path)] = self._old_known_files[path] + + self._save(json.dumps(to_save)) + logger.debug(f"Incremental save to {style_path(self._path)} completed") + + def find_lost_files(self) -> Set[Path]: + return set(self._old_known_files.keys() - self._new_known_files.keys()) + + def save_final(self) -> None: + to_save: Dict[str, str] = {} + + for path, file_hash in self._new_known_files.items(): + to_save[str(path)] = file_hash + + self._save(json.dumps(to_save)) + logger.debug(f"Final save to {style_path(self._path)} completed") + + def _save(self, text: str) -> None: + # Append a .tmp to the file name + path = Path(*self._path.parts[:-1], self._path.name + ".tmp") + + try: + write_file(path, text) + path.replace(self._path) # Assumed to be atomic + except (WriteFileException, OSError) as e: + raise CatastrophicError( + style_error("Error saving known files to ") + + style_path(path) + f": {e}") diff --git a/evering/parser.py b/evering/parser.py index b10b5f1..89296c9 100644 --- a/evering/parser.py +++ b/evering/parser.py @@ -13,12 +13,10 @@ This parsing solution has the following structure: 5. Evaluate the blocks recursively """ -__all__ = ["ParseException", "Parser"] - -class ParseException(Exception): - @classmethod - def on_line(cls, line: "Line", text: str) -> "ParseException": - return ParseException(f"Line {line.line_number}: {text}") +__all__ = [ + "split_header_and_rest", + "ParseException", "Parser", +] def split_header_and_rest(text: str) -> Tuple[List[str], List[str]]: lines = text.splitlines() @@ -40,20 +38,25 @@ def split_header_and_rest(text: str) -> Tuple[List[str], List[str]]: return header, rest +class ParseException(Exception): + @classmethod + def on_line(cls, line: "Line", text: str) -> "ParseException": + return ParseException(f"Line {line.line_number}: {text}") + class Parser: def __init__(self, raw_lines: List[str], - statement_initiator: str, - expression_opening_delimiter: str, - expression_closing_delimiter: str, + statement_prefix: str, + expression_prefix: str, + expression_suffix: str, ) -> None: """ May raise: ParseException """ - self.statement_initiator = statement_initiator - self.expression_opening_delimiter = expression_opening_delimiter - self.expression_closing_delimiter = expression_closing_delimiter + self.statement_prefix = statement_prefix + self.expression_prefix = expression_prefix + self.expression_suffix = expression_suffix # Split up the text into lines and parse those lines: List[Line] = [] @@ -89,7 +92,7 @@ class Line(ABC): pass try: - return EndStatement(parser, text, line_number) + return EndifStatement(parser, text, line_number) except ParseException: pass @@ -100,7 +103,7 @@ class Line(ABC): self.line_number = line_number def _parse_statement(self, text: str, statement_name: str) -> Optional[str]: - start = f"{self.parser.statement_initiator} {statement_name}" + start = f"{self.parser.statement_prefix} {statement_name}" text = text.strip() if text.startswith(start): return text[len(start):].strip() @@ -108,7 +111,7 @@ class Line(ABC): return None def _parse_statement_noarg(self, text: str, statement_name: str) -> bool: - return text.strip() == f"{self.parser.statement_initiator} {statement_name}" + return text.strip() == f"{self.parser.statement_prefix} {statement_name}" class ActualLine(Line): def __init__(self, parser: Parser, text: str, line_number: int) -> None: @@ -136,18 +139,18 @@ class ActualLine(Line): i = 0 while i < len(text): - # Find opening delimiter - od = text.find(self.parser.expression_opening_delimiter, i) + # Find expression prefix + od = text.find(self.parser.expression_prefix, i) if od == -1: chunks.append((text[i:], False)) break # We've consumed the entire string. - od_end = od + len(self.parser.expression_opening_delimiter) + od_end = od + len(self.parser.expression_prefix) - # Find closing delimiter - cd = text.find(self.parser.expression_closing_delimiter, od_end) + # Find expression suffix + cd = text.find(self.parser.expression_suffix, od_end) if cd == -1: - raise ParseException.on_line(self, f"No closing delimiter\n{text[:od_end]} <-- to THIS opening delimiter") - cd_end = cd + len(self.parser.expression_closing_delimiter) + raise ParseException.on_line(self, f"No matching expression suffix\n{text[:od_end]} <-- to THIS expression prefix") + cd_end = cd + len(self.parser.expression_suffix) # Split up into chunks chunks.append((text[i:od], False)) @@ -211,7 +214,7 @@ class ElseStatement(Line): if not self._parse_statement_noarg(text, "else"): raise ParseException.on_line(self, "Not an 'else' statement") -class EndStatement(Line): +class EndifStatement(Line): def __init__(self, parser: Parser, text: str, line_number: int) -> None: """ May raise: ParseException @@ -219,8 +222,8 @@ class EndStatement(Line): super().__init__(parser, line_number) - if not self._parse_statement_noarg(text, "end"): - raise ParseException.on_line(self, "Not an 'end' statement") + if not self._parse_statement_noarg(text, "endif"): + raise ParseException.on_line(self, "Not an 'endif' statement") # Block parsing @@ -302,7 +305,7 @@ class IfBlock(Block): if not lines_queue: raise ParseException("Unexpected end of file, expected 'if' statement") - if not isinstance(lines_queue[-1], EndStatement): + if not isinstance(lines_queue[-1], EndifStatement): raise ParseException.on_line(lines_queue[-1], "Expected 'end' statement") lines_queue.pop() diff --git a/evering/process.py b/evering/process.py new file mode 100644 index 0000000..52b77cd --- /dev/null +++ b/evering/process.py @@ -0,0 +1,119 @@ +import logging +import shutil +from pathlib import Path +from typing import List, Optional + +from .colors import * +from .config import * +from .known_files import * +from .parser import * +from .util import * + +__all__ = ["Processor"] +logger = logging.getLogger(__name__) + +class Processor: + def __init__(self, config: Config, known_files: KnownFiles) -> None: + self.config = config + self.known_files = known_files + + def process_file(self, path: Path, header_path: Optional[Path] = None) -> None: + logger.info(f"{style_path(path)}:") + + config = self.config.copy() + config.filename = path.name + + if header_path is None: + self._process_file_without_header(path, config) + else: + self._process_file_with_header(path, header_path, config) + + def _process_file_without_header(self, path: Path, config: Config) -> None: + logger.debug(f"Processing file {style_path(path)} with no header") + + try: + text = read_file(path) + except ReadFileException as e: + raise LessCatastrophicError( + style_error("Could not load file ") + + style_path(path) + f": {e}") + + header, lines = split_header_and_rest(text) + + try: + safer_exec("\n".join(header), config.local_vars) + except ExecuteException as e: + raise LessCatastrophicError( + style_error("Could not parse header of file ") + + style_path(path) + f": {e}") + + self._process_parseable(lines, config) + + def _process_file_with_header(self, path: Path, header_path: Path, config: Config) -> None: + logger.debug(f"Processing file {style_path(path)} " + f"with header {style_path(header_path)}") + + try: + header_text = read_file(header_path) + safer_exec(header_text, config.local_vars) + except ReadFileException as e: + raise LessCatastrophicError( + style_error("Could not load header file ") + + style_path(header_path) + f": {e}") + except ExecuteException as e: + raise LessCatastrophicError( + style_error("Could not parse header file ") + + style_path(header_path) + f": {e}") + + if config.binary: + self._process_binary(path, config) + else: + try: + lines = read_file(path).splitlines() + except ReadFileException as e: + raise LessCatastrophicError( + style_error("Could not load file ") + + style_path(path) + f": {e}") + + self._process_parseable(lines, config) + + def _process_binary(self, path: Path, config: Config) -> None: + logger.debug(f"Processing as a binary file") + + for target in config.targets: + logger.info(f" -> {style_path(str(target))}") + + try: + shutil.copy(path, target) + except (IOError, shutil.SameFileError) as e: + logger.warning(style_warning("Could not copy") + f": {e}") + + def _process_parseable(self, lines: List[str], config: Config) -> None: + for target in config.targets: + logger.info(f" -> {style_path(str(target))}") + + config_copy = config.copy() + config_copy.target = target + + try: + parser = Parser( + lines, + statement_prefix=config.statement_prefix, + expression_prefix=config.expression_delimiters[0], + expression_suffix=config.expression_delimiters[1], + ) + text = parser.evaluate(config_copy.local_vars) + except ParseException as e: + logger.warning(style_warning("Could not parse ") + + style_path(target) + f": {e}") + continue + except ExecuteException as e: + logger.warning(style_warning("Could not compile ") + + style_path(target) + f": {e}") + continue + + try: + write_file(target, text) + except WriteFileException as e: + logger.warning(style_warning("Could not write to ") + style_path(str(target)) + + f": {e}") diff --git a/evering/prompt.py b/evering/prompt.py new file mode 100644 index 0000000..e78cb89 --- /dev/null +++ b/evering/prompt.py @@ -0,0 +1,34 @@ +from typing import Optional + +__all__ = ["prompt_choice", "prompt_yes_no"] + +def prompt_choice(question: str, options: str) -> str: + default_option = None + for char in options: + if char.isupper(): + default_option = char + break + + option_string = "/".join(options) + + while True: + result = input(f"{question} [{option_string}] ").lower() + if not result and default_option: + return default_option + # The set() makes it so that we're only testing individual + # characters, not substrings. + elif result in set(options.lower()): + return result + else: + print(f"Invalid answer, please choose one of [{option_string}].") + +def prompt_yes_no(question: str, default_answer: Optional[bool]) -> bool: + if default_answer is None: + options = "yn" + elif default_answer: + options = "Yn" + else: + options = "yN" + + result = prompt_choice(question, options) + return result.lower() == "y" diff --git a/evering/util.py b/evering/util.py new file mode 100644 index 0000000..dbf2283 --- /dev/null +++ b/evering/util.py @@ -0,0 +1,88 @@ +import copy +import types +from pathlib import Path +from typing import Any, Dict + +__all__ = [ + "copy_local_variables", + "ExecuteException", "safer_exec", "safer_eval", + "ReadFileException", "read_file", + "WriteFileException", "write_file", + "CatastrophicError", "LessCatastrophicError", +] + +def copy_local_variables(local: Dict[str, Any]) -> Dict[str, Any]: + """ + Attempts to deep-copy a set of local variables, but keeping + modules at the top level alone, since they don't tend to deepcopy + well. + + May raise: Not sure at the moment + """ + + local_copy = {} + + for key, value in local.items(): + if isinstance(value, types.ModuleType): + local_copy[key] = value + else: + local_copy[key] = copy.deepcopy(value) + + return local_copy + +class ExecuteException(Exception): + pass + +def safer_exec(code: str, local_vars: Dict[str, Any]) -> None: + """ + May raise: ExecuteException + """ + + try: + exec(code, {}, local_vars) + except Exception as e: + raise ExecuteException(e) + +def safer_eval(code: str, local_vars: Dict[str, Any]) -> Any: + """ + May raise: ExecuteException + """ + + try: + return eval(code, {}, local_vars) + except Exception as e: + raise ExecuteException(e) + +class ReadFileException(Exception): + pass + +def read_file(path: Path) -> str: + """ + May raise: ReadFileException + """ + + try: + with open(path.expanduser()) as f: + return f.read() + except OSError as e: + raise ReadFileException(e) + +class WriteFileException(Exception): + pass + +def write_file(path: Path, text: str) -> None: + """ + May raise: WriteFileException + """ + + try: + with open(path.expanduser(), "w") as f: + f.write(text) + except OSError as e: + raise WriteFileException(e) + +class CatastrophicError(Exception): + pass + +class LessCatastrophicError(Exception): + pass