# coding: utf-8 # --------------------------------------------------------------------------------- # MW-Linux面板 # --------------------------------------------------------------------------------- # copyright (c) 2018-∞(https://github.com/midoks/mdserver-web) All rights reserved. # --------------------------------------------------------------------------------- # Author: midoks # --------------------------------------------------------------------------------- # --------------------------------------------------------------------------------- # sqlite3操作 # --------------------------------------------------------------------------------- import os import sys import sqlite3 import core.mw as mw def getPanelDir(): return os.path.dirname(os.getcwd()) def getTracebackInfo(): import traceback return traceback.format_exc() class Sql(): #------------------------------ # 数据库操作类 For sqlite3 #------------------------------ __DB_FILE = None # 数据库文件 __DB_CONN = None # 数据库连接对象 __DB_TABLE = "" # 被操作的表名称 __OPT_WHERE = "" # where条件 __OPT_LIMIT = "" # limit条件 __OPT_GROUP = "" # group条件 __OPT_ORDER = "" # order条件 __OPT_FIELD = "*" # field条件 __OPT_PARAM = () # where值 __debug = False def __init__(self): self.__DB_FILE = getPanelDir()+'/data/panel.db' def __getConn(self): # 取数据库对象 try: if self.__DB_CONN == None: self.__DB_CONN = sqlite3.connect(self.__DB_FILE) self.__DB_CONN.text_factory = str except Exception as ex: print(getTracebackInfo()) return "error: " + str(ex) def changeTextFactoryToBytes(self): self.__DB_CONN.text_factory = bytes return True def debug(self, debug=False): self.__debug = debug return self def autoTextFactory(self): if sys.version_info[0] == 3: self.__DB_CONN.text_factory = lambda x: str(x, encoding="utf-8", errors='ignore') else: self.__DB_CONN.text_factory = lambda x: unicode(x, "utf-8", "ignore") def dbfile(self, name): self.__DB_FILE = getPanelDir()+'/data/' + name + '.db' return self def dbPos(self, path, name, suffix_name = 'db'): self.__DB_FILE = path + '/' + name + '.' + suffix_name return self def table(self, table): # 设置表名 self.__DB_TABLE = table return self def where(self, where, param=()): # WHERE条件 if where: self.__OPT_WHERE = " WHERE " + where self.__OPT_PARAM = param return self def andWhere(self, where, param): # WHERE条件 if where: self.__OPT_WHERE = self.__OPT_WHERE + " and " + where # print(param) # print(self.__OPT_PARAM) self.__OPT_PARAM = self.__OPT_PARAM + param return self def order(self, order): # ORDER条件 if len(order): self.__OPT_ORDER = " ORDER BY " + order else: self.__OPT_ORDER = "" return self def group(self, group): if len(group): self.__OPT_GROUP = " GROUP BY " + group else: self.__OPT_GROUP = "" return self def limit(self, limit): # LIMIT条件 if len(limit): self.__OPT_LIMIT = " LIMIT " + limit else: self.__OPT_LIMIT = "" return self def field(self, field): # FIELD条件 if len(field): self.__OPT_FIELD = field return self def getDbField(self,name): sql = "PRAGMA table_info(%s)" % name result = self.__DB_CONN.execute(sql) data = result.fetchall() fields = [] for i in data: fields.append(i[1]) return fields def getDbFieldString(self,name): fields = self.getDbField(name) return ','.join(fields) def select(self): # 查询数据集 self.__getConn() try: sql = "SELECT " + self.__OPT_FIELD + " FROM " + self.__DB_TABLE + \ self.__OPT_WHERE + self.__OPT_GROUP + self.__OPT_ORDER + self.__OPT_LIMIT # print(sql) # print(self.__OPT_PARAM) result = self.__DB_CONN.execute(sql, self.__OPT_PARAM) data = result.fetchall() if len(data) == 0: return data # 构造字曲系列 if self.__OPT_FIELD != "*": field = self.__OPT_FIELD.split(',') tmp = [] for row in data: i = 0 t = {} for key in field: t[key] = row[i] i += 1 tmp.append(t) del(t) data = tmp del(tmp) else: field = self.getDbField(self.__DB_TABLE) tmp = [] for row in data: i = 0 t = {} for key in field: t[key] = row[i] i += 1 tmp.append(t) del(t) data = tmp del(tmp) # 将元组转换成列表 # tmp = map(list, data) # data = tmp # del(tmp) self.__close() return data except Exception as ex: return "error: " + str(ex) def inquiry(self, input_field=''): # 查询数据集 # 不清空查询参数 self.__getConn() try: sql = "SELECT " + self.__OPT_FIELD + " FROM " + self.__DB_TABLE + \ self.__OPT_WHERE + self.__OPT_GROUP + self.__OPT_ORDER + self.__OPT_LIMIT debug = getPanelDir()+'/data/debug.pl' if os.path.exists(debug): print(sql, self.__OPT_PARAM) result = self.__DB_CONN.execute(sql, self.__OPT_PARAM) data = result.fetchall() # 构造字曲系列 if self.__OPT_FIELD != "*": if input_field != "": field = input_field.split(',') else: field = self.__OPT_FIELD.split(',') tmp = [] for row in data: i = 0 tmp1 = {} for key in field: tmp1[key] = row[i] i += 1 tmp.append(tmp1) del(tmp1) data = tmp del(tmp) else: # 将元组转换成列表 tmp = map(list, data) data = tmp del(tmp) return data except Exception as ex: return "error: " + str(ex) def getField(self, keyName): # 取回指定字段 result = self.field(keyName).select() if len(result) == 1: return result[0][keyName] return None def setField(self, keyName, keyValue): # 更新指定字段 return self.save(keyName, (keyValue,)) def find(self): # 取一行数据 result = self.limit("1").select() if len(result) == 1: return result[0] return None def count(self): # 取行数 key = "COUNT(*)" data = self.field(key).select() try: return int(data[0][key]) except: return 0 def add(self, keys, param): # 插入数据 self.__getConn() try: values = "" for key in keys.split(','): values += "?," values = self.checkInput(values[0:len(values) - 1]) sql = "INSERT INTO " + self.__DB_TABLE + \ "(" + keys + ") " + "VALUES(" + values + ")" result = self.__DB_CONN.execute(sql, param) last_id = result.lastrowid self.__close() self.__DB_CONN.commit() return last_id except Exception as ex: print(str(ex)) return 0 # 插入数据 def insert(self, pdata): if not pdata: return False keys, param = self.__format_pdata(pdata) return self.add(keys, param) # 更新数据 def update(self, pdata): if not pdata: return False keys, param = self.__format_pdata(pdata) return self.save(keys, param) # 构造数据 def __format_pdata(self, pdata): keys = pdata.keys() keys_str = ','.join(keys) param = [] for k in keys: param.append(pdata[k]) return keys_str, tuple(param) def checkInput(self, data): if not data: return data if type(data) != str: return data checkList = [ {'d': '<', 'r': '<'}, {'d': '>', 'r': '>'}, {'d': '\'', 'r': '‘'}, {'d': '"', 'r': '“'}, {'d': '&', 'r': '&'}, {'d': '#', 'r': '#'}, {'d': '<', 'r': '<'} ] for v in checkList: data = data.replace(v['d'], v['r']) return data def addAll(self, keys, param): # 插入数据 self.__getConn() try: values = "" for key in keys.split(','): values += "?," values = values[0:len(values) - 1] sql = "INSERT INTO " + self.__DB_TABLE + \ "(" + keys + ") " + "VALUES(" + values + ")" result = self.__DB_CONN.execute(sql, param) return True except Exception as ex: return "error: " + str(ex) def commit(self): self.__close() self.__DB_CONN.commit() def save(self, keys, param): # 更新数据 self.__getConn() try: opt = "" for key in keys.split(','): opt += key + "=?," opt = opt[0:len(opt) - 1] sql = "UPDATE " + self.__DB_TABLE + " SET " + opt + self.__OPT_WHERE if self.__debug: print(sql, param) # 处理拼接WHERE与UPDATE参数 tmp = list(param) for arg in self.__OPT_PARAM: tmp.append(arg) self.__OPT_PARAM = tuple(tmp) result = self.__DB_CONN.execute(sql, self.__OPT_PARAM) self.__close() self.__DB_CONN.commit() return result.rowcount except Exception as ex: return "error: " + str(ex) def delete(self, id=None): # 删除数据 self.__getConn() try: if id: self.__OPT_WHERE = " WHERE id=?" self.__OPT_PARAM = (id,) sql = "DELETE FROM " + self.__DB_TABLE + self.__OPT_WHERE result = self.__DB_CONN.execute(sql, self.__OPT_PARAM) self.__close() self.__DB_CONN.commit() return result.rowcount except Exception as ex: return "error: " + str(ex) def originExecute(self, sql, param=()): self.__getConn() try: result = self.__DB_CONN.execute(sql, param) self.__DB_CONN.commit() return result except Exception as ex: return "error: " + str(ex) def execute(self, sql, param=()): # 执行SQL语句返回受影响行 self.__getConn() # print sql, param try: result = self.__DB_CONN.execute(sql, param) self.__DB_CONN.commit() return result.rowcount except Exception as ex: return "error: " + str(ex) def query(self, sql, param=()): # 执行SQL语句返回数据集 self.__getConn() try: result = self.__DB_CONN.execute(sql, param) # 将元组转换成列表 # data = map(list, result) return result except Exception as ex: return "error: " + str(ex) def create(self, name): # 创建数据表 self.__getConn() import mw script = mw.readFile('data/' + name + '.sql') result = self.__DB_CONN.executescript(script) self.__DB_CONN.commit() return result.rowcount def fofile(self, filename): # 执行脚本 self.__getConn() import mw script = mw.readFile(filename) result = self.__DB_CONN.executescript(script) self.__DB_CONN.commit() return result.rowcount def __close(self): # 清理条件属性 self.__OPT_WHERE = "" self.__OPT_FIELD = "*" self.__OPT_ORDER = "" self.__OPT_LIMIT = "" self.__OPT_PARAM = () def close(self): # 释放资源 try: self.__DB_CONN.close() self.__DB_CONN = None except: pass