#!/usr/bin/python
"""
 * Copyright (c) 2008 Red Hat, Inc.
 *
 * This software is licensed to you under the GNU General Public License,
 * version 2 (GPLv2). There is NO WARRANTY for this software, express or
 * implied, including the implied warranties of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE. You should have received a copy of GPLv2
 * along with this software; if not, see
 * http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt.
 * 
 * Red Hat trademarks are not licensed under GPLv2. No permission is
 * granted to use or replicate Red Hat trademarks that are incorporated
 * in this software or its documentation. 
"""

import getopt
import sys
import os
from pprint import pprint

sys.path.append('/usr/share/rhn')
try:
   from server import rhnSQL
   from common import initCFG, CFG
except:
   print "Couldn't load needed libs, Are you sure you are running this on a satellite?"
   sys.exit(1)

initCFG()

db_string = CFG.DEFAULT_DB #"rhnsat/rhnsat@rhnsat"


system_counts = {'other':0, 'server':0, 'client':0}

sep = ''
version = 0
filename = "./audit.txt"
#filename = None
out_file = None

def main():
   global sep
   global version
   global out_file

   
   opts, args = getopt.getopt(sys.argv[1:], "ho:v", ["separator=", "validate="]);
   sep = get_value(opts, "--separator")

   if get_value(opts, "--validate") != "":
      validate(get_value(opts, "--validate"))
      sys.exit(0)

   if (filename != None):
      out_file = open(filename, "w")

   rhnSQL.initDB(db_string)
   version = getVersionInfo()
   

   sys_list = getSystems()
   #pprint(sys_list)
   process_list = process_systems(sys_list)

   print_systems(sys_list)

   output_text("\n-------------Audit information-----------------------")
   try:
      import datetime
      output_text( datetime.datetime.now().strftime("The date is %A %m/%d/%Y at %H:%M"))
   except: 
      output_text("Unable to find current date/time")
   output_text("Satellite version:  " + str(version))
   output_text("Org 1's name: " + get_cert_name())

   audit_virt(sys_list) 
   chan_ents = get_chan_ent_count() 
   audit_chan_ents(chan_ents)

   if (filename != None):
      print "Saved to " + filename
      out_file.close()

   secure_file()


def get_cert_name():
   query = "select name from web_customer where id  = 1"
   return  run_query(query)[0]['name']

def audit_chan_ents(ents):
   output_text("\nAudit system counts:")
   output_text("There are " + str(ents['server_current']) + " out of " + str(ents['server_max']) + " rhel-server entitlements being used.")
   output_text("There are " + str(ents['client_current']) + " out of " + str(ents['client_max']) + " rhel-client entitlements being used.")

   output_text("Total Clients registered: " + str(system_counts['client']))
   output_text("Total Servers registered: " + str(system_counts['server']))
   output_text("Total Other systems registered: " + str(system_counts['other']))

   if system_counts['server'] > ents['server_max']:
      output_text("ALERT: There are more Server systems registered than there are entitlements for.")
   if system_counts['client'] > ents['client_max']:
      output_text("ALERT: There are more Client systems registered than there are entitlements for.")
   output_text("System Count audit complete.")

def get_chan_ent_count():
   count = {'server_max':0, 'client_max':0, 'server_current':0, 'client_current':0}
   query = "select sum(cfp.max_members) from rhnChannelFamily cf inner join rhnChannelFamilyPermissions cfp on cf.id = cfp.channel_family_id and cf.label = "
   count['server_max'] = run_query_list(query + '\'rhel-server\'')[0][0]
   count['client_max'] = run_query_list(query + '\'rhel-client\'')[0][0]  
   query = "select sum(cfp.current_members) from rhnChannelFamily cf inner join rhnChannelFamilyPermissions cfp on cf.id = cfp.channel_family_id and cf.label = " 
   count['server_current'] = run_query_list(query + '\'rhel-server\'')[0][0]
   count['client_current'] = run_query_list(query + '\'rhel-client\'')[0][0]
   return count


def audit_virt(sys_list):
   output_text("\nVirt Audit:")
   if version < 5.0:
      output_text("This satellite is less than version 5.0, so no virt capabilities exist")
      return
   for sys in sys_list:
      if sys['virt_type'] == "Host":
         if len(sys['guests']) > 4 and   'virtualization_host_platform' not in sys['sys_ents']:
            pprint(sys['sys_ents'])
            output_text("Alert: Virtual Host " + sys['name'] + ' (' + str(sys['id']) + ") has more than 4 guests, but does not have a 'virt host platform' entitlement")
   output_text("Virt Audit Completed")




def process_systems(sys_list):
   global system_counts

   for sys in sys_list:
      #lets do the virt stuff
      if version >= 5.0:
         sys['guests'] = getVirtGuests(sys)
         if (sys['guests'] == None):
            sys['virt_type'] = "None"
            sys['guests'] = []
         else:
            sys['virt_type'] = "Host"
            for guest in sys['guests']:
               guest['virt_type'] = "Guest"
      else:
         sys['guests'] = []
         sys['virt_type'] = "N/A"
      #we need to count as well
      if sys['release'] in ["5Server", "4AS", "4ES", "3AS", "3ES", "2.1AS", "2.1ES"]:
         system_counts['server'] =  system_counts['server'] + 1
      elif sys['release'] in ["5Client", "4WS", "3WS", "4Desktop", "3Desktop", "2.1Desktop", "2.1WS" ]:
         system_counts['client'] =  system_counts['client'] + 1
      else:
         system_counts['other'] = system_counts['other'] + 1
      #get entitelments:
      sys['sys_ents'] = getSystemEntitlements(sys)
      


def getVirtGuests(sys):
   query = """select s.id, s.name, s.release, s.created, si.checkin
        from rhnServer s  left join
        rhnVirtualInstance vi on s.id = vi.virtual_system_id inner join
             rhnServerInfo si on s.id = si.server_id
	where vi.host_system_id = """ + str(sys['id'])
   return run_query(query)






#get the satellite version
def getVersionInfo():
   tmpStr = run_query_list("select evr.version from rhnPackageEVR evr inner join rhnVersionInfo info on evr.id = info.evr_id")
   tmpList = tmpStr[0][0].split('.')
   return float(tmpList[0] + '.' + tmpList[1])

def getSystemEntitlements(sys):
   id = sys['id']
   query =  'select sgt.label from rhnServerGroup sg inner join rhnServerGroupMembers sgm on sg.id = sgm.server_group_id inner join rhnServerGroupType sgt on sg.group_type = sgt.id where sgt.label is not null and sgm.server_id ='  + str(id)
   toRet = []
   for ent in run_query_list(query):
      toRet.append(ent[0])
   return toRet


#get all non-virt systems
def getSystems():
   if version >= 5.0:
      query = """select s.id, s.name, s.release, s.created, si.checkin
        from rhnServer s  left join
        rhnVirtualInstance vi on s.id = vi.virtual_system_id inner join
             rhnServerInfo si on s.id = si.server_id
	where vi.virtual_system_id is null"""
   else:
      query = """select s.id, s.name, s.release, s.created, si.checkin
        from rhnServer s  left join
        rhnServerInfo si on s.id = si.server_id"""
   return run_query(query)


def print_systems(sys_list):
   if sep != "":
      output_text('id' + sep +  'name'  + sep + 'release'  + sep +  'virt status'  + sep + 'registration date'  + sep + 'last checkin')
   else:
      output_text( '%-15.15s %-35.35s %-10.10s %-15.25s %-25.25s %s' % ('id',  'name', 'release',  'virt status', 'registration date', 'last checkin'))
      output_text( '%-15.15s %-35.35s %-10.10s %-15.25s %-25.25s %s' % ('-'*14,  '-'*35, '-'*10,  '-'*15,  '-'*15,  '-'*17))
   for sys in sys_list:
      print_system(sys)
      for guest in sys['guests']:
         print_system(guest)


def print_system(sys):
   if sep != "":
       output_text(str(sys['id'])  + sep +  sys['name']  + sep + sys['release']  + sep +  sys['virt_type']  + sep + str(sys['created'])  + sep + str(sys['checkin']))
   else:
      output_text('%-15.15s %-35.35s %-10.10s %-15.25s %-25.25s %s' % (sys['id'],  sys['name'], sys['release'],  sys['virt_type'] , str(sys['created']), str(sys['checkin'])))


def secure_file():
   if (filename != None):
      os.system('md5sum ' + filename + '>> '  + filename)
   else:
      print md5

def validate(filename):
   print "The following should match:"
   os.system('head -n -1 ' + filename + '| md5sum -')
   os.system('tail -n 1 ' + filename)


def get_value(list, name):
   for item in list:
      if item[0] == name:
         return item[1]
   return ""


def output_text(string):
   global out_file
   if (out_file != None):
      out_file.write(string + "\n")
   else:
      print string


def run_query(query):
   _get_data_sql = rhnSQL.prepare(query)
   _get_data_sql.execute()
   return _get_data_sql.fetchall_dict()

def run_query_list(query):
   _get_data_sql = rhnSQL.prepare(query)
   _get_data_sql.execute()
   return _get_data_sql.fetchall()

if __name__ == "__main__":
    main()
