Source code for molsystem.templateatoms

# -*- coding: utf-8 -*-

"""A dictionary-like object for holding atoms for the templates
"""

from itertools import zip_longest
import logging
from typing import Any, Dict, TypeVar

from molsystem.atoms import _Atoms as Atoms
from molsystem.column import _Column as Column
from molsystem.table import _Table as Table

System_tp = TypeVar("System_tp", "System", None)
Atoms_tp = TypeVar("Atoms_tp", "_Atoms", str, None)

logger = logging.getLogger(__name__)


[docs]def grouped(iterable, n): "s -> (s0,s1,s2,...sn-1), (sn,sn+1,sn+2,...s2n-1), (s2n,...s3n-1), ..." return zip_longest(*[iter(iterable)] * n)
class _Templateatoms(Atoms): """The Atoms class holds arrays of attributes describing atoms This is a bit complicated due to the separation of the actual atoms and the coordinates, which depend on the configuration. Also, the list of atoms can itself be time-dependent, and is controlled by the subset 'all'. Atoms can be added ('append') or removed ('delete'). """ def __init__( self, system: System_tp, atom_tablename='templateatom', coordinates_tablename='templatecoordinates', ) -> None: self._system = system self._atom_tablename = atom_tablename self._coordinates_tablename = coordinates_tablename self._atom_table = Table(system, self._atom_tablename) self._coordinates_table = Table(system, self._coordinates_tablename) self._templates = system['template'] def __getitem__(self, key) -> Any: """Allow [] to access the data!""" if key in self._atom_table.attributes: sql = f'WHERE template = {self.current_template}' return Column(self._atom_table, key, where=sql) elif key in self._coordinates_table.attributes: where = ( "WHERE templateatom in (" f" SELECT id FROM {self._atom_tablename}" f" WHERE template = {self.current_template}" ")" ) return Column(self._coordinates_table, key, where=where) else: raise KeyError(f"'{key}' not in template atoms") @property def current_template(self): """The current template in use.""" return self._templates.current_template @current_template.setter def current_template(self, value): self._templates.current_template = value @property def n_atoms(self) -> int: """The number of atoms *in the current* template.""" self.cursor.execute( f"SELECT COUNT(*) FROM {self._atom_tablename} WHERE template = ?", (self.current_template,) ) return self.cursor.fetchone()[0] @property def attributes(self) -> Dict[str, Any]: """The definitions of the attributes. Combine the attributes of the atom and coordinates tables to make it look like a single larger table. """ result = self._atom_table.attributes for key, value in self._coordinates_table.attributes.items(): if key != 'templateatom': # ignore foreign key linking tables result[key] = value return result @property def coordinate_system(self): """The type of coordinates: 'fractional' or 'Cartesian'""" return 'Cartesian' @coordinate_system.setter def coordinate_system(self, value): raise RuntimeError('Templates can only use Cartesian coordinates.') def append(self, n=None, **kwargs: Dict[str, Any]) -> None: """Append one or more atoms The keys give the field for the data. If an existing field is not mentioned, then the default value is used, unless the default is None, in which case an error is thrown. It is an error if there is not a field corrresponding to a key. """ # Need to handle the elements specially. Can give atomic numbers, # or symbols. By construction the references to elements are identical # to their atomic numbers. if 'symbol' in kwargs: symbols = kwargs.pop('symbol') kwargs['atno'] = self.to_atnos(symbols) # How many new rows there are n_rows, lengths = self._get_n_rows(**kwargs) if n is not None: if n_rows != 1 and n_rows != n: raise RuntimeError( f"Requested number of template atoms ({n}) is not " f"compatible with the length of the data ({n_rows})." ) n_rows = n # Fill in the atom table data = {} for column in self._atom_table.attributes: if column != 'id' and column in kwargs: data[column] = kwargs.pop(column) if 'template' not in data: data['template'] = self.current_template ids = self._atom_table.append(n=n_rows, **data) # Now append to the coordinates table, but only if needed. data = {} for column in self._coordinates_table.attributes: if column != 'templateatom' and column in kwargs: data[column] = kwargs.pop(column) if len(data) > 0: data['templateatom'] = ids self._coordinates_table.append(n=n_rows, **data) return ids def atomic_numbers(self, template: int = None) -> [int]: """The atomic numbers of the atoms in a template. Parameters ---------- template : int = None Which template, defaulting to the current template. Returns ------- ids : [int] The ids of the atoms in the template. """ if template is None: template = self.current_template return [ x[0] for x in self.db.execute( f"SELECT atno FROM {self._atom_tablename} WHERE template = ?", (template,) ) ] def atom_ids(self, template: int = None) -> [int]: """The ids of the atoms in a template. Parameters ---------- template : int = None Which template, defaulting to the current template. Returns ------- ids : [int] The ids of the atoms in the template. """ if template is None: template = self.current_template return [ x[0] for x in self.db.execute( f"SELECT id FROM {self._atom_tablename} WHERE template = ?", (template,) ) ] def atoms(self, *args, template=None): """Get an iterator over atoms in the template. Parameters ---------- args : str, int or float SQL restrictions for a WHERE statement, each argument being one word, e.g. "atno" "=" 5 template : int = None Which template, defaulting to the current template. Returns ------- SQLite3.Cursor : The cursor containing the result. """ atom_tbl = self._atom_tablename coord_tbl = self._coordinates_tablename atom_columns = [*self._atom_table.attributes] coord_columns = [*self._coordinates_table.attributes] coord_columns.remove('templateatom') columns = [f'{atom_tbl}.{x}' for x in atom_columns] columns += [f'{coord_tbl}.{x}' for x in coord_columns] column_defs = ', '.join(columns) if template is None: template = self.current_template sql = ( f'SELECT {column_defs} FROM {atom_tbl}, {coord_tbl}' f' WHERE {coord_tbl}.templateatom = {atom_tbl}.id' f' AND {atom_tbl}.template = ?' ) if len(args) == 0: return self.db.execute(sql, (template,)) parameters = [template] for col, op, value in grouped(args, 3): if op == '==': op = '=' sql += f' AND "{col}" {op} ?' parameters.append(value) return self.db.execute(sql, parameters) def remove(self, *args, template=None): """Remove atoms in the template. Parameters ---------- args : str, int or float SQL restrictions for a WHERE statement, each argument being one word, e.g. "atno" "=" 5 template : int = None Which template, defaulting to the current template. Returns ------- None """ if template is None: template = self.current_template sql = f'DELETE FROM {self.table} WHERE template = ?' parameters = [template] for col, op, value in grouped(args, 3): if op == '==': op = '=' sql += f' AND "{col}" {op} ?' parameters.append(value) self.db.execute(sql, parameters) self.db.commit()